文章来自微信公众号:【机器学习炼丹术】,是个人的学习心得分享基地。

文章目录:

  • 1 模型构建函数

    • 1.1 add_module
    • 1.2 ModuleList
    • 1.3 Sequential
    • 1.4 小总结
  • 2 遍历模型结构
    • 2.1 modules()
    • 2.2 named_modules()
    • 2.3 parameters()
  • 3 保存与载入

本文是对一些函数的学习。函数主要包括下面四个方面:

  • 模型构建的函数:add_module,add_module,add_module
  • 访问子模块:add_module,add_module,add_module,add_module
  • 网络遍历:add_module,add_module
  • 模型的保存与加载:add_module,add_module,add_module

1 模型构建函数

torch.nn.Module是所有网络的基类,在PyTorch实现模型的类中都要继承这个类(这个在之前的课程中已经提到)。在构建Module中,Module是一个包含其他的Module的,类似于,你可以先定义一个小的网络模块,然后把这个小模块作为另外一个网络的组件。因此网络结构是呈现树状结构

我们先简单定义一个网络:

import torch.nn as nnimport torch class MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.conv1 = nn.Conv2d(3,64,3)        self.conv2 = nn.Conv2d(64,64,3)    def forward(self,x):        x = self.conv1(x)        x = self.conv2(x)        return xnet = MyNet()print(net)

输出结果:

MyNet中有两个属性conv1和conv2是两个卷积层,在正向传播forward的过程中,依次调用这两个卷积层实现网络的功能。

1.1 add_module

这种是最常见的定义网络的功能,在有些项目中,会看到这样的方法add_module。我们用这个方法来重写上面的网络:

class MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.add_module('conv1',nn.Conv2d(3,64,3))        self.add_module('conv2',nn.Conv2d(64,64,3))    def forward(self,x):        x = self.conv1(x)        x = self.conv2(x)        return x

其实add_module(name,layer)和self.name=layer实现了相同的功能,个人感觉也许是因为add_module可以使用字符串来定义变量名字,所以可以放在循环中?反正这个先了解熟悉熟悉

上面的两种方法都是一层一层的添加layer,如果网络复杂的话,那就需要写很多重复的代码了。因此接下来来讲解一下网络模块的构建,torch.nn.ModuleList和torch.nn.Sequential

1.2 ModuleList

ModuleList按照字面意思是用list的形式保存网络层的。这样就可以先将网络需要的layer构建好,保存到一个list,然后通过ModuleList方法添加到网络中.

class MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.linears = nn.ModuleList(            [nn.Linear(10,10) for i in range(5)]        )    def forward(self,x):        for l in self.linears:            x = l(x)        return xnet = MyNet()print(net)

输出结果是:

这个ModuleList主要是用在读取config文件来构建网络模型中的,下面用VGG模型的构建为例子:

vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',           512, 512, 512, 'M']def vgg(cfg, i, batch_norm=False):    layers = []    in_channels = i    for v in cfg:        if v == 'M':            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]        elif v == 'C':            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]        else:            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)            if batch_norm:                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]            else:                layers += [conv2d, nn.ReLU(inplace=True)]            in_channels = v    return layersclass Model1(nn.Module):    def __init__(self):        super(Model1,self).__init__()        self.vgg = nn.ModuleList(vgg(vgg_cfg,3))    def forward(self,x):        for l in self.vgg:            x = l(x)m1 = Model1()print(m1)

先读取网络结构的配置文件vgg_cfg然后根据这个文件创建对应的Layer list,然后使用ModuleList添加到网络中,这样可以快速创建不同的网络(用上面为例子的话,可以通过修改配置文件,然后快速修改网络结构 )

1.3 Sequential

在一些自己做的小项目中,Sequential其实用的更为频繁。依然重写最初最简单的例子:

class MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.conv = nn.Sequential(            nn.Conv2d(3,64,3),            nn.Conv2d(64,64,3)        )    def forward(self,x):        x = self.conv(x)        return xnet = MyNet()print(net)

运行结果:

观察细致的朋友可以发现这个问题,Seqential内的网络层是默认用数字进行标号的,而一开始我们使用self.conv1和self.conv2的时候,使用conv1和conv2作为标号的。

我们如何修改Sequential中网络层的名称呢?这里需要使用到collections.OrderedDict有序字典。Sequential是支持有序字典构建的。

from collections import OrderedDict class MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.conv = nn.Sequential(OrderedDict([            ('conv1',nn.Conv2d(3,64,3)),            ('conv2',nn.Conv2d(64,64,3))        ]))    def forward(self,x):        x = self.conv(x)        return xnet = MyNet()print(net)

输出结果:

1.4 小总结

  • 单独增加一个网络层或者子模块,可以用add_module或者直接赋予属性;
  • ModuleList可以将一个Module的List增加到网络中,自由度较高。
  • Sequential按照顺序产生一个Module模块。这里推荐习惯使用OrderedDict的方法进行构建。对网络层加上规范的名称,这样有助于后续查找与遍历

2 遍历模型结构

本章节使用下面的方法进行遍历之前提到的Module。(个人理解,Module是多个layer的合并,但是一个layer可以说成Module。 ) 先定义一个网络吧,随便写一个:

import torch.nn as nnimport torch from collections import OrderedDictclass MyNet(nn.Module):    def __init__(self):        super(MyNet,self).__init__()        self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)        self.conv2 = nn.Conv2d(64,64,3)        self.maxpool1 = nn.MaxPool2d(2,2)        self.features = nn.Sequential(OrderedDict([            ('conv3', nn.Conv2d(64,128,3)),            ('conv4', nn.Conv2d(128,128,3)),            ('relu1', nn.ReLU())        ]))    def forward(self,x):        x = self.conv1(x)        x = self.conv2(x)        x = self.maxpool1(x)        x = self.features(x)        return xnet = MyNet()print(net)

输出结果是:

2.1 modules()

在第四课中初始化模型各个层的参数的时候,用到了这个方法,现在我们再来理解一下:

for idx,m in enumerate(net.modules()):    print(idx,"-",m)

运行结果:

上面那个网络构建的时候用到了Sequential,所以网络中其实是嵌套了一个小的Module,这就是之前提到的树状结构,然后上面便利的时候也是树状结构的便利过程,可以看出来应该是一个深度遍历的过程。

  • 首先第一个输出的是最大的那个Module,也就是整个网络,0-Model整个网络模块;
  • 1-2-3-4是网络的四个子模块,4-Sequential中间仍然包含子模块
  • 5-6-7是模块4-Sequential的子模块。

【总结】

modules()是递归的返回网络的各个module(深度遍历),从最顶层直到最后的叶子的module。

2.2 named_modules()

named_modules()和module()类似,只是同时返回name和module。

for idx,(name,m) in enumerate(net.named_modules()):    print(idx,"-",name)

输出结果:

2.3 parameters()

for p in net.parameters():    print(type(p.data),p.size())

运行结果:

输出的是四个卷积层的权重矩阵参数和偏置参数。值得一提的是,对网络进行训练时需要将parameters()作为优化器optimizer的参数。

optimizer = torch.optim.SGD(net.parameters(),                            lr = 0.001,                            momentum=0.9)

总之呢,这个parameters()是返回网络所有的参数,主要用在给optimizer优化器用的。而要对网络的某一层的参数做处理的时候,一般还是使用named_parameters()方便一些。

for idx,(name,m) in enumerate(net.named_parameters()):    print(idx,"-",name,m.size())

输出结果:

【小扩展】

我个人有时会使用下面的方法来获取参数:

for idx,(name,m) in enumerate(net.named_modules()):    if isinstance(m,nn.Conv2d):        print(m.weight.shape)        print(m.bias.shape)

