打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Caffe 用于解决预测(回归)问题
userphoto

2016.10.08

关注

最近在基于caffe做目标检测的问题,需要利用caffe来训练一个回归网络,用来预测object在图像中的位置(x1,y1,width,height)。但是现有的caffe版本(happynear版本)只适用于二分类问题的数据集转换,所以需要修改caffe源码,使之也可以转换回归问题的数据集。

主要是参照 http://blog.csdn.NET/baobei0112/article/details/47606559 进行修改。但是这份博客使用的不是happynear的caffe版本,所以源码改动的地方差异较大。下面我会记录我改动的地方。

一.源码修改

1.修改caffe.proto,位于/src/caffe/proto

36行改成  repeated float label = 5;,然后运行extract_proto.bat

2.修改data_layer.hpp

  1. #ifndef CAFFE_DATA_LAYERS_HPP_  
  2. #define CAFFE_DATA_LAYERS_HPP_  
  3. #include <string>  
  4. #include <utility>  
  5. #include <vector>  
  6. #include "hdf5/hdf5.h"  
  7. #include "caffe/blob.hpp"  
  8. #include "caffe/common.hpp"  
  9. #include "caffe/data_reader.hpp"  
  10. #include "caffe/data_transformer.hpp"  
  11. #include "caffe/filler.hpp"  
  12. #include "caffe/internal_thread.hpp"  
  13. #include "caffe/layer.hpp"  
  14. #include "caffe/proto/caffe.pb.h"  
  15. #include "caffe/util/blocking_queue.hpp"  
  16. #include "caffe/util/db.hpp"  
  17. #define HDF5_DATA_DATASET_NAME "data"  
  18. #define HDF5_DATA_LABEL_NAME "label"  
  19. namespace caffe {  
  20. /** 
  21. * @brief Provides base for data layers that feed blobs to the Net. 
  22. * TODO(dox): thorough documentation for Forward and proto params. 
  23. */  
  24. template <typename Dtype>  
  25. class BaseDataLayer : public Layer<Dtype> {  
  26. public:  
  27. explicit BaseDataLayer(const LayerParameter& param);  
  28. // LayerSetUp: implements common data layer setup functionality, and calls  
  29. // DataLayerSetUp to do special data layer setup for individual layer types.  
  30. // This method may not be overridden except by the BasePrefetchingDataLayer.  
  31. virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  32. const vector<Blob<Dtype>*>& top);  
  33. // Data layers should be shared by multiple solvers in parallel  
  34. virtual inline bool ShareInParallel() const { return true; }  
  35. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  36. const vector<Blob<Dtype>*>& top) {  
  37. }  
  38. // Data layers have no bottoms, so reshaping is trivial.  
  39. virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  40. const vector<Blob<Dtype>*>& top) {  
  41. }  
  42. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  43. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  44. }  
  45. virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  46. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  47. }  
  48. protected:  
  49. TransformationParameter transform_param_;  
  50. shared_ptr<DataTransformer<Dtype> > data_transformer_;  
  51. bool output_labels_;  
  52. };  
  53. template <typename Dtype>  
  54. class Batch {  
  55. public:  
  56. Blob<Dtype> data_, label_;  
  57. };  
  58. template <typename Dtype>  
  59. class BasePrefetchingDataLayer :  
  60. public BaseDataLayer<Dtype>, public InternalThread {  
  61. public:  
  62. explicit BasePrefetchingDataLayer(const LayerParameter& param);  
  63. // LayerSetUp: implements common data layer setup functionality, and calls  
  64. // DataLayerSetUp to do special data layer setup for individual layer types.  
  65. // This method may not be overridden.  
  66. void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  67. const vector<Blob<Dtype>*>& top);  
  68. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  69. const vector<Blob<Dtype>*>& top);  
  70. virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  71. const vector<Blob<Dtype>*>& top);  
  72. // Prefetches batches (asynchronously if to GPU memory)  
  73. static const int PREFETCH_COUNT = 3;  
  74. protected:  
  75. virtual void InternalThreadEntry();  
  76. virtual void load_batch(Batch<Dtype>* batch) = 0;  
  77. Batch<Dtype> prefetch_[PREFETCH_COUNT];  
  78. BlockingQueue<Batch<Dtype>*> prefetch_free_;  
  79. BlockingQueue<Batch<Dtype>*> prefetch_full_;  
  80. Blob<Dtype> transformed_data_;  
  81. };  
  82. template <typename Dtype>  
  83. class DataLayer : public BasePrefetchingDataLayer<Dtype> {  
  84. public:  
  85. explicit DataLayer(const LayerParameter& param);  
  86. virtual ~DataLayer();  
  87. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  88. const vector<Blob<Dtype>*>& top);  
  89. // DataLayer uses DataReader instead for sharing for parallelism  
  90. virtual inline bool ShareInParallel() const { return false; }  
  91. virtual inline const char* type() const { return "Data"; }  
  92. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  93. virtual inline int MinTopBlobs() const { return 1; }  
  94. virtual inline int MaxTopBlobs() const { return 2; }  
  95. protected:  
  96. virtual void load_batch(Batch<Dtype>* batch);  
  97. DataReader reader_;  
  98. };  
  99. /** 
  100. * @brief Provides data to the Net generated by a Filler. 
  101. * TODO(dox): thorough documentation for Forward and proto params. 
  102. */  
  103. template <typename Dtype>  
  104. class DummyDataLayer : public Layer<Dtype> {  
  105. public:  
  106. explicit DummyDataLayer(const LayerParameter& param)  
  107. : Layer<Dtype>(param) {}  
  108. virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  109. const vector<Blob<Dtype>*>& top);  
  110. // Data layers should be shared by multiple solvers in parallel  
  111. virtual inline bool ShareInParallel() const { return true; }  
  112. // Data layers have no bottoms, so reshaping is trivial.  
  113. virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  114. const vector<Blob<Dtype>*>& top) {  
  115. }  
  116. virtual inline const char* type() const { return "DummyData"; }  
  117. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  118. virtual inline int MinTopBlobs() const { return 1; }  
  119. protected:  
  120. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  121. const vector<Blob<Dtype>*>& top);  
  122. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  123. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  124. }  
  125. virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  126. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  127. }  
  128. vector<shared_ptr<Filler<Dtype> > > fillers_;  
  129. vector<bool> refill_;  
  130. };  
  131. /** 
  132. * @brief Provides data to the Net from HDF5 files. 
  133. * TODO(dox): thorough documentation for Forward and proto params. 
  134. */  
  135. template <typename Dtype>  
  136. class HDF5DataLayer : public Layer<Dtype> {  
  137. public:  
  138. explicit HDF5DataLayer(const LayerParameter& param)  
  139. : Layer<Dtype>(param) {}  
  140. virtual ~HDF5DataLayer();  
  141. virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  142. const vector<Blob<Dtype>*>& top);  
  143. // Data layers should be shared by multiple solvers in parallel  
  144. virtual inline bool ShareInParallel() const { return true; }  
  145. // Data layers have no bottoms, so reshaping is trivial.  
  146. virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  147. const vector<Blob<Dtype>*>& top) {  
  148. }  
  149. virtual inline const char* type() const { return "HDF5Data"; }  
  150. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  151. virtual inline int MinTopBlobs() const { return 1; }  
  152. protected:  
  153. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  154. const vector<Blob<Dtype>*>& top);  
  155. virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  156. const vector<Blob<Dtype>*>& top);  
  157. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  158. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  159. }  
  160. virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  161. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {  
  162. }  
  163. virtual void LoadHDF5FileData(const char* filename);  
  164. std::vector<std::string> hdf_filenames_;  
  165. unsigned int num_files_;  
  166. unsigned int current_file_;  
  167. hsize_t current_row_;  
  168. std::vector<shared_ptr<Blob<Dtype> > > hdf_blobs_;  
  169. std::vector<unsigned int> data_permutation_;  
  170. std::vector<unsigned int> file_permutation_;  
  171. };  
  172. /** 
  173. * @brief Write blobs to disk as HDF5 files. 
  174. * TODO(dox): thorough documentation for Forward and proto params. 
  175. */  
  176. template <typename Dtype>  
  177. class HDF5OutputLayer : public Layer<Dtype> {  
  178. public:  
  179. explicit HDF5OutputLayer(const LayerParameter& param)  
  180. : Layer<Dtype>(param), file_opened_(false) {}  
  181. virtual ~HDF5OutputLayer();  
  182. virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  183. const vector<Blob<Dtype>*>& top);  
  184. // Data layers should be shared by multiple solvers in parallel  
  185. virtual inline bool ShareInParallel() const { return true; }  
  186. // Data layers have no bottoms, so reshaping is trivial.  
  187. virtual void Reshape(const vector<Blob<Dtype>*>& bottom,  
  188. const vector<Blob<Dtype>*>& top) {  
  189. }  
  190. virtual inline const char* type() const { return "HDF5Output"; }  
  191. // TODO: no limit on the number of blobs  
  192. virtual inline int ExactNumBottomBlobs() const { return 2; }  
  193. virtual inline int ExactNumTopBlobs() const { return 0; }  
  194. inline std::string file_name() const { return file_name_; }  
  195. protected:  
  196. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  197. const vector<Blob<Dtype>*>& top);  
  198. virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,  
  199. const vector<Blob<Dtype>*>& top);  
  200. virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,  
  201. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);  
  202. virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,  
  203. const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);  
  204. virtual void SaveBlobs();  
  205. bool file_opened_;  
  206. std::string file_name_;  
  207. hid_t file_id_;  
  208. Blob<Dtype> data_blob_;  
  209. Blob<Dtype> label_blob_;  
  210. };  
  211. /** 
  212. * @brief Provides data to the Net from image files. 
  213. * TODO(dox): thorough documentation for Forward and proto params. 
  214. */  
  215. template <typename Dtype>  
  216. class ImageDataLayer : public BasePrefetchingDataLayer<Dtype> {  
  217. public:  
  218. explicit ImageDataLayer(const LayerParameter& param)  
  219. : BasePrefetchingDataLayer<Dtype>(param) {}  
  220. virtual ~ImageDataLayer();  
  221. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  222. const vector<Blob<Dtype>*>& top);  
  223. virtual inline const char* type() const { return "ImageData"; }  
  224. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  225. virtual inline int ExactNumTopBlobs() const { return 2; }  
  226. vector<std::pair<std::string, std:: vector<float>> > lines_;  
  227. shared_ptr<Caffe::RNG> prefetch_rng_;  
  228. virtual void ShuffleImages();  
  229. virtual void load_batch(Batch<Dtype>* batch);  
  230. int lines_id_;  
  231. };  
  232. /** 
  233. * @brief Provides data to the Net from memory. 
  234. * TODO(dox): thorough documentation for Forward and proto params. 
  235. */  
  236. template <typename Dtype>  
  237. class MemoryDataLayer : public BaseDataLayer<Dtype> {  
  238. public:  
  239. explicit MemoryDataLayer(const LayerParameter& param)  
  240. : BaseDataLayer<Dtype>(param), has_new_data_(false) {}  
  241. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  242. const vector<Blob<Dtype>*>& top);  
  243. virtual inline const char* type() const { return "MemoryData"; }  
  244. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  245. virtual inline int ExactNumTopBlobs() const { return 2; }  
  246. virtual void AddDatumVector(const vector<Datum>& datum_vector);  
  247. #ifdef USE_OPENCV  
  248. virtual void AddMatVector(const vector<cv::Mat>& mat_vector,  
  249. const vector<int>& labels);  
  250. #endif // USE_OPENCV  
  251. // Reset should accept const pointers, but can't, because the memory  
  252. // will be given to Blob, which is mutable  
  253. void Reset(Dtype* data, Dtype* label, int n);  
  254. void set_batch_size(int new_size);  
  255. int batch_size() { return batch_size_; }  
  256. int channels() { return channels_; }  
  257. int height() { return height_; }  
  258. int width() { return width_; }  
  259. protected:  
  260. virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  261. const vector<Blob<Dtype>*>& top);  
  262. int batch_size_, channels_, height_, width_, size_;  
  263. Dtype* data_;  
  264. Dtype* labels_;  
  265. int n_;  
  266. size_t pos_;  
  267. Blob<Dtype> added_data_;  
  268. Blob<Dtype> added_label_;  
  269. bool has_new_data_;  
  270. };  
  271. /** 
  272. * @brief Provides data to the Net from windows of images files, specified 
  273. * by a window data file. 
  274. * TODO(dox): thorough documentation for Forward and proto params. 
  275. */  
  276. template <typename Dtype>  
  277. class WindowDataLayer : public BasePrefetchingDataLayer<Dtype> {  
  278. public:  
  279. explicit WindowDataLayer(const LayerParameter& param)  
  280. : BasePrefetchingDataLayer<Dtype>(param) {}  
  281. virtual ~WindowDataLayer();  
  282. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  283. const vector<Blob<Dtype>*>& top);  
  284. virtual inline const char* type() const { return "WindowData"; }  
  285. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  286. virtual inline int ExactNumTopBlobs() const { return 2; }  
  287. protected:  
  288. virtual unsigned int PrefetchRand();  
  289. virtual void load_batch(Batch<Dtype>* batch);  
  290. shared_ptr<Caffe::RNG> prefetch_rng_;  
  291. vector<std::pair<std::string, vector<int> > > image_database_;  
  292. enum WindowField { IMAGE_INDEX, LABEL, OVERLAP, X1, Y1, X2, Y2, NUM };  
  293. vector<vector<float> > fg_windows_;  
  294. vector<vector<float> > bg_windows_;  
  295. Blob<Dtype> data_mean_;  
  296. vector<Dtype> mean_values_;  
  297. bool has_mean_file_;  
  298. bool has_mean_values_;  
  299. bool cache_images_;  
  300. vector<std::pair<std::string, Datum > > image_database_cache_;  
  301. };  
  302. /** 
  303. * @brief Provides data to the Net from image files. 
  304. * TODO(dox): thorough documentation for Forward and proto params. 
  305. */  
  306. template <typename Dtype>  
  307. class MultiLabelImageDataLayer : public BasePrefetchingDataLayer<Dtype> {  
  308. public:  
  309. explicit MultiLabelImageDataLayer(const LayerParameter& param)  
  310. : BasePrefetchingDataLayer<Dtype>(param) {}  
  311. virtual ~MultiLabelImageDataLayer();  
  312. virtual void DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  313. const vector<Blob<Dtype>*>& top);  
  314. virtual inline const char* type() const { return "MultiLabelImageData"; }  
  315. virtual inline int ExactNumBottomBlobs() const { return 0; }  
  316. virtual inline int ExactNumTopBlobs() const { return 2; }  
  317. protected:  
  318. shared_ptr<Caffe::RNG> prefetch_rng_;  
  319. virtual void ShuffleImages();  
  320. virtual void load_batch(Batch<Dtype>* batch);  
  321. vector<std::pair<std::string, shared_ptr<vector<Dtype> > > > lines_;  
  322. int label_count;  
  323. int lines_id_;  
  324. };  
  325. } // namespace caffe  
  326. #endif // CAFFE_DATA_LAYERS_HPP_  



