Solver 配置详解
Solver是求解学习模型的核心配置文件,网络确定后,solver就决定了学习的效果。本文结合caffe.proto和网上资料,对solver配置进行学习。
Solver
Caffe学习系列(7):solver及其配置,denny402
Solver在caffe中的定义
通常的solver文件与net文件相互关联,同样的net我们往往使用不同的solver尝试得到最好的效果,其运行代码为:
caffe train --solver=*_slover.prototxt
- 1
关于solver的一切,都在caffe.proto文件中message SolverParameter 这一部分。
网络文件源
// Proto filename for the train net, possibly combined with one or more// test nets.optional string net = 24;// Inline train net param, possibly combined with one or more test nets.optional NetParameter net_param = 25;optional string train_net = 1; // Proto filename for the train net.repeated string test_net = 2; // Proto filenames for the test nets.optional NetParameter train_net_param = 21; // Inline train net params.repeated NetParameter test_net_param = 22; // Inline test net params.
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
这是最开始的部分,需要说明net文件的位置。在这四个train_net_param, train_net, net_param, net字段中至少需要出现一个,当出现多个时,就会按着(1) test_net_param, (2) test_net, (3) net_param/net 的顺序依次求解。必须为每个test_net指定一个test_iter。还可以为每个test_net指定test_level和/或test_stage。注意的是:文件的路径要从caffe的根目录开始,其它的所有配置都是这样。
可以看到这几行的标签序号并不是顺序的,也说明caffe在不断地修改,下一个可用的序号是41。
网络状态
// The states for the train/test nets. Must be unspecified or// specified once per net.//// By default, all states will have solver = true;// train_state will have phase = TRAIN,// and all test_state's will have phase = TEST.// Other defaults are set according to the NetState defaults.optional NetState train_state = 26;repeated NetState test_state = 27;
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
网络状态必须是未指定的或者只能在一个网络中指定一次。
关于NetState,其定义为:
message NetState {optional Phase phase = 1 [default = TEST];optional int32 level = 2 [default = 0];repeated string stage = 3;
}
enum Phase {TRAIN = 0;TEST = 1;
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
迭代器
// The number of iterations for each test net.repeated int32 test_iter = 3;
- 1
- 2
首先是test_iter
,这需要与test layer中的batch_size结合起来理解。mnist数据中测试样本总数为10000,一次性执行全部数据效率很低,因此我们将测试数据分成几个批次来执行,每个批次的数量就是batch_size。假设我们设置batch_size为100,则需要迭代100次才能将10000个数据全部执行完。因此test_iter设置为100。执行完一次全部数据,称之为一个epoch。
// The number of iterations between two testing phases.optional int32 test_interval = 4 [default = 0];optional bool test_compute_loss = 19 [default = false];// If true, run an initial test pass before the first iteration,// ensuring memory availability and printing the starting value of the loss.optional bool test_initialization = 32 [default = true];
- 1
- 2
- 3
- 4
- 5
- 6
test_interval
是指测试间隔,每训练test_interval次,进行一次测试。同时test_compute_loss
可以选择是否计算loss。test_initialization
是指在第一次迭代前,计算初始的loss以确保内存可用。
optional float base_lr = 5; // The base learning rate// the number of iterations between displaying info. If display = 0, no info// will be displayed.optional int32 display = 6;// Display the loss averaged over the last average_loss iterationsoptional int32 average_loss = 33 [default = 1];optional int32 max_iter = 7; // the maximum number of iterations// accumulate gradients over `iter_size` x `batch_size` instancesoptional int32 iter_size = 36 [default = 1];
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
base_lr
指基础的学习率;display
是信息显示间隔,迭代一定次数显示一次信息。average_loss
用于显示在上次average_loss迭代中的平均损失。max_iter
是最大迭代次数,需要合适设置达到精度、震荡的平衡。iter_size
是迭代器大小,梯度的计算是通过iter_size
x batch_size
决定的。
学习策略
optional string lr_policy = 8;optional float gamma = 9; // The parameter to compute the learning rate.optional float power = 10; // The parameter to compute the learning rate.optional float momentum = 11; // The momentum value.optional float weight_decay = 12; // The weight decay.// regularization types supported: L1 and L2// controlled by weight_decayoptional string regularization_type = 29 [default = "L2"];// the stepsize for learning rate policy "step"optional int32 stepsize = 13;// the stepsize for learning rate policy "multistep"repeated int32 stepvalue = 34;
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
只要是梯度下降法来求解优化,都会有一个学习率,也叫步长。base_lr用于设置基础学习率,在迭代的过程中,可以对基础学习率进行调整。怎么样进行调整,就是调整的策略,由lr_policy来设置。caffe提供了多种policy:
- fixed: 总是返回base_lr(学习率不变)
- step: 返回 base_lr * gamma ^ (floor(iter / step))
还需要设置stepsize参数以确定step,iter表示当前迭代次数。 - exp: 返回base_lr * gamma ^ iter, iter为当前迭代次数
- inv: 如果设置为inv,还需要设置一个power, 返回base_lr * (1 + gamma * iter) ^ (- power)
- multistep: 如果设置为multistep,则还需要设置一个stepvalue。这个参数和step很相似,step是均匀等间隔变化,而multistep则是根据stepvalue值变化。
- poly: 学习率进行多项式误差, 返回 base_lr (1 - iter/max_iter) ^ (power)
- sigmoid: 学习率进行sigmod衰减,返回 base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))。
multistep示例:
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "multistep"
gamma: 0.9
stepvalue: 5000
stepvalue: 7000
stepvalue: 8000
stepvalue: 9000
stepvalue: 9500
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
之后有momentum
,上次梯度更新的权重;weight_decay
权重衰减,防止过拟合;regularization_type
正则化方式。
clip_gradients
optional float clip_gradients = 35 [default = -1];
- 1
参数梯度的实际L2范数较大时,将clip_gradients设置为> = 0,以将参数梯度剪切到该L2范数。具体作用还不是很理解。
snapshot快照
optional int32 snapshot = 14 [default = 0]; // The snapshot intervaloptional string snapshot_prefix = 15; // The prefix for the snapshot.// whether to snapshot diff in the results or not. Snapshotting diff will help// debugging but the final protocol buffer size will be much larger.optional bool snapshot_diff = 16 [default = false];enum SnapshotFormat {HDF5 = 0;BINARYPROTO = 1;}optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
快照可以将训练出来的model和solver状态进行保存,snapshot
用于设置训练多少次后进行保存,默认为0,不保存。snapshot_prefix
设置保存路径。还可以设置snapshot_diff
,是否保存梯度值,保存有利于调试,但需要较大空间存储,默认为false,不保存。也可以设置snapshot_format
,保存的类型。有两种选择:HDF5 和BINARYPROTO ,默认为BINARYPROTO。
运行模式
enum SolverMode {CPU = 0;GPU = 1;}optional SolverMode solver_mode = 17 [default = GPU];// the device_id will that be used in GPU mode. Use device_id = 0 in default.optional int32 device_id = 18 [default = 0];// If non-negative, the seed with which the Solver will initialize the Caffe// random number generator -- useful for reproducible results. Otherwise,// (and by default) initialize using a seed derived from the system clock.optional int64 random_seed = 20 [default = -1];
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
设置CPU或GPU模式,在GPU下还可以指定使用哪一块GPU运行。random_seed
用于初始生成随机数种子。
Solver类型
// type of the solveroptional string type = 40 [default = "SGD"];// numerical stability for RMSProp, AdaGrad and AdaDelta and Adamoptional float delta = 31 [default = 1e-8];// parameters for the Adam solveroptional float momentum2 = 39 [default = 0.999];// RMSProp decay value// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)optional float rms_decay = 38;
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
type
是solver的类型,目前有SGD、NESTEROV、ADAGRAD、RMSPROP、ADADELTA、ADAM = 5这六类。之后的一些是这些类型的特有参数,根据需要设置。
杂项
// If true, print information about the state of the net that may help with// debugging learning problems.optional bool debug_info = 23 [default = false];// If false, don't save a snapshot after training finishes.optional bool snapshot_after_train = 28 [default = true];
- 1
- 2
- 3
- 4
- 5
- 6
debug_info
用于输出调试信息。snapshot_after_train
用于训练后是否输出快照。
Solver 配置详解相关推荐
- elasticsearch-.yml(中文配置详解)
此elasticsearch-.yml配置文件,是在$ES_HOME/config/下 elasticsearch-.yml(中文配置详解) # ======================== El ...
- (ASA) Cisco Web ××× 配置详解 [三部曲之一]
(ASA) Cisco Web ××× 配置详解 [三部曲之一] 注意:本文仅对Web×××特性和配置作介绍,不包含SSL ×××配置,SSL ×××配置将在本版的后续文章中进行介绍. 首先,先来 ...
- mybatis 同名方法_MyBatis(四):xml配置详解
目录 1.我们将 数据库的配置语句写在 db.properties 文件中 2.在 mybatis-configuration.xml 中加载db.properties文件并读取 通过源码我们可以分析 ...
- logback节点配置详解
logback节点配置详解 一:根节点 <configuration></configuration> 属性 : debug : 默认为false ,设置为true时,将打印出 ...
- PM配置详解之一:企业结构
1.维护计划工厂 功能说明 在公司结构中定义维护工厂(通常已经作为后勤工厂存在)和维护计划工厂(简称计划工厂). 维护工厂:设备所安装的位置,如某机组安装在合营公司,那么合营公司就是此机组的维护工厂, ...
- 转 Log4j.properties配置详解
一.Log4j简介 Log4j有三个主要的组件:Loggers(记录器),Appenders (输出源)和Layouts(布局).这里可简单理解为日志类别,日志要输出的地方和日志以何种形式输出.综合使 ...
- Iptables防火墙配置详解
iptables防火墙配置详解 iptables简介 iptables是基于内核的防火墙,功能非常强大,iptables内置了filter,nat和mangle三张表. (1)filter表负责过滤数 ...
- spring之旅第四篇-注解配置详解
spring之旅第四篇-注解配置详解 一.引言 最近因为找工作,导致很长时间没有更新,找工作的时候你会明白浪费的时间后面都是要还的,现在的每一点努力,将来也会给你回报的,但行好事,莫问前程!努力总不会 ...
- php-fpm 启动参数及重要配置详解
2019独角兽企业重金招聘Python工程师标准>>> php-fpm 启动参数及重要配置详解 约定几个目录 /usr/local/php/sbin/php-fpm /usr/loc ...
最新文章
- HTTP.sys 远程执行代码验证工具
- 【数据结构与算法】2.深度优先搜索DFS、广度优先搜索BFS
- c语言裂变,干货:社群是如何实现裂变的?
- 通过单步调试的方式学习 Angular 中 TView 和 LView 的概念
- Tournament CodeForces - 27B(dfs)
- redis session 超时时间_Shiro性能优化:解决Session频繁读写问题
- 实体验证---测试代码
- 在iOS平台使用libcurl
- 【统计分析】4 空间点数据分析与ArcGIS
- 如何开发出一款仿映客直播APP项目实践篇 -【原理篇】
- QT 播放器之界面布局
- K8S还没用,又出个K9S,什么鬼?
- CRNN原理详解、代码实现及BUG分析
- yourenduwanglai的鬼话连篇(九)
- ACM比赛一些需要注意的极端情况
- NeuroImage:慢性疼痛病人功能脑社区变化的网络结构
- python画小动物_三分钟识别所有小动物!
- 充电电流用软件测试准吗,充电设备 篇一:一次不严谨的测试,但估计iPhone用户看了都会买...
- R语言入门(17)-读写excel文件
- 深度置信网 DBNs
热门文章
- win7与linux切换,Windows 7停更后不想用Win10?教你直接换上Linux再战
- 万年历(c语言)编程,C语言实现的万年历
- Python项目-Day26-数据加密-hash加盐加密-token-jwt
- @TOM VIP邮箱,打造商务办公新场景,定位职场人的贴心助手!
- ArcGIS三维分析之ArcGlobe简要说明
- VMware ESXI上开虚机玩KVM
- 常见的百度云搜索引擎入口合集
- 一维到三维的推广(1D and 3D generalizations of models)
- 网络资产扫描工具 -- Goby
- linux环境下mysql主从数据库配置(maser-slave-replication)