深度之眼 PyTorch 训练营第 4 期(5):构建模型 torch.nn.Module
本文中,我们看一看如何构建模型。
创造一个模型分两步:构建模型和权值初始化。而构建模型又有“定义单独的网络层”和“把它们拼在一起”两步。
1. torch.nn.Module
torch.nn.Module
是所有 torch.nn
中的类的父类。我们来看一个非常简单的神经网络:
class SimpleNet(nn.Module):def __init__(self, x):super(SimpleNet,self).__init__()self.fc = nn.Linear(x.shape[0], 1)def forward(self, x):x = self.fc(x)return x
我们随便喂给它一个张量,打印它的网络:
>>> simpleNet = SimpleNet(torch.tensor((10, 2)))
>>> print(simpleNet)
SimpleNet((fc): Linear(in_features=2, out_features=1, bias=True)
)
所有自定义的神经网络都要继承 torch.nn.Module
。定义单独的网络层在 __init__
函数中实现,把定义好的网络层拼接在一起在 forward
函数中实现。网络类有两个重要的函数:parameters
存储了模型的权重;modules
存储了模型的结构。
>>> list(simpleNet.modules())
[SimpleNet((fc): Linear(in_features=2, out_features=1, bias=True)),Linear(in_features=2, out_features=1, bias=True)]>>> list(simpleNet.parameters())
[Parameter containing:tensor([[ 0.1533, -0.2574]], requires_grad=True),Parameter containing:tensor([-0.1589], requires_grad=True)]
2. torch.nn.Sequential
这是一个序列容器,既可以放在模型外面单独构建一个模型,也可以放在模型里面成为模型的一部分。
# 单独成为一个模型
model1 = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())
# 成为模型的一部分
class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x
放在模型里面的话,模型还是需要 __init__
和 forward
函数。
这样构建出来的模型的层没有名字:
>>> model2 = nn.Sequential(
... nn.Conv2d(1,20,5),
... nn.ReLU(),
... nn.Conv2d(20,64,5),
... nn.ReLU()
... )
>>> model2
Sequential((0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(1): ReLU()(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(3): ReLU()
)
为了方便区分不同的层,我们可以使用 collections
里的 OrderedDict
函数:
>>> from collections import OrderedDict
>>> model3 = nn.Sequential(OrderedDict([
... ('conv1', nn.Conv2d(1,20,5)),
... ('relu1', nn.ReLU()),
... ('conv2', nn.Conv2d(20,64,5)),
... ('relu2', nn.ReLU())
... ]))
>>> model3
Sequential((conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(relu1): ReLU()(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(relu2): ReLU()
)
3. torch.nn.ModuleList
将网络层存储进一个列表,可以使用列表生成式快速生成网络,生成的网络层可以被索引,也拥有列表的方法 append
,extend
或 insert
。
>>> class MyModule(nn.Module):
... def __init__(self):
... super(MyModule, self).__init__()
... self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
... self.linears.append(nn.Linear(10, 1)) # append
... def forward(self, x):
... for i, l in enumerate(self.linears):
... x = self.linears[i // 2](x) + l(x)
... return x>>> myModeul = MyModule()
>>> myModeul
MyModule((linears): ModuleList((0): Linear(in_features=10, out_features=10, bias=True)(1): Linear(in_features=10, out_features=10, bias=True)(2): Linear(in_features=10, out_features=10, bias=True)(3): Linear(in_features=10, out_features=10, bias=True)(4): Linear(in_features=10, out_features=10, bias=True)(5): Linear(in_features=10, out_features=10, bias=True)(6): Linear(in_features=10, out_features=10, bias=True)(7): Linear(in_features=10, out_features=10, bias=True)(8): Linear(in_features=10, out_features=10, bias=True)(9): Linear(in_features=10, out_features=10, bias=True)(10): Linear(in_features=10, out_features=1, bias=True) # append 进的层)
)
4. torch.nn.ModuleDict
这个函数与上面的 torch.nn.Sequential(OrderedDict(...))
的行为非常类似,并且拥有 keys
,values
,items
,pop
,update
等词典的方法:
>>> class MyDictDense(nn.Module):
... def __init__(self):
... super(MyDictDense, self).__init__()
... self.params = nn.ModuleDict({... 'linear1': nn.Linear(512, 128),
... 'linear2': nn.Linear(128, 32)
... })
... self.params.update({'linear3': nn.Linear(32, 10)}) # 添加层... def forward(self, x, choice='linear1'):
... return torch.mm(x, self.params[choice])>>> net = MyDictDense()
>>> print(net)
MyDictDense((params): ModuleDict((linear1): Linear(in_features=512, out_features=128, bias=True)(linear2): Linear(in_features=128, out_features=32, bias=True)(linear3): Linear(in_features=32, out_features=10, bias=True))
)>>> print(net.params.keys())
odict_keys(['linear1', 'linear2', 'linear3'])>>> print(net.params.items())
odict_items([('linear1', Linear(in_features=512, out_features=128, bias=True)), ('linear2', Linear(in_features=128, out_features=32, bias=True)), ('linear3', Linear(in_features=32, out_features=10, bias=True))])
欢迎关注我的微信公众号“花解语 NLP”:
深度之眼 PyTorch 训练营第 4 期(5):构建模型 torch.nn.Module相关推荐
- Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function
参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...
- 深度之眼Pytorch打卡(十三):Pytorch全连接神经网络部件——线性层、非线性激活层与Dropout层(即全连接层、常用激活函数与失活 )
前言 无论是做分类还是做回归,都主要包括数据.模型.损失函数和优化器四个部分.数据部分在上一篇笔记中已经基本完结,从这篇笔记开始,将学习深度学习模型.全连接网络MLP是最简单.最好理解的神经网络, ...
- 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)
前言 前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...
- 关于pytorch官网教程中的What is torch.nn really?(三)
文章目录 Switch to CNN `nn.Sequential` Wrapping `DataLoader` Using your GPU Closing thoughts 原文在这里. 因为MN ...
- pytorch torch.nn.Module.register_buffer
API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...
- pytorch中的神经网络模块基础类——torch.nn.Module
1.torch.nn.Module概要 pytorch官网对torch.nn.Module的描述如下. torch.nn.Module是所有的神经网络模块的基类,且所有的神经网络模块都可以包含其他的子 ...
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()
Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...
- 【深度之眼PyTorch框架班第五期】作业打卡01:PyTorch简介及环境配置;PyTorch基础数据结构——张量
文章目录 任务名称 任务简介 详细说明 作业 1. 安装anaconda,pycharm, CUDA+CuDNN(可选),虚拟环境,pytorch,并实现hello pytorch查看pytorch的 ...
- 深度学习 实验六 卷积神经网络(1)卷积 torch.nn
目录 第5章 卷积神经网络 5.1 卷积 5.1.1 二维卷积运算 5.1.2 二维卷积算子 5.1.3 二维卷积的参数量和计算量 5.1.4 感受野 5.1.5 卷积的变种 5.1.5.1 步长(S ...
最新文章
- JavaScript设计模式系列四之外观模式(附案例源码)
- 如何将Windows 10帐户还原为本地帐户(在Windows Store劫持它之后)
- Python获取电脑CPU序列号、主板序列号、BIOS序列号、硬盘序列号列表、网卡MAC地址
- Oracle创建表空间及用户
- erlang 读取confg文件异常 could not start kernel pid error in config file
- EtherDream:在 JavaScript 中使用 C 程序
- 斐讯k2路由器刷华硕固件做桥接中继
- 三面,字节跳动电商Java面经(已过)
- linux下dbf是什么文件,dbf是什么文件?dbf文件怎么读取
- 彻底解决Vista狂读硬盘,硬盘灯狂闪的问题
- 塔望食业洞察|人参饮料行业环境 市场现状及发展思考
- AIGC神器 Midjourney 强势更新!逼真到令人发指!文心一言紧跟其后
- linux调试工具ipcs的深入分析
- 一对数的和等于给定的数
- 月赚千刀的国外联盟Lead项目,实操拆解
- python 爬取淘宝第一弹(淘宝登录)
- 物联网交互创新的探讨
- 看看好妻子什么样面相
- python rot13解密_用Python实现的rot13编解码
- bitwarden自建服务器,自建bitwarden的密码服务
热门文章
- Android SOS功能模块开发
- linux下su,su-,sudo命令的区别和用法
- 微信清除缓存数据方法
- 【解题报告】2017-2018 8th BSUIR Open Programming Contest-C Good subset 线性基+线段树
- 三极管电路必懂的几种分析方法
- android 程序根据重力感应切换程序的方向
- ASP.NET MVC 音乐商店 - 0 概览
- android初学者_初学者:如何在Android设备上的打开的应用程序之间切换
- 斐讯(Phicomm)空气检测仪(悟空 M1)通过 EasyLink
- tiny core linux网络连接,用Tiny Core Linux打造纯Firefox上网系统(概要)