3.改动data_layer.cpp

  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4. #include <stdint.h>  
  5.   
  6. #include <vector>  
  7.   
  8. #include "caffe/data_layers.hpp"  
  9. #include "caffe/proto/caffe.pb.h"  
  10. #include "caffe/util/benchmark.hpp"  
  11.   
  12. namespace caffe {  
  13.   
  14. template <typename Dtype>  
  15. DataLayer<Dtype>::DataLayer(const LayerParameter& param)  
  16.   : BasePrefetchingDataLayer<Dtype>(param),  
  17.     reader_(param) {  
  18. }  
  19.   
  20. template <typename Dtype>  
  21. DataLayer<Dtype>::~DataLayer() {  
  22.   this->StopInternalThread();  
  23. }  
  24.   
  25. template <typename Dtype>  
  26. void DataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  27.       const vector<Blob<Dtype>*>& top) {  
  28.   const int batch_size = this->layer_param_.data_param().batch_size();  
  29.   // Read a data point, and use it to initialize the top blob.  
  30.   Datum& datum = *(reader_.full().peek());  
  31.   
  32.   // Use data_transformer to infer the expected blob shape from datum.  
  33.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  34.   this->transformed_data_.Reshape(top_shape);  
  35.   // Reshape top[0] and prefetch_data according to the batch_size.  
  36.   top_shape[0] = batch_size;  
  37.   top[0]->Reshape(top_shape);  
  38.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  39.     this->prefetch_[i].data_.Reshape(top_shape);  
  40.   }  
  41.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  42.       << top[0]->channels() << "," << top[0]->height() << ","  
  43.       << top[0]->width();  
  44.   // label  
  45.   if (this->output_labels_) {  
  46.       /* 
  47.       vector<int> label_shape(1, batch_size); 
  48.     top[1]->Reshape(label_shape); 
  49.     for (int i = 0; i < this->PREFETCH_COUNT; ++i) { 
  50.       this->prefetch_[i].label_.Reshape(label_shape); 
  51.     } 
  52.     */  
  53.       top[1]->Reshape(batch_size,4,1,1);  
  54.       for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  55.           this->prefetch_[i].label_.Reshape(batch_size, 4, 1, 1);  
  56.       }  
  57.   }  
  58. }  
  59.   
  60. // This function is called on prefetch thread  
  61. template<typename Dtype>  
  62. void DataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  63.   CPUTimer batch_timer;  
  64.   batch_timer.Start();  
  65.   double read_time = 0;  
  66.   double trans_time = 0;  
  67.   CPUTimer timer;  
  68.   CHECK(batch->data_.count());  
  69.   CHECK(this->transformed_data_.count());  
  70.   
  71.   // Reshape according to the first datum of each batch  
  72.   // on single input batches allows for inputs of varying dimension.  
  73.   const int batch_size = this->layer_param_.data_param().batch_size();  
  74.   Datum& datum = *(reader_.full().peek());  
  75.   // Use data_transformer to infer the expected blob shape from datum.  
  76.   vector<int> top_shape = this->data_transformer_->InferBlobShape(datum);  
  77.   this->transformed_data_.Reshape(top_shape);  
  78.   // Reshape batch according to the batch_size.  
  79.   top_shape[0] = batch_size;  
  80.   batch->data_.Reshape(top_shape);  
  81.   
  82.   Dtype* top_data = batch->data_.mutable_cpu_data();  
  83.   Dtype* top_label = NULL;  // suppress warnings about uninitialized variables  
  84.   if (this->output_labels_) {  
  85.       top_label = batch->label_.mutable_cpu_data();  
  86.   }  
  87.   /* 
  88.   if (this->output_labels_) { 
  89.       for (int label_i = 0; label_i < datum.label_size(); label_i++){ 
  90.           top_label[item_id*datum.label_size() + label_i] = datum.label(label_i); 
  91.       } 
  92.   } 
  93.   */  
  94.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  95.     timer.Start();  
  96.     // get a datum  
  97.     Datum& datum = *(reader_.full().pop("Waiting for data"));  
  98.     read_time += timer.MicroSeconds();  
  99.     timer.Start();  
  100.     // Apply data transformations (mirror, scale, crop...)  
  101.     int offset = batch->data_.offset(item_id);  
  102.     this->transformed_data_.set_cpu_data(top_data + offset);  
  103.     this->data_transformer_->Transform(datum, &(this->transformed_data_));  
  104.     // Copy label.  
  105.     if (this->output_labels_) {  
  106.      // top_label[item_id] = datum.label();  
  107.         for (int label_i = 0; label_i < datum.label_size(); label_i++){  
  108.             top_label[item_id*datum.label_size()+label_i] = datum.label(label_i);  
  109.         }  
  110.     }  
  111.     trans_time += timer.MicroSeconds();  
  112.   
  113.     reader_.free().push(const_cast<Datum*>(&datum));  
  114.   }  
  115.   timer.Stop();  
  116.   batch_timer.Stop();  
  117.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  118.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  119.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  120. }  
  121.   
  122. INSTANTIATE_CLASS(DataLayer);  
  123. REGISTER_LAYER_CLASS(Data);  
  124.   
  125. }  // namespace caffe  
