本文中,我们看一看如何构建模型。
创造一个模型分两步:构建模型和权值初始化。而构建模型又有“定义单独的网络层”和“把它们拼在一起”两步。

1. torch.nn.Module

torch.nn.Module 是所有 torch.nn 中的类的父类。我们来看一个非常简单的神经网络:

class SimpleNet(nn.Module):def __init__(self, x):super(SimpleNet,self).__init__()self.fc = nn.Linear(x.shape[0], 1)def forward(self, x):x = self.fc(x)return x

我们随便喂给它一个张量,打印它的网络:

>>> simpleNet = SimpleNet(torch.tensor((10, 2)))
>>> print(simpleNet)
SimpleNet((fc): Linear(in_features=2, out_features=1, bias=True)
)

所有自定义的神经网络都要继承 torch.nn.Module。定义单独的网络层在 __init__ 函数中实现,把定义好的网络层拼接在一起在 forward 函数中实现。网络类有两个重要的函数:parameters 存储了模型的权重;modules 存储了模型的结构。

>>> list(simpleNet.modules())
[SimpleNet((fc): Linear(in_features=2, out_features=1, bias=True)),Linear(in_features=2, out_features=1, bias=True)]>>> list(simpleNet.parameters())
[Parameter containing:tensor([[ 0.1533, -0.2574]], requires_grad=True),Parameter containing:tensor([-0.1589], requires_grad=True)]

2. torch.nn.Sequential

这是一个序列容器,既可以放在模型外面单独构建一个模型,也可以放在模型里面成为模型的一部分。

# 单独成为一个模型
model1 = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())
# 成为模型的一部分
class LeNetSequential(nn.Module):def __init__(self, classes):super(LeNetSequential, self).__init__()self.features = nn.Sequential(nn.Conv2d(3, 6, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, 5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),)self.classifier = nn.Sequential(nn.Linear(16*5*5, 120),nn.ReLU(),nn.Linear(120, 84),nn.ReLU(),nn.Linear(84, classes),)def forward(self, x):x = self.features(x)x = x.view(x.size()[0], -1)x = self.classifier(x)return x

放在模型里面的话,模型还是需要 __init__forward 函数。

这样构建出来的模型的层没有名字:

