概览

这篇博客解析caffe函数入口caffe.cpp,主要内容为caffe启动框架,基本不涉及深度学习的具体内容,内容十分基础,适合新手阅读。下面所有的代码解析都以训练lenet手写数字体识别为例,其运行参数为:

caffe train --solver=examples/mnist/lenet_solver.prototxt $@

main函数

先把main函数贴上来

int main(int argc, char** argv) {// Print output to stderr (while still logging).FLAGS_alsologtostderr = 1;// Set versiongflags::SetVersionString(AS_STRING(CAFFE_VERSION));// Usage message.gflags::SetUsageMessage("command line brew\n""usage: caffe <command> <args>\n\n""commands:\n""  train           train or finetune a model\n""  test            score a model\n""  device_query    show GPU diagnostic information\n""  time            benchmark model execution time");// Run tool or show usage.caffe::GlobalInit(&argc, &argv);if (argc == 2) {
#ifdef WITH_PYTHON_LAYERtry {
#endifreturn GetBrewFunction(caffe::string(argv[1]))();
#ifdef WITH_PYTHON_LAYER} catch (bp::error_already_set) {PyErr_Print();return 1;}
#endif} else {gflags::ShowUsageWithFlagsRestrict(argv[0], "tools/caffe");}
}

main函数上来就是一个变量FLAGS_alsologtostderr,但vscode找不到该变量的定义。其实这个变量包括其他带有FLAGS前缀的变量是由gflags定义的,gflags 是 google 开源的用于处理命令行参数的项目。alsologtostderr指将日志输出到标准错误流中去。后面SetVersionString 的作用是当你使用caffe --version时能打印出caffe的版本信息,CAFFE_VERSION由Makefile指定.紧接着SetUsageMessage实际上设置了caffe的帮助信息,当运行caffe参数不正确或者使用--help参数时打印出usage信息。caffe::GlobalInit函数会根据命令行参数做一些初始化的工作,其定义在common.cpp中,具体如下:

void GlobalInit(int* pargc, char*** pargv) {// Google flags.::gflags::ParseCommandLineFlags(pargc, pargv, true);// Google logging.::google::InitGoogleLogging(*(pargv)[0]);// Provide a backtrace on segfault.::google::InstallFailureSignalHandler();
}

对于训练手写数字体识别:

只有一个参数solver =examples/mnist/lenet_solver.prototxt 解析后可以以FLAGS_solver来访问。包括solver model等用户自定义的命令行参数(非gflags默认的参数)定义在caffe.cpp里:

DEFINE_string(gpu, "","Optional; run in GPU mode on given device IDs separated by ','.""Use '-gpu all' to run on all available GPUs. The effective training ""batch size is multiplied by the number of devices.");
DEFINE_string(solver, "","The solver definition protocol buffer text file.");
DEFINE_string(model, "","The model definition protocol buffer text file.");

对于gflags更详细的信息可以参考google gflags 库完全使用

后面的InitGoogleLogging和InstallFailureSignalHandler用来处理日志和运行错误。

那么main函数怎么根据train test等参数进入到相应的train函数或test函数中去呢?

看这一行代码:

return GetBrewFunction(caffe::string(argv[1]))();

这个函数可以根据第一个参数argv[1](argv[0]是caffe本身的路径)来返回相应的函数,接下来我们来看GetBrewFunction是怎么实现这个功能的。

typedef int (*BrewFunction)(); //定义了一个函数指针类型,该类型指针指向一个参数为空返回值为int的函数
typedef std::map<caffe::string, BrewFunction> BrewMap;//定义了一个map类型,该类型的变量维护一个字典,函数名称(string)作为key,函数指针(BrewFunction)作为value
BrewMap g_brew_map;#define RegisterBrewFunction(func) \
namespace { \
class __Registerer_##func { \   //##表示合并字符串public: /* NOLINT */ \__Registerer_##func() { \g_brew_map[#func] = &func; \ #为字符串} \
}; \
__Registerer_##func g_registerer_##func; \
}static BrewFunction GetBrewFunction(const caffe::string& name) {if (g_brew_map.count(name)) {return g_brew_map[name];//根据name中的具体内容返回相应的函数指针} else {LOG(ERROR) << "Available caffe actions:";for (BrewMap::iterator it = g_brew_map.begin();it != g_brew_map.end(); ++it) {LOG(ERROR) << "\t" << it->first;}LOG(FATAL) << "Unknown action: " << name;return NULL;  // not reachable, just to suppress old compiler warnings.}
}
//下面是一个例子,详细说明train函数怎么填充到g_brew_map中
int train(){
}
RegisterBrewFunction(train)//这一句会根据宏定义被替换成下面的内容namespace{
class __Registerer_train{public:__Registerer_train(){g_brew_map["train"] = &train;}
};
__Registerer_train g_registerer_train; //实例化的过程中将train函数填充到字典g_brew_map中去了
}