4.修改image_data_layer.cpp中label部分

  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3.   
  4. #include <fstream>  // NOLINT(readability/streams)  
  5. #include <iostream>  // NOLINT(readability/streams)  
  6. #include <string>  
  7. #include <utility>  
  8. #include <vector>  
  9.   
  10. #include "caffe/data_layers.hpp"  
  11. #include "caffe/util/benchmark.hpp"  
  12. #include "caffe/util/io.hpp"  
  13. #include "caffe/util/math_functions.hpp"  
  14. #include "caffe/util/rng.hpp"  
  15.   
  16. namespace caffe {  
  17.   
  18. template <typename Dtype>  
  19. ImageDataLayer<Dtype>::~ImageDataLayer<Dtype>() {  
  20.   this->StopInternalThread();  
  21. }  
  22.   
  23. template <typename Dtype>  
  24. void ImageDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  25.       const vector<Blob<Dtype>*>& top) {  
  26.   const int new_height = this->layer_param_.image_data_param().new_height();  
  27.   const int new_width  = this->layer_param_.image_data_param().new_width();  
  28.   const bool is_color  = this->layer_param_.image_data_param().is_color();  
  29.   string root_folder = this->layer_param_.image_data_param().root_folder();  
  30.   
  31.   CHECK((new_height == 0 && new_width == 0) ||  
  32.       (new_height > 0 && new_width > 0)) << "Current implementation requires "  
  33.       "new_height and new_width to be set at the same time.";  
  34.   // Read the file with filenames and labels  
  35.   const string& source = this->layer_param_.image_data_param().source();  
  36.   LOG(INFO) << "Opening file " << source;  
  37.   std::ifstream infile(source.c_str());  
  38.   string filename;  
  39.   //int label;  
  40.   float x1, y1, x2, y2;  
  41.   while (infile >> filename >> x1 >> y1 >> x2 >> y2) {  
  42.       std::vector<float> vec_label;  
  43.       vec_label.push_back(x1);  
  44.       vec_label.push_back(y1);  
  45.       vec_label.push_back(x2);  
  46.       vec_label.push_back(y2);  
  47.     lines_.push_back(std::make_pair(filename, vec_label));  
  48.   }  
  49.   
  50.   if (this->layer_param_.image_data_param().shuffle()) {  
  51.     // randomly shuffle data  
  52.     LOG(INFO) << "Shuffling data";  
  53.     const unsigned int prefetch_rng_seed = caffe_rng_rand();  
  54.     prefetch_rng_.reset(new Caffe::RNG(prefetch_rng_seed));  
  55.     ShuffleImages();  
  56.   }  
  57.   LOG(INFO) << "A total of " << lines_.size() << " images.";  
  58.   
  59.   lines_id_ = 0;  
  60.   // Check if we would need to randomly skip a few data points  
  61.   if (this->layer_param_.image_data_param().rand_skip()) {  
  62.     unsigned int skip = caffe_rng_rand() %  
  63.         this->layer_param_.image_data_param().rand_skip();  
  64.     LOG(INFO) << "Skipping first " << skip << " data points.";  
  65.     CHECK_GT(lines_.size(), skip) << "Not enough points to skip";  
  66.     lines_id_ = skip;  
  67.   }  
  68.   // Read an image, and use it to initialize the top blob.  
  69.   cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  70.                                     new_height, new_width, is_color);  
  71.   CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  72.   // Use data_transformer to infer the expected blob shape from a cv_image.  
  73.   vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  
  74.   this->transformed_data_.Reshape(top_shape);  
  75.   // Reshape prefetch_data and top[0] according to the batch_size.  
  76.   const int batch_size = this->layer_param_.image_data_param().batch_size();  
  77.   CHECK_GT(batch_size, 0) << "Positive batch size required";  
  78.   top_shape[0] = batch_size;  
  79.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  80.     this->prefetch_[i].data_.Reshape(top_shape);  
  81.   }  
  82.   top[0]->Reshape(top_shape);  
  83.   
  84.   LOG(INFO) << "output data size: " << top[0]->num() << ","  
  85.       << top[0]->channels() << "," << top[0]->height() << ","  
  86.       << top[0]->width();  
  87.   // label  
  88.   vector<int> label_shape(1, batch_size);  
  89.   top[1]->Reshape(batch_size,4,1,1);  
  90.   for (int i = 0; i < this->PREFETCH_COUNT; ++i) {  
  91.       this->prefetch_[i].label_.Reshape(batch_size, 4, 1, 1);  
  92.   }  
  93. }  
  94.   
  95. template <typename Dtype>  
  96. void ImageDataLayer<Dtype>::ShuffleImages() {  
  97.   caffe::rng_t* prefetch_rng =  
  98.       static_cast<caffe::rng_t*>(prefetch_rng_->generator());  
  99.   shuffle(lines_.begin(), lines_.end(), prefetch_rng);  
  100. }  
  101.   
  102. // This function is called on prefetch thread  
  103. template <typename Dtype>  
  104. void ImageDataLayer<Dtype>::load_batch(Batch<Dtype>* batch) {  
  105.   CPUTimer batch_timer;  
  106.   batch_timer.Start();  
  107.   double read_time = 0;  
  108.   double trans_time = 0;  
  109.   CPUTimer timer;  
  110.   CHECK(batch->data_.count());  
  111.   CHECK(this->transformed_data_.count());  
  112.   ImageDataParameter image_data_param = this->layer_param_.image_data_param();  
  113.   const int batch_size = image_data_param.batch_size();  
  114.   const int new_height = image_data_param.new_height();  
  115.   const int new_width = image_data_param.new_width();  
  116.   const bool is_color = image_data_param.is_color();  
  117.   string root_folder = image_data_param.root_folder();  
  118.   
  119.   // Reshape according to the first image of each batch  
  120.   // on single input batches allows for inputs of varying dimension.  
  121.   cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  122.       new_height, new_width, is_color);  
  123.   CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  124.   // Use data_transformer to infer the expected blob shape from a cv_img.  
  125.   vector<int> top_shape = this->data_transformer_->InferBlobShape(cv_img);  
  126.   this->transformed_data_.Reshape(top_shape);  
  127.   // Reshape batch according to the batch_size.  
  128.   top_shape[0] = batch_size;  
  129.   batch->data_.Reshape(top_shape);  
  130.   
  131.   Dtype* prefetch_data = batch->data_.mutable_cpu_data();  
  132.   //Dtype* prefetch_label = batch->label_.mutable_cpu_data();  
  133.   Dtype* prefetch_label = NULL;  
  134.   // datum scales  
  135.   const int lines_size = lines_.size();  
  136.   for (int item_id = 0; item_id < batch_size; ++item_id) {  
  137.     // get a blob  
  138.     timer.Start();  
  139.     CHECK_GT(lines_size, lines_id_);  
  140.     cv::Mat cv_img = ReadImageToCVMat(root_folder + lines_[lines_id_].first,  
  141.         new_height, new_width, is_color);  
  142.     CHECK(cv_img.data) << "Could not load " << lines_[lines_id_].first;  
  143.     read_time += timer.MicroSeconds();  
  144.     timer.Start();  
  145.     // Apply transformations (mirror, crop...) to the image  
  146.     int offset = batch->data_.offset(item_id);  
  147.     this->transformed_data_.set_cpu_data(prefetch_data + offset);  
  148.     this->data_transformer_->Transform(cv_img, &(this->transformed_data_));  
  149.     trans_time += timer.MicroSeconds();  
  150.     for (int label_i = 0; label_i < (lines_[lines_id_].second).size(); label_i++){  
  151.         prefetch_label[item_id*(lines_[lines_id_].second).size() + label_i] = (lines_[lines_id_].second)[label_i];  
  152.     }  
  153.     //prefetch_label[item_id] = lines_[lines_id_].second;  
  154.     // go to the next iter  
  155.     lines_id_++;  
  156.     if (lines_id_ >= lines_size) {  
  157.       // We have reached the end. Restart from the first.  
  158.       DLOG(INFO) << "Restarting data prefetching from start.";  
  159.       lines_id_ = 0;  
  160.       if (this->layer_param_.image_data_param().shuffle()) {  
  161.         ShuffleImages();  
  162.       }  
  163.     }  
  164.   }  
  165.   batch_timer.Stop();  
  166.   DLOG(INFO) << "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";  
  167.   DLOG(INFO) << "     Read time: " << read_time / 1000 << " ms.";  
  168.   DLOG(INFO) << "Transform time: " << trans_time / 1000 << " ms.";  
  169. }  
  170.   
  171. INSTANTIATE_CLASS(ImageDataLayer);  
  172. REGISTER_LAYER_CLASS(ImageData);  
  173.   
  174. }  // namespace caffe  
  175. #endif  // USE_OPENCV  