先判断是否是卷积层,然后获取其参数,输出结果:

3 保存与载入

PyTorch使用torch.save和torch.load方法来保存和加载网络,而且网络结构和参数可以分开的保存和加载。

torch.save(model,'model.pth') # 保存model = torch.load("model.pth") # 加载

pytorch中网络结构和模型参数是可以分开保存的。上面的方法是两者同时保存到了.pth文件中,当然,你也可以仅仅保存网络的参数来减小存储文件的大小。注意:如果你仅仅保存模型参数,那么在载入的时候,是需要通过运行代码来初始化模型的结构的。

torch.save(model.state_dict(),"model.pth") # 保存参数model = MyNet() # 代码中创建网络结构params = torch.load("model.pth") # 加载参数model.load_state_dict(params) # 应用到网络结构中

至此,我们今天已经学习了不少的内容,大家对PyTorch的掌握更近一步了呢~

- END -

c++ 遍历list_小白学PyTorch | 6 模型的构建访问遍历存储(附代码相关推荐

  1. c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)

    关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

  2. 【小白学PyTorch】18.TF2构建自定义模型

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 17 TFrec文件的创建与读取 扩展之Tensorflow2.0 | 1 ...

  3. 【小白学PyTorch】6.模型的构建访问遍历存储(附代码)

    <<小白学PyTorch>> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...

  4. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...

  5. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization

    <<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...

  6. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

  7. 【小白学PyTorch】扩展之Tensorflow2.0 | 20 TF2的eager模式与求导

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 19 TF2模型的存储与载入 扩展之Tensorflow2.0 | 18 ...

  8. 【小白学PyTorch】17.TFrec文件的创建与读取

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分 ...

  9. 【小白学PyTorch】16.TF2读取图片的方法

    <<小白学PyTorch>> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 ...

最新文章

  1. 中国移动这个编程大赛来了!
  2. 传真休眠怎么取消_C盘满了怎么办——系统瘦身
  3. Ubuntu14.04下切换系统自带的Python和Anaconda 下的Python
  4. File Filter用法
  5. c++ builder 读取指定单个名称节点的值
  6. IOS成长之路-检测耳机插入/拔出
  7. 阿里开源软件替换指南
  8. 介绍一个免费的采用人工智能放大老照片的在线网站
  9. spring-使用配置文件完成JdbcTemplate操作数据库
  10. mysql安装前的系统准备工作(转)
  11. HTTP POST GET 区别
  12. np.concatenate 函数的使用
  13. 《精通正则表达式》笔记 --- 选择引号内的文字
  14. 《田野里的自然历史课》首发 科普中华农耕文明
  15. JN5169 NXP ZigBee PRO 无线网络应用所需的常见操作(二)
  16. python蜂鸣器_Python与硬件学习笔记:蜂鸣器(转)
  17. VC 轻松实现非客户区按钮
  18. Feburary——1052.爱生气的书店老板
  19. SAP FI 系列 (031) - 允许在会计凭证中修改统驭科目
  20. 鸿蒙3部曲先看哪部,星辰变是“鸿蒙”系列的作品,那“鸿蒙”系列到底有多少部曲?...

热门文章

  1. Assembly中Load, LoadFrom, LoadFile以及AppDomain, Activator类中相应函数的区别
  2. 现金贷风控生命周期——贷前风控
  3. 万张PubFig人脸数据实现基于python+OpenCV的人脸特征定位程序(1)
  4. pytorch 中的 split
  5. ElasticSearch入门 第五篇:使用C#查询文档
  6. UIButton @selector 想要传递多个参数
  7. 学会使用JDK API
  8. Cocos2d-x之MenuItem
  9. java 程序执行原理
  10. 具体解决VS“滴答数必须介于 DateTime.MinValue.Ticks 和 DateTime.MaxValue.Ticks 之间”奇怪问题...