本文参考连接:https://www.jianshu.com/p/c982d55db463
个人认为是讲的比较通俗易懂的一篇好文。

针对于不同层类型定制化初始化
举个栗子:

def weights_init(m):    ##定义参数初始化函数                  classname = m.__class__.__name__    # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.elif classname.find('BatchNorm') != -1:           nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。

具体例子见下:

class Generator(nn.Module):  # 创建Generator子类,括号内指定父类的名称def __init__(self):  # 初始化父类的属性super(Generator, self).__init__()  # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属self.main = nn.Sequential(  # 按照顺序构造神经层,序列容器nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False),  # 转置卷积,输出nn.BatchNorm2d(10)  # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)netG = Generator()        # 定义类实例

在这个例子中:
1、第一个代码中的classname=ConvTranspose2d,classname=BatchNorm2d。2、第一个代码中的netG.apply(weights_init),会按顺序:先看看nn.ConvTranspose2d这一层,需不需要初始化,然后再看看nn.BatchNorm2d这一层需不需要初始化。
将以上两个例子放在一块,得到如下的代码。个人认为可以这样来理解。对于有些地方还不是很明白,欢迎各位大神指点!!!!!

class Generator(nn.Module):  # 创建Generator子类,括号内指定父类的名称def __init__(self):  # 初始化父类的属性super(Generator, self).__init__()  # 将父类和子类关联,调用父类nn.Moudle的方法__init__(),让Generator实例包含父类的所有属self.main = nn.Sequential(  # 按照顺序构造神经层,序列容器nn.ConvTranspose2d(5, 10, 4, 1, 0, bias=False),  # 转置卷积,输出nn.BatchNorm2d(10)  # 对每个特征图上的点,进行减均值除标准差的操作,affine设置为true(默认),引入权重w和b两个可学习的参数)netG = Generator()        # 定义类实例def weights_init(m):    ##定义参数初始化函数                  classname = m.__class__.__name__    # m作为一个形参,原则上可以传递很多的内容,为了实现多实参传递,每一个moudle要给出自己的name. 所以这句话就是返回m的名字。具体例子下边会详细说明。if classname.find('Conv') != -1:#find()函数,实现查找classname中是否含有conv字符,没有返回-1;有返回0.nn.init.normal_(m.weight.data, 0.0, 0.02)#m.weight.data表示需要初始化的权重。 nn.init.normal_()表示随机初始化采用正态分布,均值为0,标准差为0.02.elif classname.find('BatchNorm') != -1:           nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0) # nn.init.constant_()表示将偏差定义为常量0 netG.apply(weights_init)#netG是我们给写的神经网络定义的类实例。apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。也就是说apply函数,会一层一层的去拜访Generator网络层。

pytorch——weights_init(m)相关推荐

  1. 纯手工搭建DCGAN,从零开始

    目 录 前言 1.明确搭建流程 2.搭出大概轮廓 2.1. nn.BCELoss(x,y)交叉熵损失函数的公式 2.2. 注意 2.3. 参考 3.调通程序 3.1. 运行结果 3.2. 调试中,需要 ...

  2. PyTorch图像分类从模型自定义到测试

    点击上方"小白学视觉",选择加"星标"或"置顶"重磅干货,第一时间送达 01.什么是 Pytorch 一句话总结 Pytorch = Pyt ...

  3. Pytorch:使用DCGAN实现数据复制

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 DCGAN Ian J. Goodfellow首次提出了GAN之后 ...

  4. CGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)

    完整代码:代码地址https://www.lanzouw.com/iVadvo386ofhttps://www.lanzouw.com/iVadvo386of CGAN比DCGAN更进一步,利用标签信 ...

  5. DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)

    代码下载地址下载地址https://www.lanzouw.com/ipl8Yo37qxihttps://www.lanzouw.com/ipl8Yo37qxi Anime数据请在Anime Face ...

  6. 【pytorch速成】Pytorch图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...

  7. celeba数据集_轻松学 Pytorch 使用DCGAN实现数据复制

    点击上方蓝字关注我们 微信公众号:OpenCV学堂 关注获取更多计算机视觉与深度学习知识 DCGAN Ian J. Goodfellow首次提出了GAN之后,生成对抗只是神经网络还不是深度卷积神经网络 ...

  8. pytorch forward_pytorch使用hook打印中间特征图、计算网络算力等

    0.参考 https://oldpan.me/archives/pytorch-autograd-hook https://pytorch.org/docs/stable/search.html?q= ...

  9. pytorch中的参数初始化方法

    参数初始化(Weight Initialization) PyTorch 中参数的默认初始化在各个层的 reset_parameters() 方法中.例如:nn.Linear 和 nn.Conv2D, ...

最新文章

  1. AlexeyAB DarkNet YOLOv3框架解析与应用实践(四)
  2. EhCache的特性
  3. 福利 | 2022全球敏捷运维峰会:跟技术老将畅聊时下数据库、运维、金融科技应“云”而生的技术创新...
  4. 2019年值得关注的人工智能技术的五大趋势
  5. appcan 微信支付
  6. 【动态规划】【线段树】 Codeforces Round #426 (Div. 1) B. The Bakery
  7. 通过rhel7的kvm虚拟机实现3节点Postgres-XL(包括gtm standby)
  8. 《团队名称》第八次团队作业:Alpha冲刺day1
  9. 微信小程序踩坑(1):wx.showModal模态对话框中content换行
  10. 更改eclipse字体
  11. JSP中退出登录销毁Session
  12. 边写SQL边学数据库入门实验2(持续更新)
  13. pdf在线免费去水印 以及图片去水印 方法
  14. Python pywin32(一)
  15. HDU 6069 Counting Divisors
  16. SMD元件尺寸大小公制英制对应说明
  17. vue路由文件相关配置
  18. 基于 Paraview 扩展与实现——(2)
  19. 华为手机热点无法连接_别傻了!不能只会给别人开热点,要尝试华为手机的WiFi分享功能...
  20. Android自定义权限CVE漏洞分析 (IEEE论文)

热门文章

  1. java斜体_Java可以指示字体是否为斜体字
  2. 小小突击队服务器维护多久,《小小突击队》08月06日更新公告
  3. 5.20 按照邮箱账号的域名进行排序 [原创Excel教程]
  4. android倒影效果,Android 设置图片倒影效果
  5. C#接入steam内购
  6. html table快捷键,超级实用且神奇的表格快捷键
  7. unity lookat导致物体颠倒怎么解决_在Unity 2D中如何用一行代码实现LookAt的效果,以及向量归一化小总结...
  8. ubuntu 取消打印队列命令
  9. JAVA工程师最新面试题(来源于互联网)
  10. 威纶通触摸屏与温控器进行MODBUS通信并通过宏指令将数据发送给PLC的具体方法