根据上面一些注释,我们可以看出一个大概的框架:

1 定义一个字典,存储函数名到函数指针的映射。

2 通过RegisterBrewFunction(func)的宏定义来填充这个字典。

3 调用GetBrewFunction根据函数名返回相应的函数指针。

train函数

下面具体看train函数

// Train / Finetune a model.
int train() {CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train."; //FLAGS_solver <= 0 会输出CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())// snapshot 和 weight参数都没有,不管<< "Give a snapshot to resume training or weights to finetune ""but not both.";vector<string> stages = get_stages_from_flags(); //stages参数也没有,跳过caffe::SolverParameter solver_param;caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);//该行从lenet_solver.prototxt读取参数到solver_param中solver_param.mutable_train_state()->set_level(FLAGS_level); //level参数也没有,跳过for (int i = 0; i < stages.size(); i++) {solver_param.mutable_train_state()->add_stage(stages[i]);}// If the gpus flag is not provided, allow the mode and device to be set// in the solver prototxt.if (FLAGS_gpu.size() == 0     //从solverparam中读取GPU的信息,是否使用GPU,GPU的id之类的,初期可以不用特别关注&& solver_param.has_solver_mode()&& solver_param.solver_mode() == caffe::SolverParameter_SolverMode_GPU) {if (solver_param.has_device_id()) {FLAGS_gpu = "" +boost::lexical_cast<string>(solver_param.device_id());} else {  // Set default GPU if unspecifiedFLAGS_gpu = "" + boost::lexical_cast<string>(0);}}vector<int> gpus;get_gpus(&gpus);if (gpus.size() == 0) {LOG(INFO) << "Use CPU.";Caffe::set_mode(Caffe::CPU);} else {ostringstream s;for (int i = 0; i < gpus.size(); ++i) {s << (i ? ", " : "") << gpus[i];}LOG(INFO) << "Using GPUs " << s.str();
#ifndef CPU_ONLYcudaDeviceProp device_prop;for (int i = 0; i < gpus.size(); ++i) {cudaGetDeviceProperties(&device_prop, gpus[i]);LOG(INFO) << "GPU " << gpus[i] << ": " << device_prop.name;}
#endifsolver_param.set_device_id(gpus[0]);Caffe::SetDevice(gpus[0]);Caffe::set_mode(Caffe::GPU);Caffe::set_solver_count(gpus.size());}caffe::SignalHandler signal_handler(GetRequestedAction(FLAGS_sigint_effect),GetRequestedAction(FLAGS_sighup_effect));if (FLAGS_snapshot.size()) {solver_param.clear_weights();} else if (FLAGS_weights.size()) {solver_param.clear_weights();solver_param.add_weights(FLAGS_weights);}
//根据solver_param,生成solvershared_ptr<caffe::Solver<float> >solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));solver->SetActionFunction(signal_handler.GetActionFunction());if (FLAGS_snapshot.size()) {LOG(INFO) << "Resuming from " << FLAGS_snapshot;solver->Restore(FLAGS_snapshot.c_str());}LOG(INFO) << "Starting Optimization";if (gpus.size() > 1) {
#ifdef USE_NCCLcaffe::NCCL<float> nccl(solver);nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);
#elseLOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL";
#endif} else {//求解solversolver->Solve();}LOG(INFO) << "Optimization Done.";return 0;
}

solver的实例化

这里不涉及任何solver内部的细节,包括生成_net和test_net,具体的求解方法等内容,只剖析caffe怎样根据solverparam.type实例化不同的solver类。实际上这些内容和上面讲的根据命令行参数执行train还是test等函数的方法十分相似,但其过程更加复杂,还是简要的分析一下。

