PyTorch参数初始化和Finetune
前言
这篇文章算是论坛PyTorch Forums关于参数初始化和finetune的总结,也是我在写代码中用的算是“最佳实践”吧。最后希望大家没事多逛逛论坛,有很多高质量的回答。
参数初始化
参数的初始化其实就是对参数赋值。而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了data,grad等借口,这就意味着我们可以直接对这些参数进行操作赋值了。这就是PyTorch简洁高效所在。
所以我们可以进行如下操作进行初始化,当然其实有其他的方法,但是这种方法是PyTorch作者所推崇的:
def weight_init(m):
# 使用isinstance来判断m属于什么类型if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, nn.BatchNorm2d):
# m中的weight,bias其实都是Variable,为了能学习参数以及后向传播m.weight.data.fill_(1)m.bias.data.zero_()
Finetune
往往在加载了预训练模型的参数之后,我们需要finetune模型,可以使用不同的方式finetune。
局部微调
有时候我们加载了训练模型后,只想调节最后的几层,其他层不训练。其实不训练也就意味着不进行梯度计算,PyTorch中提供的requires_grad使得对训练的控制变得非常简单。
model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():param.requires_grad = False
# 替换最后的全连接层, 改为训练100类
# 新构造的模块的参数默认requires_grad为True
model.fc = nn.Linear(512, 100)# 只优化最后的分类层
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)
全局微调
有时候我们需要对全局都进行finetune,只不过我们希望改换过的层和其他层的学习速率不一样,这时候我们可以把其他层和新层在optimizer中单独赋予不同的学习速率。比如:
ignored_params = list(map(id, model.fc.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params,model.parameters())optimizer = torch.optim.SGD([{'params': base_params},{'params': model.fc.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)
其中base_params使用1e-3来训练,model.fc.parameters使用1e-2来训练,momentum是二者共有的。
PyTorch参数初始化和Finetune相关推荐
- PyTorch模型读写、参数初始化、Finetune
使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...
- Pytorch基础知识整理(六)参数初始化
参数初始化的目的是限定网络权重参数的初始分布,试图让权重参数更接近参数空间的最优解,从而加速训练.pytorch中网络默认初始化参数为随机均匀分布,设定额外的参数初始化并非总能加速训练. 1,模板 在 ...
- pytorch中的参数初始化方法
参数初始化(Weight Initialization) PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中.例如:nn.Linear 和 nn.Conv2D, ...
- PyTorch学习:参数初始化
Sequential 模型的参数初始化 import numpy as np import torch from torch import nn# 定义一个 Sequential 模型 net1 = ...
- pytorch tensor 初始化_Pytorch - nn.init 参数初始化方法
Pytorch 的参数初始化 - 给定非线性函数的推荐增益值(gain value):nonlinearity 非线性函数gain 增益 Linear / Identity1 Conv{1,2,3}D ...
- Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化
Pytorch 学习(6):Pytorch中的torch.nn Convolution Layers 卷积层参数初始化 class Conv1d(_ConvNd):......def __init ...
- 网络优化(三)——参数初始化
文章目录 1. 基于固定方差的参数初始化 2. 基于方差缩放的参数初始化 2.1 Xavier 初始化 2.2 Kaiming初始化 3. 正交初始化 神经网络的参数学习是一个非凸优化问题.当使用梯度 ...
- 深度学习参数初始化(二)Kaiming初始化 含代码
目录 一.介绍 二.基础知识 三.Kaiming初始化的假设条件 四.Kaiming初始化的简单的公式推导 1.前向传播 2.反向传播 五.Pytorch实现 深度学习参数初始化系列: (一)Xavi ...
- 派生类参数初始化列表和基类构造函数顺序
今天被问到了一个问题,随便回了一句,父类还没有构建,怎么能初始化父类的成员. 派生类构造函数的参数初始化列表,为什么不能初始化基类的成员? 例如下面的是不可以的 class Rectangle : p ...
最新文章
- 视频+PPT | 企业服务进阶第一课:客户全生命周期运营总览
- Apache Httpd + Subversion 搭建HTTP访问的SVN服务器
- Linux 基础知识系列第一篇
- Struts2学习(二):第一个Action
- 从0到1打造一款react-native App(二)Navigation+Redux
- SQL连接查询_ INNER JOIN
- 1811114每日一句
- go php 框架,go框架 - Go语言中文网 - Golang中文社区
- 如何实现向APP推送消息
- vue项目如何部署?history与hash模式部署时的区别
- (转载)消息队列详解
- android 实现3d扫描,DIY:让Android手机轻松变3D扫描仪
- 天猫order前后台
- 淘宝商品详情,1688商品详情滑块的解决方法和接口
- MNIST手写数字识别之MLP实现
- python项目之杠子老虎鸡虫
- 博弈论与信息经济学-重复博弈
- Loadrunner12.55windows-linux-os安装详细教程
- IoT 开发,我们需要学习哪些内容?
- 本地缓存-loadingCache