版权声明:本文为博主原创文章,转载请标注出处。 https://blog.csdn.net/Yan_Joy/article/details/53079185

Solver是求解学习模型的核心配置文件,网络确定后,solver就决定了学习的效果。本文结合caffe.proto和网上资料,对solver配置进行学习。

Solver
Caffe学习系列(7):solver及其配置,denny402


Solver在caffe中的定义

通常的solver文件与net文件相互关联,同样的net我们往往使用不同的solver尝试得到最好的效果,其运行代码为:

caffe train --solver=*_slover.prototxt
  • 1

关于solver的一切,都在caffe.proto文件中message SolverParameter 这一部分。

网络文件源

  // Proto filename for the train net, possibly combined with one or more// test nets.optional string net = 24;// Inline train net param, possibly combined with one or more test nets.optional NetParameter net_param = 25;optional string train_net = 1; // Proto filename for the train net.repeated string test_net = 2; // Proto filenames for the test nets.optional NetParameter train_net_param = 21; // Inline train net params.repeated NetParameter test_net_param = 22; // Inline test net params.
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

这是最开始的部分,需要说明net文件的位置。在这四个train_net_param, train_net, net_param, net字段中至少需要出现一个,当出现多个时,就会按着(1) test_net_param, (2) test_net, (3) net_param/net 的顺序依次求解。必须为每个test_net指定一个test_iter。还可以为每个test_net指定test_level和/或test_stage。注意的是:文件的路径要从caffe的根目录开始,其它的所有配置都是这样。
可以看到这几行的标签序号并不是顺序的,也说明caffe在不断地修改,下一个可用的序号是41。

网络状态

  // The states for the train/test nets. Must be unspecified or// specified once per net.//// By default, all states will have solver = true;// train_state will have phase = TRAIN,// and all test_state's will have phase = TEST.// Other defaults are set according to the NetState defaults.optional NetState train_state = 26;repeated NetState test_state = 27;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

网络状态必须是未指定的或者只能在一个网络中指定一次。
关于NetState,其定义为:

message NetState {optional Phase phase = 1 [default = TEST];optional int32 level = 2 [default = 0];repeated string stage = 3;
}
enum Phase {TRAIN = 0;TEST = 1;
}
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

迭代器

  // The number of iterations for each test net.repeated int32 test_iter = 3;
  • 1
  • 2

首先是test_iter,这需要与test layer中的batch_size结合起来理解。mnist数据中测试样本总数为10000,一次性执行全部数据效率很低,因此我们将测试数据分成几个批次来执行,每个批次的数量就是batch_size。假设我们设置batch_size为100,则需要迭代100次才能将10000个数据全部执行完。因此test_iter设置为100。执行完一次全部数据,称之为一个epoch。

  // The number of iterations between two testing phases.optional int32 test_interval = 4 [default = 0];optional bool test_compute_loss = 19 [default = false];// If true, run an initial test pass before the first iteration,// ensuring memory availability and printing the starting value of the loss.optional bool test_initialization = 32 [default = true];
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

test_interval是指测试间隔,每训练test_interval次,进行一次测试。同时test_compute_loss可以选择是否计算loss。test_initialization是指在第一次迭代前,计算初始的loss以确保内存可用。

  optional float base_lr = 5; // The base learning rate// the number of iterations between displaying info. If display = 0, no info// will be displayed.optional int32 display = 6;// Display the loss averaged over the last average_loss iterationsoptional int32 average_loss = 33 [default = 1];optional int32 max_iter = 7; // the maximum number of iterations// accumulate gradients over `iter_size` x `batch_size` instancesoptional int32 iter_size = 36 [default = 1];
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

base_lr指基础的学习率;display是信息显示间隔,迭代一定次数显示一次信息。average_loss用于显示在上次average_loss迭代中的平均损失。max_iter是最大迭代次数,需要合适设置达到精度、震荡的平衡。iter_size是迭代器大小,梯度的计算是通过iter_size x batch_size决定的。

学习策略

  optional string lr_policy = 8;optional float gamma = 9; // The parameter to compute the learning rate.optional float power = 10; // The parameter to compute the learning rate.optional float momentum = 11; // The momentum value.optional float weight_decay = 12; // The weight decay.// regularization types supported: L1 and L2// controlled by weight_decayoptional string regularization_type = 29 [default = "L2"];// the stepsize for learning rate policy "step"optional int32 stepsize = 13;// the stepsize for learning rate policy "multistep"repeated int32 stepvalue = 34;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