shared_ptr<caffe::Solver<float>>solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

caffe.cpp中的train函数中通过上述的代码定义了一个指向Solver<float>的shared_ptr。其中主要是通过调用SolverRegistry这个类的静态成员函数CreateSolver得到一个指向Solver的指针来构造shared_ptr类型的solver。而且由于C++多态的特性,solver是一个指向基类Solver类型的指针,通过solver这个智能指针来调用各个成员函数会调用到各个子类(SGDSolver等)的函数。

下面分析SolverRegistry具体是怎么做的:

typedef Solver<Dtype>* (*Creator)(const SolverParameter&);typedef std::map<string, Creator> CreatorRegistry;static CreatorRegistry& Registry() {static CreatorRegistry* g_registry_ = new CreatorRegistry();return *g_registry_;}
  static Solver<Dtype>* CreateSolver(const SolverParameter& param) {const string& type = param.type();CreatorRegistry& registry = Registry();CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type<< " (known types: " << SolverTypeListString() << ")";return registry[type](param);}

从上述代码可以看到也是维护了一个map由solverparam.type返回具体的solver<Dtype>指针

SolverRegistry类的构造函数是private的,也就是用我们没有办法去构造一个这个类的变量,这个类也没有数据成员,所有的成员函数也都是static的,可以直接调用。 CreateSolver函数先定义了string类型的变量type,表示Solver的类型,然后定义了一个key类型为string,value类型为Creator的map,变量名为registry,其中Creator是一个函数指针类型,指向的函数的参数为SolverParameter类型,返回类型为Solver<Dtype>*。如果是一个已经register过的Solver类型,那么registry.count(type)应该为1,然后通过registry这个map返回了我们需要类型的Solver的creator,并调用这个creator函数,将creator返回的Solver<Dtype>*返回。

Registry函数中定义了一个static的变量g_registry,这个变量是一个指向CreatorRegistry这个map类型的指针,然后直接返回,因为这个变量是static的,所以即使多次调用这个函数,也只会定义一个g_registry,可以在其他地方修改这个map里的内容,。事实上各个Solver的register的过程正是向g_registry指向的那个map里添加以Solver的type为key,对应的Creator函数指针为value的内容。

那包括SGDSolver等各种solver是怎么注册的呢?下面以注册SGDSolver为例说明

solver_factory.hpp文件中有两个宏定义如下:

#define REGISTER_SOLVER_CREATOR(type, creator)                                 \static SolverRegisterer<float> g_creator_f_##type(#type, creator<float>);    \static SolverRegisterer<double> g_creator_d_##type(#type, creator<double>)   \#define REGISTER_SOLVER_CLASS(type)                                            \template <typename Dtype>                                                    \Solver<Dtype>* Creator_##type##Solver(                                       \const SolverParameter& param)                                            \{                                                                            \return new type##Solver<Dtype>(param);                                     \}                                                                            \REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver)

sgd_solver.cpp文件末尾有

REGISTER_SOLVER_CLASS(SGD);

根据宏定义替换的结果如下:

template <typename Dtype>
Solver<Dtype>* Creator_SGD_Solver(const SolverParameter& param)
{return new SGDSolver<Dtype>(param);
}
static SolverRegisterer<float> g_creator_f_SGD("SGD",Creator_SGD_Solver<float>);
static SolverRegisterer<double> g_creator_f_SGD("SGD",Creator_SGD_Solver<double>);

即根据宏定义,定义了一个Creator函数指针可指的函数Creator_SGD_Solver,然后通过下面的函数将key和value注册进去:

template <typename Dtype>
class SolverRegisterer {public:SolverRegisterer(const string& type,Solver<Dtype>* (*creator)(const SolverParameter&)) {// LOG(INFO) << "Registering solver type: " << type;SolverRegistry<Dtype>::AddCreator(type, creator);}
};

AddCreator函数的源码不在此展示,具体细节阅读solver_factory.hpp

至此,生成solver的工厂模式应该讲清楚了,caffe的启动框架也差不多清晰了,接下来就是solver怎么根据solver_params生成net,以及net的前向和反向计算了。

参考:

Caffe中Solver解析

google gflags 库完全使用

caffe函数入口caffe.cpp详解相关推荐

