caffe源码c++学习笔记

原文地址:http://blog.csdn.net/hjimce/article/details/48933845

作者:hjimce

一、预测分类

最近几天为了希望深入理解caffe,于是便开始学起了caffe函数的c++调用,caffe的函数调用例子网上很少,需要自己慢慢的摸索,即便是找到了例子,有的时候caffe版本不一样,也会出现错误。对于预测分类的函数调用,caffe为我们提供了一个例子,一开始我懒得解读这个例子,网上找了一些分类预测的例子,总是会出现各种各样的错误,于是没办法最后只能老老实实的学官方给的例子比较实在,因此最后自己把代码解读了一下,然后自己整理成自己的类,这个类主要用于训练好模型后,我们要进行调用预测一张新输入图片的类别。

头文件:

[cpp] view plaincopy
  1. /*
  2. * Classifier.h
  3. *
  4. *  Created on: Oct 6, 2015
  5. *      Author: hjimce
  6. */
  7. #ifndef CLASSIFIER_H_
  8. #define CLASSIFIER_H_
  9. #include <caffe/caffe.hpp>
  10. #include <opencv2/core/core.hpp>
  11. #include <opencv2/highgui/highgui.hpp>
  12. #include <opencv2/imgproc/imgproc.hpp>
  13. #include <algorithm>
  14. #include <iosfwd>
  15. #include <memory>
  16. #include <string>
  17. #include <utility>
  18. #include <vector>
  19. using namespace caffe;
  20. using std::string;
  21. /* std::pair (标签, 属于该标签的概率)*/
  22. typedef std::pair<string, float> Prediction;
  23. class Classifier
  24. {
  25. public:
  26. Classifier(const string& model_file, const string& trained_file,const string& mean_file);
  27. std::vector<Prediction> Classify(const cv::Mat& img, int N = 1);//N的默认值,我选择1,因为我的项目判断的图片,一般图片里面就只有一个种类
  28. void SetLabelString(std::vector<string>strlabel);//用于设置label的名字,有n个类,那么就有n个string的名字
  29. private:
  30. void SetMean(const string& mean_file);
  31. std::vector<float> Predict(const cv::Mat& img);
  32. void WrapInputLayer(std::vector<cv::Mat>* input_channels);
  33. void Preprocess(const cv::Mat& img,
  34. std::vector<cv::Mat>* input_channels);
  35. private:
  36. shared_ptr<Net<float> > net_;//网络
  37. cv::Size input_geometry_;//网络输入图片的大小cv::Size(height,width)
  38. int num_channels_;//网络输入图片的通道数
  39. cv::Mat mean_;//均值图片
  40. std::vector<string> labels_;
  41. };
  42. #endif /* CLASSIFIER_H_ */

源文件:

