权重初始化对于训练神经网络至关重要,好的初始化权重可以有效的避免梯度消失等问题的发生。

在pytorch的使用过程中有几种权重初始化的方法供大家参考。

注意:第一种方法不推荐。尽量使用后两种方法。

# not recommend

def weights_init(m):

classname = m.__class__.__name__

if classname.find('Conv') != -1:

m.weight.data.normal_(0.0, 0.02)

elif classname.find('BatchNorm') != -1:

m.weight.data.normal_(1.0, 0.02)

m.bias.data.fill_(0)

# recommend

def initialize_weights(m):

if isinstance(m, nn.Conv2d):

m.weight.data.normal_(0, 0.02)

m.bias.data.zero_()

elif isinstance(m, nn.Linear):

m.weight.data.normal_(0, 0.02)

m.bias.data.zero_()

# recommend

def weights_init(m):

if isinstance(m, nn.Conv2d):

nn.init.xavier_normal_(m.weight.data)

nn.init.xavier_normal_(m.bias.data)

elif isinstance(m, nn.BatchNorm2d):

nn.init.constant_(m.weight,1)

nn.init.constant_(m.bias, 0)

elif isinstance(m, nn.BatchNorm1d):

nn.init.constant_(m.weight,1)

nn.init.constant_(m.bias, 0)

编写好weights_init函数后,可以使用模型的apply方法对模型进行权重初始化。

net = Residual() # generate an instance network from the Net class

net.apply(weights_init) # apply weight init

补充知识:Pytorch权值初始化及参数分组

1. 模型参数初始化

# ————————————————— 利用model.apply(weights_init)实现初始化

def weights_init(m):

classname = m.__class__.__name__

if classname.find('Conv') != -1:

n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

m.weight.data.normal_(0, math.sqrt(2. / n))

if m.bias is not None:

m.bias.data.zero_()

elif classname.find('BatchNorm') != -1:

m.weight.data.fill_(1)

m.bias.data.zero_()

elif classname.find('Linear') != -1:

n = m.weight.size(1)

m.weight.data.normal_(0, 0.01)

m.bias.data = torch.ones(m.bias.data.size())

# ————————————————— 直接放在__init__构造函数中实现初始化

for m in self.modules():

if isinstance(m, nn.Conv2d):

n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels

m.weight.data.normal_(0, math.sqrt(2. / n))

if m.bias is not None:

m.bias.data.zero_()

elif isinstance(m, nn.BatchNorm2d):

m.weight.data.fill_(1)

m.bias.data.zero_()

elif isinstance(m, nn.BatchNorm1d):

m.weight.data.fill_(1)

m.bias.data.zero_()

elif isinstance(m, nn.Linear):

nn.init.xavier_uniform_(m.weight.data)

if m.bias is not None:

m.bias.data.zero_()

# —————————————————

self.weight = Parameter(torch.Tensor(out_features, in_features))

self.bias = Parameter(torch.FloatTensor(out_features))

nn.init.xavier_uniform_(self.weight)

nn.init.zero_(self.bias)

nn.init.constant_(m, initm)

# nn.init.kaiming_uniform_()

# self.weight.data.normal_(std=0.001)

2. 模型参数分组weight_decay

def separate_bn_prelu_params(model, ignored_params=[]):

bn_prelu_params = []

for m in model.modules():

if isinstance(m, nn.BatchNorm2d):

ignored_params += list(map(id, m.parameters()))

bn_prelu_params += m.parameters()

if isinstance(m, nn.BatchNorm1d):

ignored_params += list(map(id, m.parameters()))

bn_prelu_params += m.parameters()

elif isinstance(m, nn.PReLU):

ignored_params += list(map(id, m.parameters()))

bn_prelu_params += m.parameters()

base_params = list(filter(lambda p: id(p) not in ignored_params, model.parameters()))

return base_params, bn_prelu_params, ignored_params

OPTIMIZER = optim.SGD([

{'params': base_params, 'weight_decay': WEIGHT_DECAY},

{'params': fc_head_param, 'weight_decay': WEIGHT_DECAY * 10},

{'params': bn_prelu_params, 'weight_decay': 0.0}

], lr=LR, momentum=MOMENTUM ) # , nesterov=True

Note 1:PReLU(x) = max(0,x) + a * min(0,x). Here a is a learnable parameter. When called without arguments, nn.PReLU() uses a single parameter a across all input channels. If called with nn.PReLU(nChannels), a separate a is used for each input channel.

Note 2: weight decay should not be used when learning a for good performance.