  1. 互斥量、条件变量与pthread_cond_wait()函数的使用,详解(二)

    互斥量.条件变量与pthread_cond_wait()函数的使用,详解(二) 1.Linux"线程" 进程与线程之间是有区别的,不过linux内核只提供了轻量进程的支持,未实现线 ...

  2. 如何使用指向类的成员函数的指针(详解!)

    原文:如何使用指向类的成员函数的指针(详解!) 另外一篇英文参考:Member Function Pointers and the Fastest Possible C++ Delegates 我们首 ...

  3. c语言 access编程,C语言中access/_access函数的使用实例详解

    在Linux下,access函数的声明在文件中,声明如下: int access(const char *pathname, int mode); access函数用来判断指定的文件或目录是否存在(F ...

  4. 在python中使用关键字define定义函数_python自定义函数def的应用详解

    这里是三岁,来和大家唠唠自定义函数,这一个神奇的东西,带大家白话玩转自定义函数 自定义函数,编程里面的精髓! def 自定义函数的必要函数:def 使用方法:def 函数名(参数1,参数2,参数-): ...

  5. 函数assert()详解

    函数assert()详解: 断言assert是一个宏,该宏在<assert>中,,当使用assert时候,给他个参数,即一个判读为真的表达式.预处理器产生测试该断言的代码,如果断言不为真, ...

  6. php。defined,PHP defined()函数的使用图文详解

    PHP defined()函数的使用图文详解 PHP defined() 函数 例子 定义和用法 defined() 函数检查某常量是否存在. 若常量存在,则返回 true,否则返回 false. 语 ...

  7. python中tile的用法_python3中numpy函数tile的用法详解

    tile函数位于python模块 numpy.lib.shape_base中,他的功能是重复某个数组.比如tile(A,n),功能是将数组A重复n次,构成一个新的数组,我们还是使用具体的例子来说明问题 ...

  8. Delphi Format函数功能及用法详解

    DELPHI中Format函数功能及用法详解 DELPHI中Format函数功能及用法详解function Format(const Format: string; const Args: array ...

  9. python中的json函数_python中装饰器、内置函数、json的详解

    装饰器 装饰器本质上是一个Python函数,它可以让其他函数在不需要做任何代码变动的前提下增加额外功能,装饰器的返回值也是一个函数对象. 先看简单例子: def run(): time.sleep(1 ...

最新文章

  1. 面试者面试官,双向角度的程序员面试指南!
  2. 打印swift 变量的类型
  3. Office安装时报错1907的解决方法
  4. python元类的概念_Python中的元类编程 | 学步园
  5. LINUX下忘记MySQL的ROOT密码后修改,以及添加访问IP。
  6. linux iscsi
  7. strace命令学习
  8. 如果给一个单位做相关的软件,你认为最重要的是需要得到谁的支持,为什么...
  9. mac homebrew
  10. IIS上部署asp.net core2.1项目
  11. offsetTop和scrollTop的差别
  12. DataPipeline | PayPal庞姬桦:大数据在小微企业贷款上的运用
  13. this关键字的使用案例
  14. 行政区村界线_市政府批复!崇川区部分行政区划调整
  15. ImportError: No module named apex
  16. css样式的属性包括,css字体样式属性有哪些
  17. Excel表格数据导入
  18. 有什么好用的表单工具?
  19. ProGet 22.0 Enterprise Crack by Xacker
  20. PCL学习笔记5-sample consensus采样一致性算法

热门文章

  1. NU Virgos(圣女天团)
  2. mvn上传pom/jar至Nexus私服
  3. 嘉为科技出席2021腾讯云启产业生态年会,荣获“年度通用明星奖”
  4. 浅谈面向对象的编程思想:如何优雅地把大象装进冰箱?
  5. Tableau 添加加权平均参考线
  6. 网易猛犸:数据质量漫谈
  7. 5.1劳动节|致敬每一位数字安全劳动者
  8. 苏宁易购财报看点:加码线上业务布局,注册会员增至6.23亿人
  9. c语言编程植物信息查询系统,C:\WINDOWS\Desktop\导航库\植物\xzjs\hzc.htm
  10. 新彩虹世界密码系统是多少_希望最近的世界密码日是我们需要的最后一个