[cpp] view plaincopy
  1. /*
  2. * Classifier.cpp
  3. *
  4. *  Created on: Oct 6, 2015
  5. *      Author: hjimce
  6. */
  7. #include "Classifier.h"
  8. using namespace caffe;
  9. Classifier::Classifier(const string& model_file,const string& trained_file,const string& mean_file)
  10. {
  11. //设置计算模式为CPU
  12. Caffe::set_mode(Caffe::CPU);
  13. //加载网络模型,
  14. net_.reset(new Net<float>(model_file, TEST));
  15. //加载已经训练好的参数
  16. net_->CopyTrainedLayersFrom(trained_file);
  17. CHECK_EQ(net_->num_inputs(), 1) << "Network should have exactly one input.";
  18. CHECK_EQ(net_->num_outputs(), 1) << "Network should have exactly one output.";
  19. //输入层
  20. Blob<float>* input_layer = net_->input_blobs()[0];
  21. num_channels_ = input_layer->channels();
  22. //输入层一般是彩色图像、或灰度图像,因此需要进行判断,对于Alexnet为三通道彩色图像
  23. CHECK(num_channels_ == 3 || num_channels_ == 1)<< "Input layer should have 1 or 3 channels.";
  24. //网络输入层的图片的大小,对于Alexnet大小为227*227
  25. input_geometry_ = cv::Size(input_layer->width(), input_layer->height());
  26. //设置均值
  27. SetMean(mean_file);
  28. }
  29. static bool PairCompare(const std::pair<float, int>& lhs,
  30. const std::pair<float, int>& rhs) {
  31. return lhs.first > rhs.first;
  32. }
  33. //函数用于返回向量v的前N个最大值的索引,也就是返回概率最大的五种物体的标签
  34. //如果你是二分类问题,那么这个N直接选择1
  35. static std::vector<int> Argmax(const std::vector<float>& v, int N)
  36. {
  37. //根据v的大小进行排序,因为要返回索引,所以需要借助于pair
  38. std::vector<std::pair<float, int> > pairs;
  39. for (size_t i = 0; i < v.size(); ++i)
  40. pairs.push_back(std::make_pair(v[i], i));
  41. std::partial_sort(pairs.begin(), pairs.begin() + N, pairs.end(), PairCompare);
  42. std::vector<int> result;
  43. for (int i = 0; i < N; ++i)
  44. result.push_back(pairs[i].second);
  45. return result;
  46. }
  47. //预测函数,输入一张图片img,希望预测的前N种概率最大的,我们一般取N等于1
  48. //输入预测结果为std::make_pair,每个对包含这个物体的名字,及其相对于的概率
  49. std::vector<Prediction> Classifier::Classify(const cv::Mat& img, int N) {
  50. std::vector<float> output = Predict(img);
  51. N = std::min<int>(labels_.size(), N);
  52. std::vector<int> maxN = Argmax(output, N);
  53. std::vector<Prediction> predictions;
  54. for (int i = 0; i < N; ++i) {
  55. int idx = maxN[i];
  56. predictions.push_back(std::make_pair(labels_[idx], output[idx]));
  57. }
  58. return predictions;
  59. }
  60. void Classifier::SetLabelString(std::vector<string>strlabel)
  61. {
  62. labels_=strlabel;
  63. }
  64. //加载均值文件
  65. void Classifier::SetMean(const string& mean_file)
  66. {
  67. BlobProto blob_proto;
  68. ReadProtoFromBinaryFileOrDie(mean_file.c_str(), &blob_proto);
  69. /*把BlobProto 转换为 Blob<float>类型 */
  70. Blob<float> mean_blob;
  71. mean_blob.FromProto(blob_proto);
  72. //验证均值图片的通道个数是否与网络的输入图片的通道个数相同
  73. CHECK_EQ(mean_blob.channels(), num_channels_)<< "Number of channels of mean file doesn't match input layer.";
  74. //把三通道的图片分开存储,三张图片按顺序保存到channels中
  75. std::vector<cv::Mat> channels;
  76. float* data = mean_blob.mutable_cpu_data();
  77. for (int i = 0; i < num_channels_; ++i) {
  78. cv::Mat channel(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
  79. channels.push_back(channel);
  80. data += mean_blob.height() * mean_blob.width();
  81. }
  82. //重新合成一张图片
  83. cv::Mat mean;
  84. cv::merge(channels, mean);
  85. //计算每个通道的均值,得到一个三维的向量channel_mean,然后把三维的向量扩展成一张新的均值图片
  86. //这种图片的每个通道的像素值是相等的,这张均值图片的大小将和网络的输入要求一样
  87. cv::Scalar channel_mean = cv::mean(mean);
  88. mean_ = cv::Mat(input_geometry_, mean.type(), channel_mean);
  89. }
  90. //预测函数,输入一张图片
  91. std::vector<float> Classifier::Predict(const cv::Mat& img)
  92. {
  93. //?
  94. Blob<float>* input_layer = net_->input_blobs()[0];
  95. input_layer->Reshape(1, num_channels_, input_geometry_.height, input_geometry_.width);
  96. net_->Reshape();
  97. //输入带预测的图片数据,然后进行预处理,包括归一化、缩放等操作
  98. std::vector<cv::Mat> input_channels;
  99. WrapInputLayer(&input_channels);
  100. Preprocess(img, &input_channels);
  101. //前向传导
  102. net_->ForwardPrefilled();
  103. //把最后一层输出值,保存到vector中,结果就是返回每个类的概率
  104. Blob<float>* output_layer = net_->output_blobs()[0];
  105. const float* begin = output_layer->cpu_data();
  106. const float* end = begin + output_layer->channels();
  107. return std::vector<float>(begin, end);
  108. }
  109. /* 这个其实是为了获得net_网络的输入层数据的指针,然后后面我们直接把输入图片数据拷贝到这个指针里面*/
  110. void Classifier::WrapInputLayer(std::vector<cv::Mat>* input_channels)
  111. {
  112. Blob<float>* input_layer = net_->input_blobs()[0];
  113. int width = input_layer->width();
  114. int height = input_layer->height();
  115. float* input_data = input_layer->mutable_cpu_data();
  116. for (int i = 0; i < input_layer->channels(); ++i) {
  117. cv::Mat channel(height, width, CV_32FC1, input_data);
  118. input_channels->push_back(channel);
  119. input_data += width * height;
  120. }
  121. }
  122. //图片预处理函数,包括图片缩放、归一化、3通道图片分开存储
  123. //对于三通道输入CNN,经过该函数返回的是std::vector<cv::Mat>因为是三通道数据,索引用了vector
  124. void Classifier::Preprocess(const cv::Mat& img,std::vector<cv::Mat>* input_channels)
  125. {
  126. /*1、通道处理,因为我们如果是Alexnet网络,那么就应该是三通道输入*/
  127. cv::Mat sample;
  128. //如果输入图片是一张彩色图片,但是CNN的输入是一张灰度图像,那么我们需要把彩色图片转换成灰度图片
  129. if (img.channels() == 3 && num_channels_ == 1)
  130. cv::cvtColor(img, sample, CV_BGR2GRAY);
  131. else if (img.channels() == 4 && num_channels_ == 1)
  132. cv::cvtColor(img, sample, CV_BGRA2GRAY);
  133. //如果输入图片是灰度图片,或者是4通道图片,而CNN的输入要求是彩色图片,因此我们也需要把它转化成三通道彩色图片
  134. else if (img.channels() == 4 && num_channels_ == 3)
  135. cv::cvtColor(img, sample, CV_BGRA2BGR);
  136. else if (img.channels() == 1 && num_channels_ == 3)
  137. cv::cvtColor(img, sample, CV_GRAY2BGR);
  138. else
  139. sample = img;
  140. /*2、缩放处理,因为我们输入的一张图片如果是任意大小的图片,那么我们就应该把它缩放到227×227*/
  141. cv::Mat sample_resized;
  142. if (sample.size() != input_geometry_)
  143. cv::resize(sample, sample_resized, input_geometry_);
  144. else
  145. sample_resized = sample;
  146. /*3、数据类型处理,因为我们的图片是uchar类型,我们需要把数据转换成float类型*/
  147. cv::Mat sample_float;
  148. if (num_channels_ == 3)
  149. sample_resized.convertTo(sample_float, CV_32FC3);
  150. else
  151. sample_resized.convertTo(sample_float, CV_32FC1);
  152. //均值归一化,为什么没有大小归一化?
  153. cv::Mat sample_normalized;
  154. cv::subtract(sample_float, mean_, sample_normalized);
  155. /* 3通道数据分开存储 */
  156. cv::split(sample_normalized, *input_channels);
  157. CHECK(reinterpret_cast<float*>(input_channels->at(0).data) == net_->input_blobs()[0]->cpu_data()) << "Input channels are not wrapping the input layer of the network.";
  158. }

调用实例,下面这个实例是要用于性别预测的例子:

[cpp] view plaincopy
  1. //============================================================================
  2. // Name        : caffepredict.cpp
  3. // Author      :
  4. // Version     :
  5. // Copyright   : Your copyright notice
  6. // Description : Hello World in C++, Ansi-style
  7. //============================================================================
  8. #include <string>
  9. #include <vector>
  10. #include <fstream>
  11. #include "caffe/caffe.hpp"
  12. #include <opencv2/opencv.hpp>
  13. #include"Classifier.h"
  14. int main()
  15. {
  16. caffe::Caffe::set_mode(caffe::Caffe::CPU);
  17. cv::Mat src1;
  18. src1 = cv::imread("4.jpg");
  19. Classifier cl("deploy.prototxt", "gender_net.caffemodel","imagenet_mean.binaryproto");
  20. std::vector<string>label;
  21. label.push_back("male");
  22. label.push_back("female");
  23. cl.SetLabelString(label);
  24. std::vector<Prediction>pre=cl.Classify(src1);
  25. cv::imshow("1.jpg",src1);
  26. std::cout <<pre[0].first<< std::endl;
  27. return 0;
  28. }

二、文件数据

[cpp] view plaincopy
  1. /函数的作用是读取一张图片,并保存到到datum中
  2. //第一个参数:filename图片文件路径名
  3. //第二个参数:label图片的分类标签
  4. //第三、四个参数:图片resize新的宽高
  5. //调用方法:
  6. /*Datum datum
  7. ReadImageToDatum(“1.jpg”, 10, 256, 256, true,&datum)*/
  8. //把图片1.jpg,其标签为10的图片缩放到256*256,并保存为彩色图片,最后保存到datum当中
  9. bool ReadImageToDatum(const string& filename, const int label,
  10. const int height, const int width, const bool is_color,
  11. const std::string & encoding, Datum* datum) {
  12. cv::Mat cv_img = ReadImageToCVMat(filename, height, width, is_color);//读取图片到cv::Mat
  13. if (cv_img.data) {
  14. if (encoding.size()) {
  15. if ( (cv_img.channels() == 3) == is_color && !height && !width &&
  16. matchExt(filename, encoding) )
  17. return ReadFileToDatum(filename, label, datum);
  18. std::vector<uchar> buf;
  19. cv::imencode("."+encoding, cv_img, buf);
  20. datum->set_data(std::string(reinterpret_cast<char*>(&buf[0]),
  21. buf.size()));
  22. datum->set_label(label);
  23. datum->set_encoded(true);
  24. return true;
  25. }
  26. CVMatToDatum(cv_img, datum);//把图片由cv::Mat转换成Datum
  27. datum->set_label(label);//设置图片的标签
  28. return true;
  29. } else {
  30. return false;
  31. }
  32. }

**********************作者:hjimce   时间:2015.10.1   地址:http://blog.csdn.net/hjimce 转载请保留本行信息********************

深度学习(七)caffe源码c++学习笔记相关推荐

  1. 深度学习框架Caffe源码解析

    作者:薛云峰(https://github.com/HolidayXue),主要从事视频图像算法的研究, 本文来源微信公众号:深度学习大讲堂.  原文:深度学习框架Caffe源码解析  欢迎技术投稿. ...

  2. caffe源码c++学习笔记

    转载自:深度学习(七)caffe源码c++学习笔记 - hjimce的专栏 - 博客频道 - CSDN.NET http://blog.csdn.net/hjimce/article/details/ ...

  3. caffe源码深入学习6:超级详细的im2col绘图解析,分析caffe卷积操作的底层实现

       在先前的两篇博客中,笔者详细解析了caffe卷积层的定义与实现,可是在conv_layer.cpp与base_conv_layer.cpp中,卷积操作的实现仍然被隐藏,通过im2col_cpu函 ...

  4. caffe源码分析-layer

    本文主要分析caffe layer层,主要内容如下: 从整体上说明下caffe的layer层的类别,以及作用 通过proto定义与类Layer简要说明下Layer的核心成员变量; Layer类的核心成 ...

  5. STL源码剖析学习七:stack和queue

    STL源码剖析学习七:stack和queue stack是一种先进后出的数据结构,只有一个出口. 允许新增.删除.获取最顶端的元素,没有任何办法可以存取其他元素,不允许有遍历行为. 缺省情况下用deq ...

  6. caffe源码学习:softmaxWithLoss前向计算

    caffe源码学习:softmaxWithLoss 在caffe中softmaxwithLoss是由两部分组成,softmax+Loss组成,其实主要就是为了caffe框架的可扩展性. 表达式(1)是 ...

  7. 免费学习机器学习和深度学习的源码、学习笔记和框架分享

    机器学习和深度学习的免费学习源码.学习笔记和框架分享 python笔记 源码 python导入模块的的几种方式 在python中,字典按值排序 python中set的基本常用方法 python取出fr ...

  8. caffe源码学习——1.熟悉protobuf,会读caffe.proto

    要想学习caffe源码,首当其冲的要阅读的,就是caffe.proto这个文件.它定义了caffe中用到的许多结构化数据. caffe采用了Protocol Buffers的数据格式. 那么,Prot ...

  9. 深度学习03-sklearn.LinearRegression 源码学习

    在上次的代码重写中使用了sklearn.LinearRegression 类进行了线性回归之后猜测其使用的是常用的梯度下降+反向传播算法实现,所以今天来学习它的源码实现.但是在看到源码的一瞬间突然有种 ...

最新文章

  1. android点击展开textview,《Android APP可能有的东西》之UI篇:展开TextView全文
  2. linux sar命令 性能监控
  3. 思科与华为生成树协议的对接
  4. 又不能起床python好学吗
  5. java类与接口练习
  6. java程序员_哪些书是不可错过的?Java程序员书单分享
  7. Spring IOC学习心得之注册bean的依赖关系
  8. Bootstrap table的基础用法
  9. c#实现文件转base64和base64转文件(文件为任意格式)
  10. 如何快速实现一个抽签小程序
  11. 利用python构建马科维茨_Markowitz投资组合之Python模拟
  12. 微云Android2.2apk,微云安卓版V6.2.10
  13. OWASP Top 10 简单介绍
  14. 携程景区爬取 + 保存Excel
  15. [树的直径 树形DP] UOJ #11【UTR #1】ydc的大树
  16. HINSTANCE/HWND/CWnd/HANDLE 的区别
  17. -webkit-tap-highlight-color
  18. springboot 实现图片上传功能
  19. 软件测试用例 单元测试,软件单元测试的测试用例编写方法
  20. gridControl自动增加行添加数据

热门文章

  1. 为什么大型科技公司更会发生人员流失 标准 ceo 软件 技术 图 阅读2479 原文:Why Good People Leave Large Tech Companies 作者:steve
  2. 抢椅子游戏java_游戏教案小班抢椅子
  3. Modular Arithmetic 模算术
  4. APM - 零侵入监控Http服务
  5. 实战SSM_O2O商铺_12【商铺注册】View层之前台页面
  6. python 寻找比目标字母大的最小字符
  7. python 链表两数相加
  8. linux 病毒脚本,解析常见的Linux病毒
  9. python实验过程心得体会_20192416 实验四《Python程序设计》综合实践报告
  10. csgo卡程序关不掉_微信推QQ小程序,取代QQ?网友:这功能有用?