5.修改memory_data_layer.cpp

  1. #ifdef USE_OPENCV  
  2. #include <opencv2/core/core.hpp>  
  3. #endif  // USE_OPENCV  
  4.   
  5. #include <vector>  
  6.   
  7. #include "caffe/data_layers.hpp"  
  8.   
  9. namespace caffe {  
  10.   
  11. template <typename Dtype>  
  12. void MemoryDataLayer<Dtype>::DataLayerSetUp(const vector<Blob<Dtype>*>& bottom,  
  13.      const vector<Blob<Dtype>*>& top) {  
  14.   batch_size_ = this->layer_param_.memory_data_param().batch_size();  
  15.   channels_ = this->layer_param_.memory_data_param().channels();  
  16.   height_ = this->layer_param_.memory_data_param().height();  
  17.   width_ = this->layer_param_.memory_data_param().width();  
  18.   size_ = channels_ * height_ * width_;  
  19.   CHECK_GT(batch_size_ * size_, 0) <<  
  20.       "batch_size, channels, height, and width must be specified and"  
  21.       " positive in memory_data_param";  
  22.   vector<int> label_shape(1, batch_size_);  
  23.   top[0]->Reshape(batch_size_, channels_, height_, width_);  
  24.   top[1]->Reshape(label_shape);  
  25.   added_data_.Reshape(batch_size_, channels_, height_, width_);  
  26.   added_label_.Reshape(label_shape);  
  27.   data_ = NULL;  
  28.   labels_ = NULL;  
  29.   added_data_.cpu_data();  
  30.   added_label_.cpu_data();  
  31. }  
  32.   
  33. template <typename Dtype>  
  34. void MemoryDataLayer<Dtype>::AddDatumVector(const vector<Datum>& datum_vector) {  
  35.   CHECK(!has_new_data_) <<  
  36.       "Can't add data until current data has been consumed.";  
  37.   size_t num = datum_vector.size();  
  38.   CHECK_GT(num, 0) << "There is no datum to add.";  
  39.   CHECK_EQ(num % batch_size_, 0) <<  
  40.       "The added data must be a multiple of the batch size.";  
  41.   added_data_.Reshape(num, channels_, height_, width_);  
  42.   added_label_.Reshape(num, 1, 1, 1);  
  43.   // Apply data transformations (mirror, scale, crop...)  
  44.   this->data_transformer_->Transform(datum_vector, &added_data_);  
  45.   // Copy Labels  
  46.   Dtype* top_label = added_label_.mutable_cpu_data();  
  47.   for (int item_id = 0; item_id < num; ++item_id) {  
  48.     //top_label[item_id] = datum_vector[item_id].label();  
  49.       int label_num = datum_vector[item_id].label_size();  
  50.       for (int label_i = 0; label_i < label_num; label_i++){  
  51.           top_label[item_id*label_num + label_i] = datum_vector[item_id].label(label_i);  
  52.       }  
  53.   }  
  54.   // num_images == batch_size_  
  55.   Dtype* top_data = added_data_.mutable_cpu_data();  
  56.   Reset(top_data, top_label, num);  
  57.   has_new_data_ = true;  
  58. }  
  59.   
  60. #ifdef USE_OPENCV  
  61. template <typename Dtype>  
  62. void MemoryDataLayer<Dtype>::AddMatVector(const vector<cv::Mat>& mat_vector,  
  63.     const vector<int>& labels) {  
  64.   size_t num = mat_vector.size();  
  65.   CHECK(!has_new_data_) <<  
  66.       "Can't add mat until current data has been consumed.";  
  67.   CHECK_GT(num, 0) << "There is no mat to add";  
  68.   CHECK_EQ(num % batch_size_, 0) <<  
  69.       "The added data must be a multiple of the batch size.";  
  70.   added_data_.Reshape(num, channels_, height_, width_);  
  71.   added_label_.Reshape(num, 1, 1, 1);  
  72.   // Apply data transformations (mirror, scale, crop...)  
  73.   this->data_transformer_->Transform(mat_vector, &added_data_);  
  74.   // Copy Labels  
  75.   Dtype* top_label = added_label_.mutable_cpu_data();  
  76.   for (int item_id = 0; item_id < num; ++item_id) {  
  77.     top_label[item_id] = labels[item_id];  
  78.   }  
  79.   // num_images == batch_size_  
  80.   Dtype* top_data = added_data_.mutable_cpu_data();  
  81.   Reset(top_data, top_label, num);  
  82.   has_new_data_ = true;  
  83. }  
  84. #endif  // USE_OPENCV  
  85.   
  86. template <typename Dtype>  
  87. void MemoryDataLayer<Dtype>::Reset(Dtype* data, Dtype* labels, int n) {  
  88.   CHECK(data);  
  89.   CHECK(labels);  
  90.   CHECK_EQ(n % batch_size_, 0) << "n must be a multiple of batch size";  
  91.   // Warn with transformation parameters since a memory array is meant to  
  92.   // be generic and no transformations are done with Reset().  
  93.   if (this->layer_param_.has_transform_param()) {  
  94.     LOG(WARNING) << this->type() << " does not transform array data on Reset()";  
  95.   }  
  96.   data_ = data;  
  97.   labels_ = labels;  
  98.   n_ = n;  
  99.   pos_ = 0;  
  100. }  
  101.   
  102. template <typename Dtype>  
  103. void MemoryDataLayer<Dtype>::set_batch_size(int new_size) {  
  104.   CHECK(!has_new_data_) <<  
  105.       "Can't change batch_size until current data has been consumed.";  
  106.   batch_size_ = new_size;  
  107.   added_data_.Reshape(batch_size_, channels_, height_, width_);  
  108.   added_label_.Reshape(batch_size_, 1, 1, 1);  
  109. }  
  110.   
  111. template <typename Dtype>  
  112. void MemoryDataLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,  
  113.       const vector<Blob<Dtype>*>& top) {  
  114.   CHECK(data_) << "MemoryDataLayer needs to be initalized by calling Reset";  
  115.   top[0]->Reshape(batch_size_, channels_, height_, width_);  
  116.   top[1]->Reshape(batch_size_, 1, 1, 1);  
  117.   top[0]->set_cpu_data(data_ + pos_ * size_);  
  118.   top[1]->set_cpu_data(labels_ + pos_);  
  119.   pos_ = (pos_ + batch_size_) % n_;  
  120.   if (pos_ == 0)  
  121.     has_new_data_ = false;  
  122. }  
  123.   
  124. INSTANTIATE_CLASS(MemoryDataLayer);  
  125. REGISTER_LAYER_CLASS(MemoryData);  
  126.   
  127. }  // namespace caffe  

