pytorch——weights_init(m)
本文参考连接:https://www.jianshu.com/p/c982d55db463
个人认为是讲的比较通俗易懂的一篇好文。
针对于不同层类型定制化初始化
举个栗子:
def weights_init(m): ##定义参数初始化函数 classname = m.__class__.__name__ # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。
具体例子见下:
class Generator(nn.Module): # 创建Generator子类,括号内指定父类的名称def __init__(self): # 初始化父类的属性super(Generator, self).__init__() # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属self.main = nn.Sequential( # 按照顺序构造神经层,序列容器nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False), # 转置卷积,输出nn.BatchNorm2d(10) # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)netG = Generator() # 定义类实例
在这个例子中:
1、第一个代码中的classname=ConvTranspose2d,classname=BatchNorm2d。2、第一个代码中的netG.apply(weights_init),会按顺序:先看看nn.ConvTranspose2d这一层,需不需要初始化,然后再看看nn.BatchNorm2d这一层需不需要初始化。
将以上两个例子放在一块,得到如下的代码。个人认为可以这样来理解。对于有些地方还不是很明白,欢迎各位大神指点!!!!!
class Generator(nn.Module): # 创建Generator子类,括号内指定父类的名称def __init__(self): # 初始化父类的属性super(Generator, self).__init__() # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属self.main = nn.Sequential( # 按照顺序构造神经层,序列容器nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False), # 转置卷积,输出nn.BatchNorm2d(10) # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)netG = Generator() # 定义类实例def weights_init(m): ##定义参数初始化函数 classname = m.__class__.__name__ # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。
pytorch——weights_init(m)相关推荐
- 纯手工搭建DCGAN,从零开始
目 录 前言 1.明确搭建流程 2.搭出大概轮廓 2.1. nn.BCELoss(x,y)交叉熵损失函数的公式 2.2. 注意 2.3. 参考 3.调通程序 3.1. 运行结果 3.2. 调试中,需要 ...
- PyTorch图像分类从模型自定义到测试
点击上方"小白学视觉",选择加"星标"或"置顶"重磅干货,第一时间送达 01.什么是 Pytorch 一句话总结 Pytorch = Pyt ...
- Pytorch:使用DCGAN实现数据复制
点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 DCGAN Ian J. Goodfellow首次提出了GAN之后 ...
- CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
完整代码:代码地址https://www.lanzouw.com/iVadvo386ofhttps://www.lanzouw.com/iVadvo386of CGAN比DCGAN更进一步,利用标签信 ...
- DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)
代码下载地址下载地址https://www.lanzouw.com/ipl8Yo37qxihttps://www.lanzouw.com/ipl8Yo37qxi Anime数据请在Anime Face ...
- 【pytorch速成】Pytorch图像分类从模型自定义到测试
文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...
- celeba数据集_轻松学 Pytorch 使用DCGAN实现数据复制
点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 DCGAN Ian J. Goodfellow首次提出了GAN之后,生成对抗只是神经网络还不是深度卷积神经网络 ...
- pytorch forward_pytorch使用hook打印中间特征图、计算网络算力等
0.参考 https://oldpan.me/archives/pytorch-autograd-hook https://pytorch.org/docs/stable/search.html?q= ...
- pytorch中的参数初始化方法
参数初始化(Weight Initialization) PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中.例如:nn.Linear 和 nn.Conv2D, ...
最新文章
- AlexeyAB DarkNet YOLOv3框架解析与应用实践(四)
- EhCache的特性
- 福利 | 2022全球敏捷运维峰会:跟技术老将畅聊时下数据库、运维、金融科技应“云”而生的技术创新...
- 2019年值得关注的人工智能技术的五大趋势
- appcan 微信支付
- 【动态规划】【线段树】 Codeforces Round #426 (Div. 1) B. The Bakery
- 通过rhel7的kvm虚拟机实现3节点Postgres-XL(包括gtm standby)
- 《团队名称》第八次团队作业:Alpha冲刺day1
- 微信小程序踩坑(1):wx.showModal模态对话框中content换行
- 更改eclipse字体
- JSP中退出登录销毁Session
- 边写SQL边学数据库入门实验2(持续更新)
- pdf在线免费去水印 以及图片去水印 方法
- Python pywin32(一)
- HDU 6069 Counting Divisors
- SMD元件尺寸大小公制英制对应说明
- vue路由文件相关配置
- 基于 Paraview 扩展与实现——(2)
- 华为手机热点无法连接_别傻了!不能只会给别人开热点,要尝试华为手机的WiFi分享功能...
- Android自定义权限CVE漏洞分析 (IEEE论文)
热门文章
- java斜体_Java可以指示字体是否为斜体字
- 小小突击队服务器维护多久,《小小突击队》08月06日更新公告
- 5.20 按照邮箱账号的域名进行排序 [原创Excel教程]
- android倒影效果,Android 设置图片倒影效果
- C#接入steam内购
- html table快捷键,超级实用且神奇的表格快捷键
- unity lookat导致物体颠倒怎么解决_在Unity 2D中如何用一行代码实现LookAt的效果,以及向量归一化小总结...
- ubuntu 取消打印队列命令
- JAVA工程师最新面试题(来源于互联网)
- 威纶通触摸屏与温控器进行MODBUS通信并通过宏指令将数据发送给PLC的具体方法