Note 3: The default number of a to learn is 1, the default initial value of a is 0.25.

3. 参数分组weight_decay–其他

第2节中的内容可以满足一般的参数分组需求,此部分可以满足更个性化的分组需求。参考:face_evoLVe_Pytorch-master

自定义schedule

def schedule_lr(optimizer):

for params in optimizer.param_groups:

params['lr'] /= 10.

print(optimizer)

方法一:利用model.modules()和obj.__class__ (更普适)

# model.modules()和model.children()的区别:model.modules()会迭代地遍历模型的所有子层,而model.children()只会遍历模型下的一层

# 下面的关键词if 'model',源于模型定义文件。如model_resnet.py中自定义的所有nn.Module子类,都会前缀'model_resnet',所以可通过这种方式一次性筛选出自定义的模块

def separate_irse_bn_paras(model):

paras_only_bn = []

paras_no_bn = []

for layer in model.modules():

if 'model' in str(layer.__class__): # eg. a=[1,2] type(a): a.__class__:

continue

if 'container' in str(layer.__class__): # 去掉Sequential型的模块

continue

else:

if 'batchnorm' in str(layer.__class__):

paras_only_bn.extend([*layer.parameters()])

else:

paras_no_bn.extend([*layer.parameters()]) # extend()用于在列表末尾一次性追加另一个序列中的多个值(用新列表扩展原来的列表)

return paras_only_bn, paras_no_bn

方法二:调用modules.parameters和named_parameters()

但是本质上,parameters()是根据named_parameters()获取,named_parameters()是根据modules()获取。使用此方法的前提是,须按下文1,2中的方式定义模型,或者利用Sequential+OrderedDict定义模型。

def separate_resnet_bn_paras(model):

all_parameters = model.parameters()

paras_only_bn = []

for pname, p in model.named_parameters():

if pname.find('bn') >= 0:

paras_only_bn.append(p)

paras_only_bn_id = list(map(id, paras_only_bn))

paras_no_bn = list(filter(lambda p: id(p) not in paras_only_bn_id, all_parameters))

return paras_only_bn, paras_no_bn

两种方法的区别

参数分组的区别,其实对应了模型构造时的区别。举例:

1、构造ResNet的basic block,在__init__()函数中定义了

self.conv1 = conv3x3(inplanes, planes, stride)

self.bn1 = BatchNorm2d(planes)

self.relu = ReLU(inplace = True)

2、在forward()中定义

out = self.conv1(x)

out = self.bn1(out)

out = self.relu(out)

3、对ResNet取model.name_parameters()返回的pname形如:

‘layer1.0.conv1.weight'

‘layer1.0.bn1.weight'

‘layer1.0.bn1.bias'

# layer对应conv2_x, …, conv5_x; '0'对应各layer中的block索引,比如conv2_x有3个block,对应索引为layer1.0, …, layer1.2; 'conv1'就是__init__()中定义的self.conv1

4、若构造model时采用了Sequential(),则model.name_parameters()返回的pname形如:

‘body.3.res_layer.1.weight',此处的1.weight实际对应了BN的weight,无法通过pname.find(‘bn')找到该模块。

self.res_layer = Sequential(

Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),

BatchNorm2d(depth),

ReLU(depth),

Conv2d(depth, depth, (3, 3), stride, 1, bias=False),

BatchNorm2d(depth)

)

5、针对4中的情况,两种解决办法:利用OrderedDict修饰Sequential,或利用方法一