6.修改convet_imaget.cpp

  1. // This program converts a set of images to a lmdb/leveldb by storing them  
  2. // as Datum proto buffers.  
  3. // Usage:  
  4. //   convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME  
  5. //  
  6. // where ROOTFOLDER is the root folder that holds all the images, and LISTFILE  
  7. // should be a list of files as well as their labels, in the format as  
  8. //   subfolder1/file1.JPEG 7  
  9. //   ....  
  10.   
  11. #include <algorithm>  
  12. #include <fstream>  // NOLINT(readability/streams)  
  13. #include <string>  
  14. #include <utility>  
  15. #include <vector>  
  16.   
  17. #include "boost/scoped_ptr.hpp"  
  18. #include "gflags/gflags.h"  
  19. #include "glog/logging.h"  
  20.   
  21. #include "caffe/proto/caffe.pb.h"  
  22. #include "caffe/util/db.hpp"  
  23. #include "caffe/util/io.hpp"  
  24. #include "caffe/util/rng.hpp"  
  25.   
  26. using namespace caffe;  // NOLINT(build/namespaces)  
  27. using std::pair;  
  28. using boost::scoped_ptr;  
  29.   
  30. DEFINE_bool(gray, false,  
  31.     "When this option is on, treat images as grayscale ones");  
  32. DEFINE_bool(shuffle, false,  
  33.     "Randomly shuffle the order of images and their labels");  
  34. DEFINE_string(backend, "lmdb",  
  35.         "The backend {lmdb, leveldb} for storing the result");  
  36. DEFINE_int32(resize_width, 0, "Width images are resized to");  
  37. DEFINE_int32(resize_height, 0, "Height images are resized to");  
  38. DEFINE_bool(check_size, false,  
  39.     "When this option is on, check that all the datum have the same size");  
  40. DEFINE_bool(encoded, false,  
  41.     "When this option is on, the encoded image will be save in datum");  
  42. DEFINE_string(encode_type, "",  
  43.     "Optional: What type should we encode the image as ('png','jpg',...).");  
  44.   
  45. int main(int argc, char** argv) {  
  46. #ifdef USE_OPENCV  
  47.   //::google::InitGoogleLogging(argv[0]);  
  48.   // Print output to stderr (while still logging)  
  49.   FLAGS_alsologtostderr = 1;  
  50.   
  51. #ifndef GFLAGS_GFLAGS_H_  
  52.   namespace gflags = google;  
  53. #endif  
  54.   
  55.   gflags::SetUsageMessage("Convert a set of images to the leveldb/lmdb\n"  
  56.         "format used as input for Caffe.\n"  
  57.         "Usage:\n"  
  58.         "    convert_imageset [FLAGS] ROOTFOLDER/ LISTFILE DB_NAME\n"  
  59.         "The ImageNet dataset for the training demo is at\n"  
  60.         "    http://www.image-net.org/download-images\n");  
  61.   caffe::GlobalInit(&argc, &argv);  
  62.   
  63.   if (argc < 4) {  
  64.     gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/convert_imageset");  
  65.     return 1;  
  66.   }  
  67.   
  68.   const bool is_color = !FLAGS_gray;  
  69.   const bool check_size = FLAGS_check_size;  
  70.   const bool encoded = FLAGS_encoded;  
  71.   const string encode_type = FLAGS_encode_type;  
  72.   
  73.   std::ifstream infile(argv[2]);  
  74.   std::vector<std::pair<std::string, vector<float>> > lines;  
  75.   std::string filename;  
  76.   /* 
  77.   int label; 
  78.   while (infile >> filename >> label) { 
  79.     lines.push_back(std::make_pair(filename, label)); 
  80.   } 
  81.   */  
  82.   float x1, y1, x2, y2;  
  83.   while (infile >> filename >> x1 >> y1 >> x2 >> y2) {  
  84.       std::vector<float> vec_label;  
  85.       vec_label.push_back(x1);  
  86.       vec_label.push_back(y1);  
  87.       vec_label.push_back(x2);  
  88.       vec_label.push_back(y2);  
  89.       lines.push_back(std::make_pair(filename, vec_label));  
  90.   }  
  91.   if (FLAGS_shuffle) {  
  92.     // randomly shuffle data  
  93.     LOG(INFO) << "Shuffling data";  
  94.     shuffle(lines.begin(), lines.end());  
  95.   }  
  96.   LOG(INFO) << "A total of " << lines.size() << " images.";  
  97.   
  98.   if (encode_type.size() && !encoded)  
  99.     LOG(INFO) << "encode_type specified, assuming encoded=true.";  
  100.   
  101.   int resize_height = std::max<int>(0, FLAGS_resize_height);  
  102.   int resize_width = std::max<int>(0, FLAGS_resize_width);  
  103.   
  104.   // Create new DB  
  105.   scoped_ptr<db::DB> db(db::GetDB(FLAGS_backend));  
  106.   db->Open(argv[3], db::NEW);  
  107.   scoped_ptr<db::Transaction> txn(db->NewTransaction());  
  108.   
  109.   // Storing to db  
  110.   std::string root_folder(argv[1]);  
  111.   Datum datum;  
  112.   int count = 0;  
  113.   const int kMaxKeyLength = 256;  
  114.   char key_cstr[kMaxKeyLength];  
  115.   int data_size = 0;  
  116.   bool data_size_initialized = false;  
  117.   
  118.   for (int line_id = 0; line_id < lines.size(); ++line_id) {  
  119.     bool status;  
  120.     std::string enc = encode_type;  
  121.     if (encoded && !enc.size()) {  
  122.       // Guess the encoding type from the file name  
  123.       string fn = lines[line_id].first;  
  124.       size_t p = fn.rfind('.');  
  125.       if ( p == fn.npos )  
  126.         LOG(WARNING) << "Failed to guess the encoding of '" << fn << "'";  
  127.       enc = fn.substr(p);  
  128.       std::transform(enc.begin(), enc.end(), enc.begin(), ::tolower);  
  129.     }  
  130.     status = ReadImageToDatum(root_folder + lines[line_id].first,  
  131.         lines[line_id].second, resize_height, resize_width, is_color,  
  132.         enc, &datum);  
  133.     if (status == false) continue;  
  134.     if (check_size) {  
  135.       if (!data_size_initialized) {  
  136.         data_size = datum.channels() * datum.height() * datum.width();  
  137.         data_size_initialized = true;  
  138.       } else {  
  139.         const std::string& data = datum.data();  
  140.         CHECK_EQ(data.size(), data_size) << "Incorrect data field size "  
  141.             << data.size();  
  142.       }  
  143.     }  
  144.     // sequential  
  145.     int length = sprintf_s(key_cstr, kMaxKeyLength, "%08d_%s", line_id,  
  146.         lines[line_id].first.c_str());  
  147.   
  148.     // Put in db  
  149.     string out;  
  150.     CHECK(datum.SerializeToString(&out));  
  151.     txn->Put(string(key_cstr, length), out);  
  152.   
  153.     if (++count % 1000 == 0) {  
  154.       // Commit db  
  155.       txn->Commit();  
  156.       txn.reset(db->NewTransaction());  
  157.       LOG(INFO) << "Processed " << count << " files.";  
  158.     }  
  159.   }  
  160.   // write the last batch  
  161.   if (count % 1000 != 0) {  
  162.     txn->Commit();  
  163.     LOG(INFO) << "Processed " << count << " files.";  
  164.   }  
  165. #else  
  166.   LOG(FATAL) << "This tool requires OpenCV; compile with USE_OPENCV.";  
  167. #endif  // USE_OPENCV  
  168.   return 0;  
  169. }  

