转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/6221664.html

参考网址:

https://github.com/torch/nn/issues/873

http://stackoverflow.com/questions/37459812/finetune-a-torch-model

https://github.com/torch/nn/blob/master/doc/module.md

https://github.com/torch/torch7/blob/master/doc/utility.md

=====================================================

170928更新(可以微调层):

参考网址:

http://www.thedataware.com/post/the-torch-adventures-setting-layer-wise-training-parameters

https://github.com/NVIDIA/DIGITS/tree/master/examples/fine-tuning

https://github.com/NVIDIA/DIGITS/blob/master/examples/fine-tuning/lenet-fine-tune.lua#L56

https://stackoverflow.com/questions/37459812/finetune-a-torch-model

https://www.zhihu.com/question/44376850

说明:目前就第一个网址的能finetune参数。

深度学习中目前有参数的为:卷积层-conv(weight+bias),batchnorm层:bn(weight+bias),全连接层-linear(weight+bias)。

因而在torch中使用local params, gradParams = model:parameters()的话,默认得到的#params为上面这三种类型的层的数量之和再乘以2。如果对应没有bias,则该层参数参数数量为1。

使用http://www.thedataware.com/post/the-torch-adventures-setting-layer-wise-training-parameters的方法,可以更新某个层。该文章是每个层设置不同的学习率,如果只某些特定的层学习率不为0,其它层学习率均为0(或者先定义fineTuneLayerIdx={10,11,12},而后for i = 1, #params改成for i = 1, #fineTuneLayerIdx来减少计算量),则会只更新这些层的参数。需要注意的是,如果fine tune最后几层还好,可以print(params),来看一下参数,然后计算一下哪些参数是需要更新的,如果更新中间的层。。。只能自己去对应了(特别是如Inception,Resnet这种网络中间层的参数,对应起来更加蛋疼了吧)。

该网址中对每层都设置学习率的代码如下:

local params, gradParams = model:parameters() -- Set the learning rate to 0.01
local learningRates = torch.Tensor(#params):fill(0.01)
-- Set the learning rate of the second layer to 0.001
learningRates[2] = 0.001optimState = {}
for i = 1, #params dotable.insert(optimState, {learningRate = learningRates[i],learningRateDecay = 0.0001,momentum = 0.9,dampening = 0.0,weightDecay = 5e-4})
endfor e = 1, epochs do-- Get MNIST batchX, Y = get_mnist_batch(batch_size)-- forward -> backward (outside of feval)
  model:zeroGradParameters()out = model:forward(X)err = criterion:forward(out, Y)gradOutputs = criterion:backward(out, Y)model:backward(X, gradOutputs)-- layer-wise optimizationfor i = 1, #params dolocal feval = function(x)return err, gradParams[i]end-- run optimizer
    optim.sgd(feval, params[i], optimState[i])endend
-- model trained

View Code

如果使用fineTuneLayerIdx,即只微调部分层,代码如下:

local params, gradParams = model:parameters() -- 需要finetune的参数层(不是网络层。网络层:内部可能还有更小的网络,比如densenet,resnext等;
-- 参数层:正常情况下,一个conv,bn,linear等各有2个参数层,所以参数曾可能比网络成多很多)
local fineTuneLayerIdx = {30,34,35} -- Set the learning rate to 0.01
local learningRates = torch.Tensor(#fineTuneLayerIdx):fill(0.01)
-- Set the learning rate of the second layer to 0.001
learningRates[2] = 0.001optimState = {}
for i = 1, #fineTuneLayerIdx dotable.insert(optimState, {learningRate = learningRates[i],learningRateDecay = 0.0001,momentum = 0.9,dampening = 0.0,weightDecay = 5e-4})
endfor e = 1, epochs do-- Get MNIST batchX, Y = get_mnist_batch(batch_size)-- forward -> backward (outside of feval)
  model:zeroGradParameters()out = model:forward(X)err = criterion:forward(out, Y)gradOutputs = criterion:backward(out, Y)model:backward(X, gradOutputs)-- layer-wise optimizationfor i = 1, #fineTuneLayerIdx dolocal feval = function(x)return err, gradParams[fineTuneLayerIdx[i]]end-- run optimizer
    optim.sgd(feval, params[fineTuneLayerIdx[i]], optimState[i])endend
-- model trained

View Code

需要注意的是,如果使用model:parameters(),需要optimState为多个table,不能为下面这样简单的一个table:

optimState = { -- 使用model:parameters()时,使用这种optimState有问题learningRate = learningRates, learningRateDecay = 0.0001, momentum = 0.9, dampening = 0.0, weightDecay = 5e-4 }

否则在第二次运行到optim.sgd(feval, params[fineTuneLayerIdx[i]], optimState[i])时,可能会提示维度不一样。

另外,https://www.zhihu.com/question/44376850中“知乎用户”的回答也和这个类似,只不过不知道那个网址中的和这个网址中的谁先谁后吧。

如果使用https://stackoverflow.com/questions/37459812/finetune-a-torch-model中的方法,即:

for i=1, x doc = model:get(i)c.updateGradInput = function(self, inp, out) endc.accGradParameters = function(self,inp, out) end
end

我这边有conv、bn,linear这三种层,会提示下面bn层的错误,不清楚是我这边程序的问题,还是怎么回事。

如果使用https://github.com/NVIDIA/DIGITS/blob/master/examples/fine-tuning/lenet-fine-tune.lua#L56这种方法,其实和上面的类似,只不过没有设置每层的updateGradInput这个。只设置一个的话,同样的输入,每次输出不一样(我把所有的conv,bn,linear都设置了= function(self, inp, out) end,为了看一下输出是否一致。理论上如果这些层参数都不更新,同样的输入,最终的输出应该相同),即感觉没能fine tune特定的层。

170928更新结束

=====================================================

161229更新:

感谢@linzhineng 。

即便按照本文这样设置,实际上在微调时,其它层的参数还是会变化。现在凌乱了,不清楚如何微调了/(ㄒoㄒ)/~~

难道只能手动修改更新过程吗?

161229更新结束:

=====================================================

由于torch每个模块均有train参数,当其为true时进行训练,当期为false时进行测试。因而,如果要对训练好的模型进行微调,如只对某模块调整参数,其他模块参数固定,则可以使用第一个参考网址中soumith的方法(该方法固定某模块,和本文目的是反的):

model:training()
model:apply(function(m) if torch.type(m):find("BatchNormalization") then m:evaluate() end end)

说明:一般来说,在训练时,需要设置model:training(),在测试时,需要设置model:evaluate()。因而微调参数时,上面代码加在训练代码中model:training()后面就可以了(需要适当的修改)。

第四个网址给出了[string] torch.type(object)。因而,对上面的代码修改如下:如果要达到微调某一模块参数(如全连接层Linear),只需要使用:

   model:evaluate()model:apply(function(m) if torch.type(m):find('Linear')  thenm:training()end
   end)

说明:上面代码测试后成功。但是遇到了一个很诡异的问题。如果第一行改为model:training(),在找到对应的层后,改为m: evaluate (),没有成功(对应的torch.type(m):find('Linear')==nil),所以才使用了上面的代码。还有一点,如果判断torch.type(m):find('Linear')==nil,最后没有成功改了m的train变量的值,具体不太清楚,最终使用了上面给出的代码。

上面torch.type(m)会返回模块的名字,如:

nn.Sequential
nn.SpatialConvolution
nn.SpatialBatchNormalization
nn.ReLU
nn.SpatialMaxPooling
nn.SpatialConvolution
nn.SpatialBatchNormalization
nn.ReLU

上面torch.type(m):find("BatchNormalization"),如果在某层找到了BatchNormalization,则返回找到的起始和结束位置,否则返回nil。

还有,微调时,一般都只微调某一层,但是torch中很多层名字相同,如果要改特定的一层,如conv层,还要继续修改代码,判断是否是需要的那个conv层,否则会将所有的conv层参数都修改。

注意:如果网络定义使用了Inception层,此处不光返回Inception,还会返回Inception里面各个层(如nn.Sequential,nn.InceptionHisign,nn.DepthConcat等)。

在torch/install/share/lua/5.1/nn/Module.lua中,有如下代码:

function Module:training()self.train = true
endfunction Module:evaluate()self.train = false
end

直觉上,torch中这种方式不如caffe的fine tuning时,设置对应层lr_mult=0容易。

第三个网址有对apply,training,evaluate的较详细的说明。

此外,第二个网址通过updateGradInput和accGradParameters来达到固定某层参数的效果,不过没有试过。

(原)torch中微调某层参数相关推荐

  1. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  2. 【增强学习】Torch中的增强学习层

    要想在Torch框架下解决计算机视觉中的增强学习问题(例如Visual Attention),可以使用Nicholas Leonard提供的dpnn包.这个包对Torch中原有nn包进行了强大的扩展, ...

  3. Pytorch中torch.nn.Softmax的dim参数含义

    自己搞了一晚上终于搞明白了,下文说的很透彻,做个记录,方便以后翻阅 Pytorch中torch.nn.Softmax的dim参数含义

  4. NLP-预训练模型-2019:ALBert【 轻Bert;使用 “输入层向量矩阵分解”、“跨层参数共享” 减少参数量;使用SOP代替NSP】【较Bert而言缩短训练及推理时间】

    预训练模型(Pretrained model):一般情况下预训练模型都是大型模型,具备复杂的网络结构,众多的参数量,以及在足够大的数据集下进行训练而产生的模型. 在NLP领域,预训练模型往往是语言模型 ...

  5. Caffe常用层参数介绍

    DATA crop:截取原图像中一个固定patch layers {name: "data"type: DATAtop: "data"top: "la ...

  6. pytorch:固定部分层参数,固定单个模型

    文章目录 固定部分层参数 固定指定层的参数 不同层设置不同的学习率 固定部分层参数 class RESNET_attention(nn.Module):def __init__(self, model ...

  7. 卷积神经网络(CNN)中,卷积层、激活函数、池化层、全链接层术语解析

    本文内容转自https://www.cnblogs.com/zf-blog/p/6075286.html和https://www.cnblogs.com/rgvb178/p/6055213.html ...

  8. pytorch---之BN层参数详解及应用(1,2,3)(1,2)?

    BN层参数详解(1,2) 一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层(对 ...

  9. pytorch 批量归一化BatchNorm1d和BatchNorm2d的用法、BN层参数 running_mean running_var变量计算 验证

    前提知识 BN层包括mean var gamma beta四个参数,.对于图像来说(4,3,2,2),一组特征图,一个通道的特征图对应一组参数,即四个参数均为维度为通道数的一维向量,图中gamma.b ...

最新文章

  1. 【IM】关于条件随机场CRF的理解
  2. Git初始配置【一】
  3. ansible高级应用示例
  4. 解析JVM线程同步机制
  5. java模糊查询比对方法_Java多条件和模糊查询
  6. matlab trendsurface,MATLAB 添加新的预测性维护产品
  7. git-管理修改-强化暂存区的意识
  8. 高德地图 Android API 的基站定位原理及使用方法
  9. coupled/decoupled
  10. 王思聪名下企业被拍卖1100万债权,此前还债20亿 网友:拍下等于“接盘侠”?...
  11. 支持drupal的空间
  12. JSF通过EL读取List中的值
  13. Nginx资源合并优化模块nginx-http-concat
  14. [2018.08.07 T1] 签到?
  15. 蓝桥杯练习系统特殊回文数(python)
  16. npm学习(十七)之node_modules中的bin文件夹
  17. 【前端基础面试题】如何用CSS画一个三角形(详解)
  18. 华为麦芒5刷机_TWRP_Magisk(Root)_Xposed流程
  19. 卷尺精度标准_卷尺检验技术标准
  20. KENALLRYLLDKDD|359821-54-8

热门文章

  1. linux超实用的管理命令
  2. jQuery学习笔记--目录
  3. 安装vmware 6.52 Red Hat Enterprise Linux 5(rhel-5.1-server-i386-dvd) openldap2.4
  4. pexpect oracle,expect免交互脚本编程
  5. 人声处理_10款免费的人声处理工具
  6. springboot项目中pom.xml文件的颜色变成灰色,图标变成蜘蛛图形
  7. php printf 0.2f,php printf()
  8. linux mysql 5.6.24_Mysql实例Linux安装MySQL5.6.24使用文字说明
  9. 幻读(phantom read)
  10. 核心组件:IRule