前言

在我们对神经网络模型进行训练时,往往需要对模型进行初始化或者加载预训练模型。本文将对模型的权重初始化与加载预训练模型做一个学习记录,以便后续查询使用。

权重初始化

常见的初始化方法

PyTorch 在 torch.nn.init 中提供了常用的初始化方法函数,这里主要简要介绍Xavier初始化与kaiming初始化。

Xavier初始化

Xavier 初始化方法,方法来源于2010年的一篇论文《Understanding the difficulty of training deep feedforward neural networks》

公式推导是从“方差一致性”出发,初始化的分布有均匀分布和正态分布两种。

Xavier 均匀分布

torch.nn.init.xavier_uniform_(tensor, gain=1.0)

该初始化方法服从均匀分布 U ∼ ( − a , a ) U\sim(-a,a) U∼(−a,a),其中a为:
a = gain ⁡ × 6 fan_in ⁡ + fan_out ⁡ a=\operatorname{gain} \times \sqrt{\frac{6}{\operatorname{fan\_in} +\operatorname{fan\_out}}} a=gain×fan_in+fan_out6​ ​

该初始化方法中有一个参数 gain,增益的大小是依据激活函数类型来设定
eg:

nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain(‘relu’))

PS:上述初始化方法,也称为 Glorot initialization

使用方法示例:

for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.xavier_uniform_(m.weight)

Xavier正态分布

torch.nn.init.xavier_normal_(tensor, gain=1.0)

该初始化方法服从正态分布 N ( 0 , s t d 2 ) \mathcal{N}\left(0, \mathrm{std}^{2}\right) N(0,std2):
std ⁡ = gain ⁡ × 2 fan_in ⁡ + fan_out ⁡ \operatorname{std}=\operatorname{gain} \times \sqrt{\frac{2}{\operatorname{fan\_in} +\operatorname{fan\_out}}} std=gain×fan_in+fan_out2​ ​
使用方法示例:

for m in model.modules():if isinstance(m, (nn.Conv2d, nn.Linear)):nn.init.xavier_normal_(m.weight)

kaiming初始化

kaiming初始化,方法来源于2015年的一篇论文《 Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》

公式推导同样从“方差一致性”出法,kaiming是针对xavier初始化方法在relu这一类激活函数表现不佳而提出的改进,详细可以参看论文。

kaiming均匀分布

torch.nn.init.kaiming_uniform_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)

该初始化方法服从均匀分布 U ( \mathcal{U}( U( -bound, bound ) ) ):
bound  = gain ⁡ × 3 fan_mode  \text { bound }=\operatorname{gain} \times \sqrt{\frac{3}{\text { fan\_mode }}}  bound =gain× fan_mode 3​ ​
其中,a为激活函数的负半轴的斜率,mode可选为fan_infan_out, fan_in使正向传播时,方差一致; fan_out使反向传播时,方差一致。
nonlinearity 建议选择 reluleaky_relu ,默认值为 leaky_relu

kaiming正态分布

torch.nn.init.kaiming_normal_(tensor, a=0, mode=‘fan_in’, nonlinearity=‘leaky_relu’)

该初始化方法服从正态分布 N ( 0 , s t d 2 ) \mathcal{N}\left(0, \mathrm{std}^{2}\right) N(0,std2):
s t d = gain  fan_mode  \mathrm{std}=\frac{\text { gain }}{\sqrt{\text { fan\_mode }}} std= fan_mode  ​ gain ​
其中,a为激活函数的负半轴的斜率,mode可选为fan_infan_out, fan_in使正向传播时,方差一致; fan_out使反向传播时,方差一致。
nonlinearity 建议选择 reluleaky_relu ,默认值为 leaky_relu

模型权重初始化

# 定义权值初始化
def initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):torch.nn.init.xavier_normal_(m.weight.data)if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):torch.nn.init.normal_(m.weight.data, 0, 0.01)m.bias.data.zero_()

保存与加载模型

pytorch在保存模型时,可以保存整个神经网络的的结构信息和模型参数信息,save的对象是网络net;也可以只保存神经网络的训练模型参数,save的对象是net.state_dict()

# 保存和加载整个模型
torch.save(model_object, 'model.pth')
model = torch.load('model.pth')  # 仅保存和加载模型参数
torch.save(model_object.state_dict(), 'params.pth')
model_object.load_state_dict(torch.load('params.pth'))

加载预训练模型

# load params, 这里加载的是模型的参数,不是整个模型
pretrained_dict = torch.load('net_params.pkl')
# 仅保存了整个模型, 需要使用以下语句
# pretrained_dict = torch.load('net_params.pkl').state_dict()# 获取当前网络的dict
net_state_dict = net.state_dict()# 剔除不匹配的权值参数
pretrained_dict_1 = {k: v for k, v in pretrained_dict.items() if k in net_state_dict}# 更新新模型参数字典
net_state_dict.update(pretrained_dict_1)# 将包含预训练模型参数的字典"放"到新模型中
net.load_state_dict(net_state_dict)

pytorch预训练模型的简单修改与使用

以resnet预训练模型举例,resnet源代码的pytorch官方实现。 resnet网络最后一层分类层fc是对1000种类型进行划分,如果自己的数据集只有6类,可以只对fc层进行修改:

#调用模型
model = torchvision.models.resnet50(pretrained=True)#提取fc层中固定的参数
fc_features = model.fc.in_features#修改类别
model.fc = nn.Linear(fc_features, 6)

按需设置学习率

# ================================= #
#         按需设置学习率
# ================================= ## 将fc3层的参数从原始网络参数中剔除
ignored_params = list(map(id, net.fc3.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, net.parameters())# 为fc3层设置需要的学习率
optimizer = optim.SGD([{'params': base_params},{'params': net.fc3.parameters(), 'lr': lr_init*10}],  lr_init, momentum=0.9, weight_decay=1e-4)criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)     # 设置学习率下降策略

参考资料

  1. 《Pytorch模型训练实用教程》
  2. 《Understanding the difficulty of training deep feedforward neural
    networks》
  3. 《 Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》
  4. PyTorch Documentation

Pytorch网络模型权重初始化、保存与加载模型、加载预训练模型、按需设置学习率相关推荐

  1. pytorch 之 加载不同形式的预训练模型

    我们在学习pytorch时,不可避免的要加载不同的预训练模型.而且pytorch下的预训练模型有很多种形式,我们又该如何加载呢.今天,我就为大家介绍三种常用的模型形式以及其加载方式. 1.pth形式和 ...

  2. Paddle加载NLP的各类预训练模型方法总结(以文本分类任务为例,包含完整代码)

    一.Introduction 最近宅在家,有空只能搞搞NLP的比赛.由于缺乏GPU的加持,只好白嫖百度的AI Studio(毕竟人家提供免费的Tesla V100).在此不得不赞扬一下优秀的国产深度学 ...

  3. 自然语言处理--gensim.word2vec 模块加载使用谷歌的预训练模型googlenews-vectors-negative300.bin.gz

    词向量将词的语义表示为训练语料库中上下文中的向量,可以把词向量看作是一个权重或分数的列表,列表中的每个权重或分数都对应于这个词在某个特定维度的含义. 很多公司都提供了预训练好的词向量模型,而且有很多针 ...

  4. Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

    文章目录 一.概述 二.代码编写 1. 数据处理 2. 准备配置文件 3. 自定义DataSet和DataLoader 4. 构建模型 5. 训练模型 6. 编写预测模块 三.效果展示 四.源码地址 ...

  5. Pytorch预训练模型加载

    1. 保存模型:torch.save(model.state_dict(), PATH) 加载模型:model.load_state_dict(torch.load(PATH)) model.eval ...

  6. Pytorch:NLP 迁移学习、NLP中的标准数据集、NLP中的常用预训练模型、加载和使用预训练模型、huggingface的transfomers微调脚本文件

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) run_glue.py微调脚本代码 python命令执行run ...

  7. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  8. 几种常用的权重初始化方法

    来源:投稿 作者:175 编辑:学姐 在深度学习中,权重的初始值非常重要,权重初始化方法甚至关系到模型能否收敛.本文主要介绍两种权重初始化方法. 为什么需要随机初始值 我们知道,神经网络一般在初始化权 ...

  9. ie iframe加载 只有head_mmdetection中加载模型不匹配问题

    最近在学习mmdetection框架,在加载预训练模型时偶尔会遇到如下问题: The model and loaded state dict do not match exactly 正常情况下只是把 ...

最新文章

  1. 网络应用 axIos +vue的应用
  2. 网络:TCP/UDP
  3. 福州java培训哪里好_南通java培训哪家好
  4. 信息如何实现病毒式传播?一文看懂Gossip协议
  5. react 组件名称重复_设计可重复使用的React组件
  6. 关于Tomcat+Nginx负载均衡与Jmeter服务器测压的日记
  7. 【OGG】 RAC环境下管理OGG的高可用 (五)
  8. python数据库管理软件_数据库管理工具神器-DataGrip,可同时管理多个主流数据库[SQL Server,MySQL,Oracle等]连接 - Python社区...
  9. Android开发笔记(七十九)资源与权限校验
  10. Entity Framework 4 in Action读书笔记——第六章:理解实体的生命周期(三)
  11. 决策树(六)--随机森林
  12. 软件项目管理大作业_《软件工程》软件项目管理实验
  13. java入门简单小项目_JAVA入门_java项目接入Mysql8.0
  14. 彻底删除IE的缓存问题
  15. 米发,免费域名转发 301重定向 URL跳转服务
  16. 5个很少被提到但能提高NLP工作效率的Python库
  17. 北京python培训班价格
  18. 困扰成都青年的20年癫痫在三博脑科医院终结
  19. html的表单可以加背景图片,如何装饰表单的背景和字符
  20. Java最新面试题100道,包含答案示例(41-50题)

热门文章

  1. 量子计算(八):观测量和计算基下的测量
  2. JavaScript 开发者应懂的 33 个概念
  3. 分布式存储之CAP理论
  4. 毛哥的快乐生活(29) 大河弯弯
  5. 趣谈计算机网络1 - 通讯协议综述
  6. 一个线程死掉就等于整个进程死掉
  7. 计算机网络课程设计任务书
  8. python制作exe可执行文件的方法---使用pyinstaller
  9. 使用生成器和多线程为Keras训练模型的fit函数提供数据
  10. android studio 出现: Design editor is unavailable until a successful build 解决方法