>>> model2 = nn.Sequential(
...           nn.Conv2d(1,20,5),
...           nn.ReLU(),
...           nn.Conv2d(20,64,5),
...           nn.ReLU()
...         )
>>> model2
Sequential((0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(1): ReLU()(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(3): ReLU()
)

为了方便区分不同的层,我们可以使用 collections 里的 OrderedDict 函数:

>>> from collections import OrderedDict
>>> model3 = nn.Sequential(OrderedDict([
...           ('conv1', nn.Conv2d(1,20,5)),
...           ('relu1', nn.ReLU()),
...           ('conv2', nn.Conv2d(20,64,5)),
...           ('relu2', nn.ReLU())
...         ]))
>>> model3
Sequential((conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(relu1): ReLU()(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(relu2): ReLU()
)

3. torch.nn.ModuleList

将网络层存储进一个列表,可以使用列表生成式快速生成网络,生成的网络层可以被索引,也拥有列表的方法 appendextendinsert

>>> class MyModule(nn.Module):
...     def __init__(self):
...         super(MyModule, self).__init__()
...         self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
...         self.linears.append(nn.Linear(10, 1)) # append
...     def forward(self, x):
...         for i, l in enumerate(self.linears):
...             x = self.linears[i // 2](x) + l(x)
...         return x>>> myModeul = MyModule()
>>> myModeul
MyModule((linears): ModuleList((0): Linear(in_features=10, out_features=10, bias=True)(1): Linear(in_features=10, out_features=10, bias=True)(2): Linear(in_features=10, out_features=10, bias=True)(3): Linear(in_features=10, out_features=10, bias=True)(4): Linear(in_features=10, out_features=10, bias=True)(5): Linear(in_features=10, out_features=10, bias=True)(6): Linear(in_features=10, out_features=10, bias=True)(7): Linear(in_features=10, out_features=10, bias=True)(8): Linear(in_features=10, out_features=10, bias=True)(9): Linear(in_features=10, out_features=10, bias=True)(10): Linear(in_features=10, out_features=1, bias=True) # append 进的层)
)

4. torch.nn.ModuleDict

这个函数与上面的 torch.nn.Sequential(OrderedDict(...)) 的行为非常类似,并且拥有 keysvaluesitemspopupdate 等词典的方法:

>>> class MyDictDense(nn.Module):
...     def __init__(self):
...         super(MyDictDense, self).__init__()
...         self.params = nn.ModuleDict({...                 'linear1': nn.Linear(512, 128),
...                 'linear2': nn.Linear(128, 32)
...         })
...         self.params.update({'linear3': nn.Linear(32, 10)}) # 添加层...     def forward(self, x, choice='linear1'):
...         return torch.mm(x, self.params[choice])>>> net = MyDictDense()
>>> print(net)
MyDictDense((params): ModuleDict((linear1): Linear(in_features=512, out_features=128, bias=True)(linear2): Linear(in_features=128, out_features=32, bias=True)(linear3): Linear(in_features=32, out_features=10, bias=True))
)>>> print(net.params.keys())
odict_keys(['linear1', 'linear2', 'linear3'])>>> print(net.params.items())
odict_items([('linear1', Linear(in_features=512, out_features=128, bias=True)), ('linear2', Linear(in_features=128, out_features=32, bias=True)), ('linear3', Linear(in_features=32, out_features=10, bias=True))])

欢迎关注我的微信公众号“花解语 NLP”:

深度之眼 PyTorch 训练营第 4 期(5):构建模型 torch.nn.Module相关推荐

  1. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

  2. 深度之眼Pytorch打卡(十三):Pytorch全连接神经网络部件——线性层、非线性激活层与Dropout层(即全连接层、常用激活函数与失活 )

    前言   无论是做分类还是做回归,都主要包括数据.模型.损失函数和优化器四个部分.数据部分在上一篇笔记中已经基本完结,从这篇笔记开始,将学习深度学习模型.全连接网络MLP是最简单.最好理解的神经网络, ...

  3. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

  4. 关于pytorch官网教程中的What is torch.nn really?(三)

    文章目录 Switch to CNN `nn.Sequential` Wrapping `DataLoader` Using your GPU Closing thoughts 原文在这里. 因为MN ...

  5. pytorch torch.nn.Module.register_buffer

    API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...

  6. pytorch中的神经网络模块基础类——torch.nn.Module

    1.torch.nn.Module概要 pytorch官网对torch.nn.Module的描述如下. torch.nn.Module是所有的神经网络模块的基类,且所有的神经网络模块都可以包含其他的子 ...

  7. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

  8. 【深度之眼PyTorch框架班第五期】作业打卡01:PyTorch简介及环境配置;PyTorch基础数据结构——张量

    文章目录 任务名称 任务简介 详细说明 作业 1. 安装anaconda,pycharm, CUDA+CuDNN(可选),虚拟环境,pytorch,并实现hello pytorch查看pytorch的 ...

  9. 深度学习 实验六 卷积神经网络(1)卷积 torch.nn

    目录 第5章 卷积神经网络 5.1 卷积 5.1.1 二维卷积运算 5.1.2 二维卷积算子 5.1.3 二维卷积的参数量和计算量 5.1.4 感受野 5.1.5 卷积的变种 5.1.5.1 步长(S ...

最新文章

  1. JavaScript设计模式系列四之外观模式(附案例源码)
  2. 如何将Windows 10帐户还原为本地帐户(在Windows Store劫持它之后)
  3. Python获取电脑CPU序列号、主板序列号、BIOS序列号、硬盘序列号列表、网卡MAC地址
  4. Oracle创建表空间及用户
  5. erlang 读取confg文件异常 could not start kernel pid error in config file
  6. EtherDream:在 JavaScript 中使用 C 程序
  7. 斐讯k2路由器刷华硕固件做桥接中继
  8. 三面,字节跳动电商Java面经(已过)
  9. linux下dbf是什么文件,dbf是什么文件?dbf文件怎么读取
  10. 彻底解决Vista狂读硬盘,硬盘灯狂闪的问题
  11. 塔望食业洞察|人参饮料行业环境 市场现状及发展思考
  12. AIGC神器 Midjourney 强势更新!逼真到令人发指!文心一言紧跟其后
  13. linux调试工具ipcs的深入分析
  14. 一对数的和等于给定的数
  15. 月赚千刀的国外联盟Lead项目,实操拆解
  16. python 爬取淘宝第一弹(淘宝登录)
  17. 物联网交互创新的探讨
  18. 看看好妻子什么样面相
  19. python rot13解密_用Python实现的rot13编解码
  20. bitwarden自建服务器,自建bitwarden的密码服务

热门文章

  1. Android SOS功能模块开发
  2. linux下su,su-,sudo命令的区别和用法
  3. 微信清除缓存数据方法
  4. 【解题报告】2017-2018 8th BSUIR Open Programming Contest-C Good subset 线性基+线段树
  5. 三极管电路必懂的几种分析方法
  6. android 程序根据重力感应切换程序的方向
  7. ASP.NET MVC 音乐商店 - 0 概览
  8. android初学者_初学者:如何在Android设备上的打开的应用程序之间切换
  9. 斐讯(Phicomm)空气检测仪(悟空 M1)通过 EasyLink
  10. tiny core linux网络连接,用Tiny Core Linux打造纯Firefox上网系统(概要)