只要是梯度下降法来求解优化,都会有一个学习率,也叫步长。base_lr用于设置基础学习率,在迭代的过程中,可以对基础学习率进行调整。怎么样进行调整,就是调整的策略,由lr_policy来设置。caffe提供了多种policy:

  • fixed: 总是返回base_lr(学习率不变)
  • step: 返回 base_lr * gamma ^ (floor(iter / step))
    还需要设置stepsize参数以确定step,iter表示当前迭代次数。
  • exp: 返回base_lr * gamma ^ iter, iter为当前迭代次数
  • inv: 如果设置为inv,还需要设置一个power, 返回base_lr * (1 + gamma * iter) ^ (- power)
  • multistep: 如果设置为multistep,则还需要设置一个stepvalue。这个参数和step很相似,step是均匀等间隔变化,而multistep则是根据stepvalue值变化。
  • poly: 学习率进行多项式误差, 返回 base_lr (1 - iter/max_iter) ^ (power)
  • sigmoid: 学习率进行sigmod衰减,返回 base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))。

multistep示例:

base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "multistep"
gamma: 0.9
stepvalue: 5000
stepvalue: 7000
stepvalue: 8000
stepvalue: 9000
stepvalue: 9500
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

之后有momentum,上次梯度更新的权重;weight_decay权重衰减,防止过拟合;regularization_type正则化方式。

clip_gradients

optional float clip_gradients = 35 [default = -1];
  • 1

参数梯度的实际L2范数较大时,将clip_gradients设置为> = 0,以将参数梯度剪切到该L2范数。具体作用还不是很理解。

snapshot快照

  optional int32 snapshot = 14 [default = 0]; // The snapshot intervaloptional string snapshot_prefix = 15; // The prefix for the snapshot.// whether to snapshot diff in the results or not. Snapshotting diff will help// debugging but the final protocol buffer size will be much larger.optional bool snapshot_diff = 16 [default = false];enum SnapshotFormat {HDF5 = 0;BINARYPROTO = 1;}optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

快照可以将训练出来的model和solver状态进行保存,snapshot用于设置训练多少次后进行保存,默认为0,不保存。snapshot_prefix设置保存路径。还可以设置snapshot_diff,是否保存梯度值,保存有利于调试,但需要较大空间存储,默认为false,不保存。也可以设置snapshot_format,保存的类型。有两种选择:HDF5 和BINARYPROTO ,默认为BINARYPROTO。

运行模式

  enum SolverMode {CPU = 0;GPU = 1;}optional SolverMode solver_mode = 17 [default = GPU];// the device_id will that be used in GPU mode. Use device_id = 0 in default.optional int32 device_id = 18 [default = 0];// If non-negative, the seed with which the Solver will initialize the Caffe// random number generator -- useful for reproducible results. Otherwise,// (and by default) initialize using a seed derived from the system clock.optional int64 random_seed = 20 [default = -1];
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

设置CPU或GPU模式,在GPU下还可以指定使用哪一块GPU运行。random_seed用于初始生成随机数种子。

Solver类型

  // type of the solveroptional string type = 40 [default = "SGD"];// numerical stability for RMSProp, AdaGrad and AdaDelta and Adamoptional float delta = 31 [default = 1e-8];// parameters for the Adam solveroptional float momentum2 = 39 [default = 0.999];// RMSProp decay value// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)optional float rms_decay = 38;
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

type是solver的类型,目前有SGD、NESTEROV、ADAGRAD、RMSPROP、ADADELTA、ADAM = 5这六类。之后的一些是这些类型的特有参数,根据需要设置。

杂项

  // If true, print information about the state of the net that may help with// debugging learning problems.optional bool debug_info = 23 [default = false];// If false, don't save a snapshot after training finishes.optional bool snapshot_after_train = 28 [default = true];
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6

debug_info用于输出调试信息。snapshot_after_train用于训练后是否输出快照。