downsample = Sequential( OrderedDict([

(‘conv_ds', conv1x1(self.inplanes, planes * block.expansion, stride)),

(‘bn_ds', BatchNorm2d(planes * block.expansion)),

]))

# 如此,相应模块的pname将会带有'conv_ds',‘bn_ds'字样

以上这篇pytorch 网络参数 weight bias 初始化详解就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

初始化模型参数 python_pytorch 网络参数 weight bias 初始化详解_python_脚本之家相关推荐

  1. [Pytorch系列-61]:循环神经网络 - 中文新闻文本分类详解-3-CNN网络训练与评估代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  2. 【GCN】图卷积网络(GCN)入门详解

    机器学习算法与自然语言处理出品 @公众号原创专栏作者 Don.hub 单位 | 京东算法工程师 学校 | 帝国理工大学 图卷积网络(GCN)入门详解 什么是GCN GCN 概述 模型定义 数学推导 G ...

  3. 怎样在两个局域网内共享一台打印机 。常用网络命令及命令实例详解

    怎样在两个局域网内共享一台打印机 怎样在两个局域网内共享一台打印机 我们公司有两间办公室,原先布线的时候用一个路由器延伸出多个接口预埋在墙里并做上插头,IP地址是自动分配的,网关是192.168.0. ...

  4. ModelMatrix、ModelViewMatrix、ProjectionMatrix、NormalMatrix模型矩阵、模型视图矩阵、投影矩阵、正规矩阵详解

    ModelMatrix.ModelViewMatrix.ProjectionMatrix.ModelMatrix模型矩阵.模型视图矩阵.投影矩阵.正规矩阵详解 1. 前言 在openGL经常用到Mod ...

  5. Windows 网络服务架构系列课程详解(一) ----DHCP服务器的搭建与配置

    Windows 网络服务架构系列课程详解(一) ---------DHCP服务器的搭建与配置   实验背景: 企业网络环境中在没有配置DHCP服务器时,经常会遇到这样的情况,用户不懂怎么去配置IP地址 ...

  6. 网络IO和磁盘IO详解

    网络IO和磁盘IO详解 1. 缓存IO 缓存I/O又被称作标准I/O,大多数文件系统的默认I/O操作都是缓存I/O.在Linux的缓存I/O机制中,数据先从磁盘复制到内核空间的缓冲区,然后从内核空间缓 ...

  7. 《Android 网络开发与应用实战详解》——2.3节Android系统架构

    本节书摘来自异步社区<Android 网络开发与应用实战详解>一书中的第2章,第2.3节Android系统架构,作者 王东华,更多章节内容可以访问云栖社区"异步社区"公 ...

  8. [指标应用]乖离率(BIAS)应用详解

    [指标应用]乖离率(BIAS)应用详解   乖离率(BIAS)是描述股价与股价的移动平均线的相距的远近程度.BIAS指的是相对距离. 1.BIAS的计算公式及参数. N日乖离率=(当日收盘价-N日移动 ...

  9. P2P网络节点间如何互访——详解STUN方式NAT穿透

    P2P网络节点间如何互访--详解STUN方式NAT穿透 转载请注明出处:https://www.jzgwind.com/?p=973  by joey 一.背景 P2P网络的核心原理,是将分布在网络上 ...

最新文章

  1. PSVR周年庆开始,大量游戏巨幅促销
  2. linux中统计java数量,linux 统计当前目录下文件数
  3. 语音识别学习日志 2019-7-17 语音识别基础知识准备6 {维特比算法(Viterbi Algorithm)}
  4. PL/SQL第三课(学习笔记)
  5. linux下各文件夹的作用
  6. Redis基础(二)——通用命令和配置
  7. async spring 默认线程池_SpringBoot中Async异步方法和定时任务介绍
  8. 锁、threading.local、线程池
  9. 这几道Redis面试题都不懂,怎么拿到阿里后端offer?
  10. Ubuntu系统上使用锐捷客户端有线连接校园网
  11. 一个简单的软件测试流程(附带流程详解)流程图
  12. 西门子界面官方精美触摸屏+WINCC程序模板 西门子官方触摸屏程序模板,炫酷的扁平式动画效果
  13. SQL中NOW() 函数
  14. 360修复高危漏洞可以修复吗_Win7系统360安全卫士提示“进行漏洞的修复”是否该修复?...
  15. 彼时彼刻恰似此时此刻
  16. CSS——web字体与CSS字体图标
  17. 百度网盘中直接双击编辑的PPT关闭后,找不到了。
  18. 三菱ST言语编程梳理
  19. 带你入门多目标跟踪(一)领域概述
  20. 机器人唱歌bgm_Soul app里面机器人匹配的那首bgm是什么呀?好好听!!!求玩过soul的大神告知!!...

热门文章

  1. Qt--在.pro文件中添加链接库的写法
  2. sqlite3数据库最大可以是多大?可以存放多少数据?读写性能怎么样?详述
  3. MFC程序提示 0xC0000005: 读取位置 0x00000020 时发生访问冲突。
  4. 窗口类、窗口类对象与窗口 三者之间关系
  5. C#中的继承与多态还有接口
  6. IOS开发基础之SQLite3数据库的使用增删改查
  7. Java连接Mysql数据库增删改查实现
  8. ubuntu安装 rust nightly_Rust 嵌入式开发环境搭建指南 (一):让世界闪烁吧
  9. 企业中书写css,web前端开发企业级CSS常用命名,书写规范总结(示例代码)
  10. golang sdk后端怎么用_Python比Golang慢多少?实际上两者差异并不大