【小白学PyTorch】6.模型的构建访问遍历存储(附代码)
<<小白学PyTorch>>
小白学PyTorch | 5 torchvision预训练模型与数据集全览
小白学PyTorch | 4 构建模型三要素与权重初始化
小白学PyTorch | 3 浅谈Dataset和Dataloader
小白学PyTorch | 2 浅谈训练集验证集和测试集
小白学PyTorch | 1 搭建一个超简单的网络
小白学PyTorch | 动态图与静态图的浅显理解
文章目录:
1 模型构建函数
1.1 add_module
1.2 ModuleList
1.3 Sequential
1.4 小总结
2 遍历模型结构
2.1 modules()
2.2 named_modules()
2.3 parameters()
3 保存与载入
本文是对一些函数的学习。函数主要包括下面四个方便:
模型构建的函数:
add_module
,add_module
,add_module
访问子模块:
add_module
,add_module
,add_module
,add_module
网络遍历:
add_module
,add_module
模型的保存与加载:
add_module
,add_module
,add_module
1 模型构建函数
torch.nn.Module
是所有网络的基类,在PyTorch实现模型的类中都要继承这个类(这个在之前的课程中已经提到)。在构建Module中,Module是一个包含其他的Module的,类似于,你可以先定义一个小的网络模块,然后把这个小模块作为另外一个网络的组件。因此网络结构是呈现树状结构。
我们先简单定义一个网络:
import torch.nn as nn
import torch
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.conv1 = nn.Conv2d(3,64,3)self.conv2 = nn.Conv2d(64,64,3)def forward(self,x):x = self.conv1(x)x = self.conv2(x)return x
net = MyNet()
print(net)
输出结果:MyNet
中有两个属性conv1
和conv2
是两个卷积层,在正向传播forward
的过程中,依次调用这两个卷积层实现网络的功能。
1.1 add_module
这种是最常见的定义网络的功能,在有些项目中,会看到这样的方法add_module
。我们用这个方法来重写上面的网络:
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.add_module('conv1',nn.Conv2d(3,64,3))self.add_module('conv2',nn.Conv2d(64,64,3))def forward(self,x):x = self.conv1(x)x = self.conv2(x)return x
其实add_module(name,layer)
和self.name=layer
实现了相同的功能,个人感觉也许是因为add_module可以使用字符串来定义变量名字,所以可以放在循环中?反正这个先了解熟悉熟悉。
上面的两种方法都是一层一层的添加layer,如果网络复杂的话,那就需要写很多重复的代码了。因此接下来来讲解一下网络模块的构建,torch.nn.ModuleList
和torch.nn.Sequential
1.2 ModuleList
ModuleList
按照字面意思是用list
的形式保存网络层的。这样就可以先将网络需要的layer构建好,保存到一个list,然后通过ModuleList
方法添加到网络中.
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(5)])def forward(self,x):for l in self.linears:x = l(x)return x
net = MyNet()
print(net)
输出结果是:
这个ModuleList主要是用在读取config文件来构建网络模型中的,下面用VGG模型的构建为例子:
vgg_cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',512, 512, 512, 'M']def vgg(cfg, i, batch_norm=False):layers = []in_channels = ifor v in cfg:if v == 'M':layers += [nn.MaxPool2d(kernel_size=2, stride=2)]elif v == 'C':layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)if batch_norm:layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]else:layers += [conv2d, nn.ReLU(inplace=True)]in_channels = vreturn layersclass Model1(nn.Module):def __init__(self):super(Model1,self).__init__()self.vgg = nn.ModuleList(vgg(vgg_cfg,3))def forward(self,x):for l in self.vgg:x = l(x)
m1 = Model1()
print(m1)
先读取网络结构的配置文件vgg_cfg
然后根据这个文件创建对应的Layer list,然后使用ModuleList
添加到网络中,这样可以快速创建不同的网络(用上面为例子的话,可以通过修改配置文件,然后快速修改网络结构 )
1.3 Sequential
在一些自己做的小项目中,Sequential
其实用的更为频繁。依然重写最初最简单的例子:
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.conv = nn.Sequential(nn.Conv2d(3,64,3),nn.Conv2d(64,64,3))def forward(self,x):x = self.conv(x)return x
net = MyNet()
print(net)
运行结果:
观察细致的朋友可以发现这个问题,Seqential内的网络层是默认用数字进行标号的,而一开始我们使用self.conv1
和self.conv2
的时候,使用conv1和conv2作为标号的。
我们如何修改Sequential
中网络层的名称呢?这里需要使用到collections.OrderedDict
有序字典。Sequential
是支持有序字典构建的。
from collections import OrderedDict
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.conv = nn.Sequential(OrderedDict([('conv1',nn.Conv2d(3,64,3)),('conv2',nn.Conv2d(64,64,3))]))def forward(self,x):x = self.conv(x)return x
net = MyNet()
print(net)
输出结果:
1.4 小总结
单独增加一个网络层或者子模块,可以用
add_module
或者直接赋予属性;ModuleList
可以将一个Module的List增加到网络中,自由度较高。Sequential
按照顺序产生一个Module模块。这里推荐习惯使用OrderedDict的方法进行构建。对网络层加上规范的名称,这样有助于后续查找与遍历
2 遍历模型结构
本章节使用下面的方法进行遍历之前提到的Module
。(个人理解,Module是多个layer的合并,但是一个layer可以说成Module。 ) 先定义一个网络吧,随便写一个:
import torch.nn as nn
import torch
from collections import OrderedDict
class MyNet(nn.Module):def __init__(self):super(MyNet,self).__init__()self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=3)self.conv2 = nn.Conv2d(64,64,3)self.maxpool1 = nn.MaxPool2d(2,2)self.features = nn.Sequential(OrderedDict([('conv3', nn.Conv2d(64,128,3)),('conv4', nn.Conv2d(128,128,3)),('relu1', nn.ReLU())]))def forward(self,x):x = self.conv1(x)x = self.conv2(x)x = self.maxpool1(x)x = self.features(x)return x
net = MyNet()
print(net)
输出结果是:
2.1 modules()
在第四课中初始化模型各个层的参数的时候,用到了这个方法,现在我们再来理解一下:
for idx,m in enumerate(net.modules()):print(idx,"-",m)
运行结果:
上面那个网络构建的时候用到了Sequential
,所以网络中其实是嵌套了一个小的Module,这就是之前提到的树状结构,然后上面便利的时候也是树状结构的便利过程,可以看出来应该是一个深度遍历的过程。
首先第一个输出的是最大的那个Module,也就是整个网络,
0-Model
整个网络模块;1-2-3-4
是网络的四个子模块,4-Sequential
中间仍然包含子模块5-6-7
是模块4-Sequential
的子模块。
【总结】
modules()
是递归的返回网络的各个module(深度遍历),从最顶层直到最后的叶子的module。
2.2 named_modules()
named_modules()
和module()
类似,只是同时返回name和module。
for idx,(name,m) in enumerate(net.named_modules()):print(idx,"-",name)
输出结果:
2.3 parameters()
for p in net.parameters():print(type(p.data),p.size())
运行结果:
输出的是四个卷积层的权重矩阵参数和偏置参数。值得一提的是,对网络进行训练时需要将parameters()作为优化器optimizer的参数。
optimizer = torch.optim.SGD(net.parameters(),lr = 0.001,momentum=0.9)
总之呢,这个parameters()
是返回网络所有的参数,主要用在给optimizer优化器用的。而要对网络的某一层的参数做处理的时候,一般还是使用named_parameters()方便一些。
for idx,(name,m) in enumerate(net.named_parameters()):print(idx,"-",name,m.size())
输出结果:
【小扩展】
我个人有时会使用下面的方法来获取参数:
for idx,(name,m) in enumerate(net.named_modules()):if isinstance(m,nn.Conv2d):print(m.weight.shape)print(m.bias.shape)
先判断是否是卷积层,然后获取其参数,输出结果:
3 保存与载入
PyTorch使用torch.save
和torch.load
方法来保存和加载网络,而且网络结构和参数可以分开的保存和加载。
torch.save(model,'model.pth') # 保存
model = torch.load("model.pth") # 加载
pytorch中网络结构和模型参数是可以分开保存的。上面的方法是两者同时保存到了.pth文件中,当然,你也可以仅仅保存网络的参数来减小存储文件的大小。注意:如果你仅仅保存模型参数,那么在载入的时候,是需要通过运行代码来初始化模型的结构的。
torch.save(model.state_dict(),"model.pth") # 保存参数
model = MyNet() # 代码中创建网络结构
params = torch.load("model.pth") # 加载参数
model.load_state_dict(params) # 应用到网络结构中
至此,我们今天已经学习了不少的内容,大家对PyTorch的掌握更近一步了呢~
- END -
往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑获取一折本站知识星球优惠券,复制链接直接打开:https://t.zsxq.com/662nyZF本站qq群1003271085。加入微信群请扫码进群(如果是博士或者准备读博士请说明):
【小白学PyTorch】6.模型的构建访问遍历存储(附代码)相关推荐
- c++list遍历_小白学PyTorch | 6 模型的构建访问遍历存储(附代码)
关注一下不迷路哦~喜欢的点个星标吧~<> 小白学PyTorch | 5 torchvision预训练模型与数据集全览 小白学PyTorch | 4 构建模型三要素与权重初始化 小白学PyT ...
- c++ 遍历list_小白学PyTorch | 6 模型的构建访问遍历存储(附代码
文章来自微信公众号:[机器学习炼丹术],是个人的学习心得分享基地. 文章目录: 1 模型构建函数 1.1 add_module 1.2 ModuleList 1.3 Sequential 1.4 小总 ...
- 【小白学PyTorch】18.TF2构建自定义模型
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 17 TFrec文件的创建与读取 扩展之Tensorflow2.0 | 1 ...
- pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构
[机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...
- pytorch dataset_【小白学PyTorch】16.TF2读取图片的方法
<> 扩展之tensorflow2.0 | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 tensorboardX可视化教程 小白学PyTorch | 13 Ef ...
- data后缀文件解码_小白学PyTorch | 17 TFrec文件的创建与读取
[机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 16 TF2读取图片的方法 小白学PyTorch | 15 TF2实现一个简单的服装分类任务 小白学PyTorch | 14 ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(下)池化、Normalization
<<小白学PyTorch>> 扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积.激活.初始化.正则 扩展之Tensorflow2.0 | 20 TF ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...
- 【小白学PyTorch】扩展之Tensorflow2.0 | 20 TF2的eager模式与求导
[机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 19 TF2模型的存储与载入 扩展之Tensorflow2.0 | 18 ...
最新文章
- Maven学习 使用Nexus搭建Maven私服(转)
- 在Docker中运行Dubbo应用
- Navicat怎样筛选数据
- js 正则 显示千分号 支持整数和小数
- linux死锁的例子,操作系统教程—Linux实例分析 孟庆昌 第8章 死锁new.ppt
- C++语言vector容器介绍和示例
- java爬虫 webcollector_Java爬虫-WebCollector | 学步园
- qca9535 tftp32 刷机_【U-Boot】U-Boot 刷机方法大全
- 怎么使用biopython_什么是Biopython? 你能用Biopython做什么? Biopython功能概。
- Python批量将MP3音频转为WAV格式(附代码) | Python工具
- json parser类的使用
- 唯物史观在高中历史教学中的具体运用
- 30系列NVIDIA显卡安装tensorflow 极简
- maya如何导入多片段动画文件查看和编辑
- Subversive or Subclipse
- Emlog采集插件 刀网资源采集 一键显示资源1.1
- PM、RD、FE、UE...等等这些互联网相关的缩写
- 观察者模式解读厦门建国以来最强台风
- 如何形象地理解涌现?
- TBTCOIN硬分叉
热门文章
- iphone 使用委托(delegate)在不同的窗口之间传递数据
- oracle怎么捕获表上的DML语句(不包括select)语句)
- MySQL数据库Innodb储存引擎----储存页的结构
- js数组对象的常用方法
- R语言处理Web数据
- SimpleAdapter的用法
- 标准配置的UBUNTU 11.10 RUBY VMWARE 镜像,手工MOD(ZSH_RVM_RAILS_VIM)
- php发送gmail,使用GMail SMTP服务器从PHP页面发送电子邮件
- java 字符串 加密_如何用JAVA实现字符串简单加密解密?
- Image Pro Plus测量组织平均厚度