Solver 配置详解相关推荐

  1. elasticsearch-.yml(中文配置详解)

    此elasticsearch-.yml配置文件,是在$ES_HOME/config/下 elasticsearch-.yml(中文配置详解) # ======================== El ...

  2. (ASA) Cisco Web ××× 配置详解 [三部曲之一]

    (ASA) Cisco Web ××× 配置详解 [三部曲之一] 注意:本文仅对Web×××特性和配置作介绍,不包含SSL ×××配置,SSL ×××配置将在本版的后续文章中进行介绍.   首先,先来 ...

  3. mybatis 同名方法_MyBatis(四):xml配置详解

    目录 1.我们将 数据库的配置语句写在 db.properties 文件中 2.在 mybatis-configuration.xml 中加载db.properties文件并读取 通过源码我们可以分析 ...

  4. logback节点配置详解

    logback节点配置详解 一:根节点 <configuration></configuration> 属性 : debug : 默认为false ,设置为true时,将打印出 ...

  5. PM配置详解之一:企业结构

    1.维护计划工厂 功能说明 在公司结构中定义维护工厂(通常已经作为后勤工厂存在)和维护计划工厂(简称计划工厂). 维护工厂:设备所安装的位置,如某机组安装在合营公司,那么合营公司就是此机组的维护工厂, ...

  6. 转 Log4j.properties配置详解

    一.Log4j简介 Log4j有三个主要的组件:Loggers(记录器),Appenders (输出源)和Layouts(布局).这里可简单理解为日志类别,日志要输出的地方和日志以何种形式输出.综合使 ...

  7. Iptables防火墙配置详解

    iptables防火墙配置详解 iptables简介 iptables是基于内核的防火墙,功能非常强大,iptables内置了filter,nat和mangle三张表. (1)filter表负责过滤数 ...

  8. spring之旅第四篇-注解配置详解

    spring之旅第四篇-注解配置详解 一.引言 最近因为找工作,导致很长时间没有更新,找工作的时候你会明白浪费的时间后面都是要还的,现在的每一点努力,将来也会给你回报的,但行好事,莫问前程!努力总不会 ...

  9. php-fpm 启动参数及重要配置详解

    2019独角兽企业重金招聘Python工程师标准>>> php-fpm 启动参数及重要配置详解 约定几个目录 /usr/local/php/sbin/php-fpm /usr/loc ...

最新文章

  1. HTTP.sys 远程执行代码验证工具
  2. 【数据结构与算法】2.深度优先搜索DFS、广度优先搜索BFS
  3. c语言裂变,干货:社群是如何实现裂变的?
  4. 通过单步调试的方式学习 Angular 中 TView 和 LView 的概念
  5. Tournament CodeForces - 27B(dfs)
  6. redis session 超时时间_Shiro性能优化:解决Session频繁读写问题
  7. 实体验证---测试代码
  8. 在iOS平台使用libcurl
  9. 【统计分析】4 空间点数据分析与ArcGIS
  10. 如何开发出一款仿映客直播APP项目实践篇 -【原理篇】
  11. QT 播放器之界面布局
  12. K8S还没用,又出个K9S,什么鬼?
  13. CRNN原理详解、代码实现及BUG分析
  14. yourenduwanglai的鬼话连篇(九)
  15. ACM比赛一些需要注意的极端情况
  16. NeuroImage:慢性疼痛病人功能脑社区变化的网络结构
  17. python画小动物_三分钟识别所有小动物!
  18. 充电电流用软件测试准吗,充电设备 篇一:一次不严谨的测试,但估计iPhone用户看了都会买...
  19. R语言入门(17)-读写excel文件
  20. 深度置信网 DBNs

热门文章

  1. win7与linux切换,Windows 7停更后不想用Win10?教你直接换上Linux再战
  2. 万年历(c语言)编程,C语言实现的万年历
  3. Python项目-Day26-数据加密-hash加盐加密-token-jwt
  4. @TOM VIP邮箱,打造商务办公新场景,定位职场人的贴心助手!
  5. ArcGIS三维分析之ArcGlobe简要说明
  6. VMware ESXI上开虚机玩KVM
  7. 常见的百度云搜索引擎入口合集
  8. 一维到三维的推广(1D and 3D generalizations of models)
  9. 网络资产扫描工具 -- Goby
  10. linux环境下mysql主从数据库配置(maser-slave-replication)