前言

这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。

参数初始化

参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。 

所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:

def weight_init(m):
# 使用isinstance来判断m属于什么类型if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播m.weight.data.fill_(1)m.bias.data.zero_()

Finetune

往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。

局部微调

有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

全局微调

有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:

ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,model.parameters())optimizer = torch.optim.SGD([{'params': base_params},{'params': model.fc.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。

PyTorch参数初始化和Finetune相关推荐

  1. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  2. Pytorch基础知识整理(六)参数初始化

    参数初始化的目的是限定网络权重参数的初始分布,试图让权重参数更接近参数空间的最优解,从而加速训练.pytorch中网络默认初始化参数为随机均匀分布,设定额外的参数初始化并非总能加速训练. 1,模板 在 ...

  3. pytorch中的参数初始化方法

    参数初始化(Weight Initialization) PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中.例如:nn.Linear 和 nn.Conv2D, ...

  4. PyTorch学习:参数初始化

    Sequential 模型的参数初始化 import numpy as np import torch from torch import nn# 定义一个 Sequential 模型 net1 = ...

  5. pytorch tensor 初始化_Pytorch - nn.init 参数初始化方法

    Pytorch 的参数初始化 - 给定非线性函数的推荐增益值(gain value):nonlinearity 非线性函数gain 增益 Linear / Identity1 Conv{1,2,3}D ...

  6. Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化

    Pytorch 学习(6):Pytorch中的torch.nn  Convolution Layers  卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...

  7. 网络优化(三)——参数初始化

    文章目录 1. 基于固定方差的参数初始化 2. 基于方差缩放的参数初始化 2.1 Xavier 初始化 2.2 Kaiming初始化 3. 正交初始化 神经网络的参数学习是一个非凸优化问题.当使用梯度 ...

  8. 深度学习参数初始化(二)Kaiming初始化 含代码

    目录 一.介绍 二.基础知识 三.Kaiming初始化的假设条件 四.Kaiming初始化的简单的公式推导 1.前向传播 2.反向传播 五.Pytorch实现 深度学习参数初始化系列: (一)Xavi ...

  9. 派生类参数初始化列表和基类构造函数顺序

    今天被问到了一个问题,随便回了一句,父类还没有构建,怎么能初始化父类的成员. 派生类构造函数的参数初始化列表,为什么不能初始化基类的成员? 例如下面的是不可以的 class Rectangle : p ...

最新文章

  1. 视频+PPT | 企业服务进阶第一课:客户全生命周期运营总览
  2. Apache Httpd + Subversion 搭建HTTP访问的SVN服务器
  3. Linux 基础知识系列第一篇
  4. Struts2学习(二):第一个Action
  5. 从0到1打造一款react-native App(二)Navigation+Redux
  6. SQL连接查询_ INNER JOIN
  7. 1811114每日一句
  8. go php 框架,go框架 - Go语言中文网 - Golang中文社区
  9. 如何实现向APP推送消息
  10. vue项目如何部署?history与hash模式部署时的区别
  11. (转载)消息队列详解
  12. android 实现3d扫描,DIY:让Android手机轻松变3D扫描仪
  13. 天猫order前后台
  14. 淘宝商品详情,1688商品详情滑块的解决方法和接口
  15. MNIST手写数字识别之MLP实现
  16. python项目之杠子老虎鸡虫
  17. 博弈论与信息经济学-重复博弈
  18. Loadrunner12.55windows-linux-os安装详细教程
  19. IoT 开发,我们需要学习哪些内容?
  20. 本地缓存-loadingCache

热门文章

  1. OO第三单元单元总结
  2. CodeForces - 1076E Vasya and a Tree 树剖?nono dfs+树状数组
  3. Git初学者:msysgit和tortoisegit
  4. Loadrunner pacing与thinktime
  5. 运动状态最佳心率计算器 (Target-Heart-Rate Calculator)
  6. 【化学信息学】药物的分子结构
  7. SpringMVC之Ambiguous mapping(模棱两可的映射)
  8. 失败后总结的三个坑!
  9. Android之mp3播放器开发过程
  10. 第三届中兴捧月程序设计大赛 西大ATeam作品 望大家投票支持