7.修改io.cpp (只贴了部分需要修改的程序)

  1. bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  2.     const int height, const int width, const bool is_color,  
  3.     const std::string & encoding, Datum* datum) {  
  4.   cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);  
  5.   if (cv_img.data) {  
  6.     if (encoding.size()) {  
  7.       if ( (cv_img.channels() == 3) == is_color && !height && !width &&  
  8.           matchExt(filename, encoding) )  
  9.         return ReadFileToDatum(filename, labels, datum);  
  10.       std::vector<uchar> buf;  
  11.       cv::imencode("."+encoding, cv_img, buf);  
  12.       datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),  
  13.                       buf.size()));  
  14.      // datum->set_label(label);  
  15.       datum->mutable_label()->Clear();  
  16.       for (int label_i = 0; label_i < labels.size(); label_i++){  
  17.           datum->add_label(labels[label_i]);  
  18.       }  
  19.       datum->set_encoded(true);  
  20.       return true;  
  21.     }  
  22.     CVMatToDatum(cv_img, datum);  
  23.    // datum->set_label(label);  
  24.     datum->mutable_label()->Clear();  
  25.     for (int label_i = 0; label_i < labels.size(); label_i++){  
  26.         datum->add_label(labels[label_i]);  
  27.     }  
  28.     return true;  
  29.   } else {  
  30.     return false;  
  31.   }  
  32. }  
  33. #endif  // USE_OPENCV  
  34.   
  35. bool ReadFileToDatum(const string& filename, const std::vector<float> labels,  
  36.     Datum* datum) {  
  37.   std::streampos size;  
  38.   
  39.   fstream file(filename.c_str(), ios::in|ios::binary|ios::ate);  
  40.   if (file.is_open()) {  
  41.     size = file.tellg();  
  42.     std::string buffer(size, ' ');  
  43.     file.seekg(0, ios::beg);  
  44.     file.read(&buffer[0], size);  
  45.     file.close();  
  46.     datum->set_data(buffer);  
  47.   //  datum->set_label(label);  
  48.     datum->mutable_label()->Clear();  
  49.     for (int label_i = 0; label_i < labels.size(); label_i++){  
  50.         datum->add_label(labels[label_i]);  
  51.     }  
  52.     datum->set_encoded(true);  
  53.     return true;  
  54.   } else {  
  55.     return false;  
  56.   }  
  57. }  

