详解PyTorch中的ModuleList和Sequential
点击上方“视学算法”,选择加"星标"或“置顶”
重磅干货,第一时间送达
作者丨小占同学@知乎(已授权)
来源丨https://zhuanlan.zhihu.com/p/75206669
编辑丨极市平台
导读
本文详细讲解了PyTorch中的nn.Sequential和nn.ModuleList两个模块。
在使用PyTorch的时候,经常遇到nn.Sequential和nn.ModuleList,今天将这两个模块认真区分了一下,总结如下。PyTorch版本为1.0.0。本文也会随着本人逐渐深入Torch和有新的体会时,会进行更新。
本人才疏学浅,希望各位看官不吝赐教。
一、官方文档
首先看官方文档的解释,仅列出了容器(Containers)中几个比较常用的CLASS。
CLASS torch.nn.Module
Base class for all neural network modules.
Your models should also subclass this class.
import torch.nn as nnimport torch.nn.functional as F
class Model(nn.Module): def __init__(self): super(Model, self).__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5)def forward(self, x): x = F.relu(conv1(x)) return F.relu(conv2(x))
CLASS torch.nn.Sequential(*args)
A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
# Example of using Sequentialmodel = nn.Sequential( nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() )# Example of using Sequential with OrderedDictmodel = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1, 20, 5)), ('ReLU1', nn.ReLU()), ('conv2', nn.Conv2d(20, 64, 5)), ('ReLU2', nn.ReLU()) ]))
CLASS torch.nn.ModuleList(modules=None)
Holds submodules in a list.
[ModuleList]
can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all [Module]
methods.
ModuleList:https://pytorch.org/docs/stable/nn.html#torch.nn.ModuleList
Module:https://pytorch.org/docs/stable/nn.html#torch.nn.Module
class MyModel(nn.Module): def __init__(self): super(MyModel, self).__init__() self.linears = nn.ModuleList([nn.linear for i in range(10)])# ModuleList can act as an iterable, or be indexed using ints def forward(self, x): for i, l in enumerate(self.linears): x = self.linears[i // 2](x) + l(x) return x
二、nn.Sequential与nn.ModuleList简介
nn.Sequential
nn.Sequential里面的模块按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。如下面的例子所示:
#首先导入torch相关包import torchimport torch.nn as nnimport torch.nn.functional as Fclass net_seq(nn.Module): def __init__(self): super(net2, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)
nn.Sequential中可以使用OrderedDict来指定每个module的名字,而不是采用默认的命名方式(按序号 0,1,2,3...)。例子如下:
from collections import OrderedDict
class net_seq(nn.Module): def __init__(self): super(net_seq, self).__init__() self.seq = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(1,20,5)), ('relu1', nn.ReLU()), ('conv2', nn.Conv2d(20,64,5)), ('relu2', nn.ReLU()) ])) def forward(self, x): return self.seq(x)net_seq = net_seq()print(net_seq)#net_seq(# (seq): Sequential(# (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (relu1): ReLU()# (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (relu2): ReLU()# )#)
nn.ModuleList
nn.ModuleList,它是一个储存不同 module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 (比如 nn.Conv2d, nn.Linear 之类的) 加到这个 list 里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于一般的 list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。若使用python的list,则会出问题。下面看一个例子:
class net_modlist(nn.Module): def __init__(self): super(net_modlist, self).__init__() self.modlist = nn.ModuleList([ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ])def forward(self, x): for m in self.modlist: x = m(x) return x
net_modlist = net_modlist()print(net_modlist)#net_modlist(# (modlist): ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )#)
for param in net_modlist.parameters(): print(type(param.data), param.size())#<class 'torch.Tensor'> torch.Size([20, 1, 5, 5])#<class 'torch.Tensor'> torch.Size([20])#<class 'torch.Tensor'> torch.Size([64, 20, 5, 5])#<class 'torch.Tensor'> torch.Size([64])
可以看到,这个网络权重 (weithgs) 和偏置 (bias) 都在这个网络之内。接下来看看另一个作为对比的网络,它使用 Python 自带的 list:
class net_modlist(nn.Module): def __init__(self): super(net_modlist, self).__init__() self.modlist = [ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ]def forward(self, x): for m in self.modlist: x = m(x) return x
net_modlist = net_modlist()print(net_modlist)#net_modlist()for param in net_modlist.parameters(): print(type(param.data), param.size())#None
显然,使用 Python 的 list 添加的卷积层和它们的 parameters 并没有自动注册到我们的网络中。当然,我们还是可以使用 forward 来计算输出结果。但是如果用其实例化的网络进行训练的时候,因为这些层的parameters不在整个网络之中,所以其网络参数也不会被更新,也就是无法训练。
三、nn.Sequential与nn.ModuleList的区别
不同点1:
nn.Sequential内部实现了forward函数,因此可以不用写forward函数。而nn.ModuleList则没有实现内部forward函数。
对于nn.Sequential:
#例1:这是来自官方文档的例子seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() )print(seq)# Sequential(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()# )
#对上述seq进行输入input = torch.randn(16, 1, 20, 20)print(seq(input))#torch.Size([16, 64, 12, 12])
#例2:或者继承nn.Module类的话,就要写出forward函数class net1(nn.Module): def __init__(self): super(net1, self).__init__() self.seq = nn.Sequential( nn.Conv2d(1,20,5), nn.ReLU(), nn.Conv2d(20,64,5), nn.ReLU() ) def forward(self, x): return self.seq(x)#注意:按照下面这种利用for循环的方式也是可以得到同样结果的 #def forward(self, x): # for s in self.seq: # x = s(x) # return x#对net1进行输入input = torch.randn(16, 1, 20, 20)net1 = net1()print(net1(input).shape)#torch.Size([16, 64, 12, 12])
而对于nn.ModuleList:
#例1:若按照下面这么写,则会产生错误modlist = nn.ModuleList([ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ])print(modlist)#ModuleList(# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))# (1): ReLU()# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))# (3): ReLU()#)
input = torch.randn(16, 1, 20, 20)print(modlist(input))#产生NotImplementedError
#例2:写出forward函数class net2(nn.Module): def __init__(self): super(net2, self).__init__() self.modlist = nn.ModuleList([ nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU() ])#这里若按照这种写法则会报NotImplementedError错 #def forward(self, x): # return self.modlist(x)#注意:只能按照下面利用for循环的方式 def forward(self, x): for m in self.modlist: x = m(x) return x
input = torch.randn(16, 1, 20, 20)net2 = net2()print(net2(input).shape)#torch.Size([16, 64, 12, 12])
如果完全直接用 nn.Sequential,确实是可以的,但这么做的代价就是失去了部分灵活性,不能自己去定制 forward 函数里面的内容了。
一般情况下 nn.Sequential 的用法是来组成卷积块 (block),然后像拼积木一样把不同的 block 拼成整个网络,让代码更简洁,更加结构化。
不同点2:
nn.Sequential可以使用OrderedDict对每层进行命名,上面已经阐述过了;
不同点3:
nn.Sequential里面的模块按照顺序进行排列的,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。见下面代码:
class net3(nn.Module): def __init__(self): super(net3, self).__init__() self.linears = nn.ModuleList([nn.Linear(10,20), nn.Linear(20,30), nn.Linear(5,10)]) def forward(self, x): x = self.linears[2](x) x = self.linears[0](x) x = self.linears[1](x)return x
net3 = net3()print(net3)#net3(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=20, bias=True)# (1): Linear(in_features=20, out_features=30, bias=True)# (2): Linear(in_features=5, out_features=10, bias=True)# )#)
input = torch.randn(32, 5)print(net3(input).shape)#torch.Size([32, 30])
根据 net5 的结果,可以看出来这个 ModuleList 里面的顺序不能决定什么,网络的执行顺序是根据 forward 函数来决定的。若将forward函数中几行代码互换,使输入输出之间的大小不一致,则程序会报错。此外,为了使代码具有更高的可读性,最好把ModuleList和forward中的顺序保持一致。
不同点4:
有的时候网络中有很多相似或者重复的层,我们一般会考虑用 for 循环来创建它们,而不是一行一行地写,比如:
layers = [nn.Linear(10, 10) for i in range(5)]
那么这里我们使用ModuleList:
class net4(nn.Module): def __init__(self): super(net4, self).__init__() layers = [nn.Linear(10, 10) for i in range(5)] self.linears = nn.ModuleList(layers)def forward(self, x): for layer in self.linears: x = layer(x) return x
net = net4()print(net)# net4(# (linears): ModuleList(# (0): Linear(in_features=10, out_features=10, bias=True)# (1): Linear(in_features=10, out_features=10, bias=True)# (2): Linear(in_features=10, out_features=10, bias=True)# )# )
参考:
官方文档: Container(https://pytorch.org/docs/stable/nn.html#containers)
PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景(https://zhuanlan.zhihu.com/p/64990232)
点个在看 paper不断!
详解PyTorch中的ModuleList和Sequential相关推荐
- 收藏 | 详解PyTorch中的ModuleList和Sequential
点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨小占同学@知乎(已授权) 来源丨https://zhuanlan.zhihu.com/p/7520666 ...
- python中squeeze函数_详解pytorch中squeeze()和unsqueeze()函数介绍
squeeze的用法主要就是对数据的维度进行压缩或者解压. 先看torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度,比如是一行或者一列这种,一个一行三列(1,3)的 ...
- tensor torch 构造_详解Pytorch中的网络构造
背景 在PyTroch框架中,如果要自定义一个Net(网络,或者model,在本文中,model和Net拥有同样的意思),通常需要继承自nn.Module然后实现自己的layer.比如,在下面的示例中 ...
- 详解Pytorch中的requires_grad、叶子节点与非叶子节点、with torch.no_grad()、model.eval()、model.train()、BatchNorm层
requires_grad requires_grad意为是否需要计算梯度 使用backward()函数反向传播计算梯度时,并不是计算所有tensor的梯度,只有满足下面条件的tensor的梯度才会被 ...
- 详解pytorch中的常见的Tensor数据类型以及类型转换
文章目录 概览 Tensor的构建 补充 类型转换 附录 概览 本文主要讲pytorch中的常见的Tensor数据类型,例如:float32,float64,int32,int64.构造他们分别使用如 ...
- 详解PyTorch中的contiguous
目录 前言 PyTorch中的is_contiguous是啥? 行优先 为什么需要 contiguous ? 为什么不在view 方法中默认调用contiguous方法? 前言 contiguous ...
- 一文详解Pytorch中的优化器Optimizer
本文将结合源码与代码示例详细解析Optimizer的五大方法. 1. 前言 优化器主要用在模型训练阶段,用于更新模型中可学习的参数.torch.optim提供了多种优化器接口,比如Adam.RAdam ...
- 详解PyTorch中的copy_()函数、detach()函数、detach_()函数和clone()函数
参考链接: copy_(src, non_blocking=False) → Tensor 参考链接: detach() 参考链接: detach_() 参考链接: clone() → Tensor ...
- YOLOV1详解——Pytorch版
YOLOV1详解--Pytorch版 1 YOLOV1 1 数据处理 1.1 数据集划分 1.2 读入xml文件 1.3 数据增强 2 训练 2.1 Backbone 2.2 Loss 2.3 tra ...
最新文章
- Java基础篇:泛型
- 电厂运维的cis数据_【面向运行人员的电站智能运维管家系统】
- 瞬发大量并发连接 造成MySQL连接不响应的分析
- MVC4.0网站发布和部署到IIS7.0上的方法
- BZOJ 1717 [Usaco2006 Dec]Milk Patterns 产奶的模式(后缀数组)
- mysql根据注释搜索表
- linux mpeg-4,嵌入式MPEG-4解码系统的设计与实现,嵌入式MPEG-4解码系统,嵌入式Linux,视频码流,P...
- 优秀!90后博士做出世界级成果,发32篇SCI,四拿国奖,两获国际荣誉
- C. Diverse Permutation(Codeforces Round #275(div2)
- js控制ul的显示隐藏,对象的有效范围
- 实现随着 下拉菜单中 选中值的变化 周边的值也也跟着变化。(使用【 VLOOKUP 】 函数)
- 使用timerfd实现定时器功能
- Docker最新教程 (视频地址https://www.bilibili.com/video/BV1og4y1q7M4)
- 数据库实验八--OpenGauss(数据库的备份与恢复)
- 使用python控制其他软件运行_Python实现运行其他程序的四种方式实例分析
- php解决时间超过2038年
- u盘写保护怎么才能真正去掉
- Sqlserver与Oracle 10g数据类型对照
- Golang环境及revel框架在Linux下的安装
- 解析explain执行计划