8.修改io.hpp (只贴了部分需要修改的程序)

  1. bool ReadFileToDatum(const string& filename, const std::vector<float> labels, Datum* datum);  
  2.   
  3. inline bool ReadFileToDatum(const string& filename, Datum* datum) {  
  4.  // return ReadFileToDatum(filename, -1, datum);  
  5.     return 0;  
  6. }  
  7.   
  8. bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  9.     const int height, const int width, const bool is_color,  
  10.     const std::string & encoding, Datum* datum);  
  11.   
  12. inline bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  13.     const int height, const int width, const bool is_color, Datum* datum) {  
  14.   return ReadImageToDatum(filename, labels, height, width, is_color,  
  15.                           "", datum);  
  16. }  
  17.   
  18. inline bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  19.     const int height, const int width, Datum* datum) {  
  20.   return ReadImageToDatum(filename, labels, height, width, true, datum);  
  21. }  
  22.   
  23. inline bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  24.     const bool is_color, Datum* datum) {  
  25.   return ReadImageToDatum(filename, labels, 0, 0, is_color, datum);  
  26. }  
  27.   
  28. inline bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  29.     Datum* datum) {  
  30.   return ReadImageToDatum(filename, labels, 0, 0, true, datum);  
  31. }  
  32.   
  33. inline bool ReadImageToDatum(const string& filename, const std::vector<float> labels,  
  34.     const std::string & encoding, Datum* datum) {  
  35.   return ReadImageToDatum(filename, labels, 0, 0, true, encoding, datum);  
  36. }  

完成上述修改之后即可进行编译得到新的convert_image_set等可执行程序。


二.将自己的数据集转成leveldb格式

基本跟http://blog.csdn.net/messiran10/article/details/49159559的流程一样,主要是以下两点需要变化:

1.样本说明文件

train_samples/10007.jpg 0.491667 0.529412 0.450000 0.352941 需要把一维的label转成4维的label

2.模型配置文件

需要把softmax loss层换成 平方损失层

需要去掉accuracy层(否则会出错)






本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
caffe源码学习(六) 自定义层
深度学习caffe的代码怎么读?
[caffe解读] caffe从数学公式到代码实现1-导论
跨平台Caffe及I/O模型与并行方案(三)
基于NumPy实现随机梯度下降算法
caffe 源码分析[三]:Euclidean loss layer
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服