一、继承Module类来构造模型

Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,如下所示,继承这个类需要重载Module类中的__init__函数和forward函数,它们分别用于创建模型参数和定义前向计算(正向传播)。

继承时无需定义反向传播函数,因为系统将通过自动求梯度而自动生成反向传播所需的backward函数。

如下的代码定义了一个MLP类:

import torch
from torch import nnclass MLP(nn.Module):def __init__(self,**kwargs):super(MLP,self).__init__(**kwargs)self.hidden=nn.Linear(784,256)self.act=nn.ReLU()self.output=nn.Linear(256,10)def forward(self,x):a=self.act(self.hidden(x))return self.output(a)

我们可以实例化MLP类得到模型变量net。下面的代码初始化net并传入输入数据x进行一次前向计算。其中,net(X)会调用MLP继承自Module类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。

x=torch.rand(2,784)
net=MLP()
print(net)
net(X)

输出:

MLP((hidden): Linear(in_features=784, out_features=256, bias=True)(act): ReLU()(output): Linear(in_features=256, out_features=10, bias=True)
)

二、Module的子类

Module类是一个通用的部件,Pytorch还实现了继承自Module的可以方便构建模型的类,如Sequential、ModuleList和ModuleDict等等。

2.1 Sequential类

当模型的前向计算为简单串联各个层的计算时, Sequential 类可以通过更加简单的⽅式定义模型。这正是 Sequential 类的⽬的:它可以接收⼀个⼦模块的有序字典(OrderedDict)或者⼀系列⼦模块作为参数来逐⼀添加 Module 的实例,⽽模型的前向计算就是将这些实例按添加的顺序逐⼀计算。

2.2 ModuleList类

ModuleList接收一个子模块的列表作为输入,然后也可以类似List那样进行append和extend操作:

net=nn.ModuleList([nn.Linear(784,256),nn.ReLU()])
net.append(nn.Linear(256,10))
print(net[-1])
print(net)

输出:

ModuleList((0): Linear(in_features=784, out_features=256, bias=True)(1): ReLU()(2): Linear(in_features=256, out_features=10, bias=True)
)

2.3 ModuleDict类

ModuleDict接收一个子模块的字典作为输入,然后也可以类似字典那样进行添加访问操作:

net=nn.ModuleDict({'linear':nn.Linear(784,256),'act':nn.ReLU(),
})
net['output']==nn.Linear(256,10)
print(net['linear'])
print(net.output)
print(net)

输出:

Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict((linear): Linear(in_features=784, out_features=256, bias=True)(act): ReLU()(output): Linear(in_features=256, out_features=10, bias=True)
)

三、构造复杂的模型

虽然上面介绍的这些类可以使模型构造更加简单,且不需要定义forward函数,但直接继承Module类可以极大地拓展模型构造的灵活性。下面我们创建一个稍微复杂些的网络FancyMLP。在这个网络中,我们通过get_constant函数创建训练中不被迭代的参数,即常数参数。在前向计算中,除了使用创建的常数参数外,我们还使用Tensor的函数和Python的控制流,并多次调用相同的层。

class FancyMLP(nn.Module):def __init__(self,**kwargs):super(FancyMLP,self).__init__(**kwargs)#这里定义了不可训练的参数(常数参数)self.rand_weight=torch.rand((20,20),requires_grad=False)self.linear=nn.Linear(20,20)def forward(self,x):x=self.linear(x)#使用创建的常数参数,以及nn.functional中的relu函数以及mm函数x=nn.functional.relu(torch.mm(x,self.rand_weight.data)+1)#复用全连接层,等价于两个全连接层共享参数x=self.linear(x)#控制流,这里我们需要调用item函数来返回标量进行比较while x.norm().item()>1:x/=2if x.norm().item()<0.8:x*=10return x.sum()

测试:

X=torch.rand(2,20)
net=FancyMLP()
print(net)
print(net(X))

输出:

FancyMLP((linear): Linear(in_features=20, out_features=20, bias=True)
)
tensor(1.5884, grad_fn=<SumBackward0>)

因为FancyMLP和Sequential类都是Module类的子类,所以我们可以嵌套调用它们。

class NestMLP(nn.Module):def __init__(self,**kwargs):super(NestMLP,self).__init__(**kwargs)self.net=nn.Sequential(nn.Linear(40,30),nn.ReLU())def forward(self,x):return self.net(x)net=nn.Sequential(NestMLP(),nn.Linear(30,20),FancyMLP())X=torch.rand(2,40)
print(net)
print(net(X))

输出:

Sequential((0): NestMLP((net): Sequential((0): Linear(in_features=40, out_features=30, bias=True)(1): ReLU()))(1): Linear(in_features=30, out_features=20, bias=True)(2): FancyMLP((linear): Linear(in_features=20, out_features=20, bias=True))
)
tensor(4.6094, grad_fn=<SumBackward0>)

Pytorch模型构造方法相关推荐

  1. TensorFlow与PyTorch模型部署性能比较

    TensorFlow与PyTorch模型部署性能比较 前言 2022了,选 PyTorch 还是 TensorFlow?之前有一种说法:TensorFlow 适合业界,PyTorch 适合学界.这种说 ...

  2. PyTorch 深度剖析:如何保存和加载PyTorch模型?

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨科技猛兽 编辑丨极市平台 导读 本文详解了PyTorch 模型 ...

  3. TensorRT和PyTorch模型的故事

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 作者丨伯恩legacy 来源丨https://zhuanlan.zh ...

  4. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  5. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

  6. 在C++平台上部署PyTorch模型流程+踩坑实录

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文主要讲解如何将pytorch的模型部署到c++平台上的模 ...

  7. 如何使用TensorRT对训练好的PyTorch模型进行加速?

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨伯恩legacy@知乎 来源丨https://zhuanlan.zhihu.com/p/8831 ...

  8. pytorch模型转onnx-量化rknn(bisenet)

    1.pytorch模型转化onnx 先把pytorch的.pth模型转成onnx,例如我这个是用Bisenet转的,执行export_onnx.py import argparse import os ...

  9. sklearn与pytorch模型的保存与读取

    当我们花了很长时间训练了一个模型,需要用该模型做其他事情(比如迁移学习),或者我们想把自己的机器学习模型分享出去的时候,我们这时候需要将我们的ML模型持久化到硬盘中去. 1.sklearn中模型的保存 ...

  10. 浅谈pytorch 模型 .pt, .pth, .pkl的区别及模型保存方式 pth中的路径加载使用

    首先xxx.pth文件里面会书写一些路径,一行一个. 将xxx.pth文件放在特定位置,则可以让python在加载模块时,读取xxx.pth中指定的路径. Python客栈送红包.纸质书 有时,在用i ...

最新文章

  1. SAP MM 采购申请评估价格不能为0?
  2. SQL必知必会——插入数据(十五)
  3. JS键盘字母相应的keyCode值
  4. php 反射 调用私有方法,PHP通过反射方法调用执行类中的私有方法
  5. python软件下载中文版-PyCharm中文版
  6. html论坛页面怎么做_用php怎么做一个简单的留言页面?
  7. Ubuntu4.04 安装Mesos
  8. FZU 1502 Letter Deletion
  9. 电压比较器的介绍和工作原理
  10. mac trace traceroute 简要使用
  11. 10年软件测试工程师 常用八大测试用例设计方法
  12. 经纬财富:亳州炒白银操作方法,谨防亏损
  13. Android resource compilation failed
  14. Python collections模块之Counter()详解
  15. 如何PDF转Excel,手机和电脑都能用的方法
  16. 软件测试中的黑盒测试和白盒测试和灰盒测试
  17. 模拟电子技术-模拟集成电路
  18. PHP打包下载多文件
  19. 半角与全角、简繁体中文字符串互相转化
  20. 如何在 macOS 下安装 QGIS

热门文章

  1. 广色域图片Android,Android Q将支持广色域照片
  2. 实习踩坑之路:一个诡异的SQL?PageHelper莫名多了一个Limit子句,导致SQL执行错误?
  3. JavaSE学习--集合02
  4. android+祖玛游戏源码,Flash祖玛游戏源代码
  5. 基于 SurfaceView 的直播点亮心形效果
  6. updatepanel失效怎么办_[转]jquery与updatepanel二次失效问题解决方案-阿里云开发者社区...
  7. 查看linux文件的日期格式,5个在Linux中管理文件类型和系统时间的有用命令
  8. 红旗Linux职称考试模块,计算机职称考试红旗Linux Desktop 6.0考试大纲
  9. matlab中求解非线性方程组的函数,利用solve函数求解非线性方程组的问题
  10. java案例代码12--随机码--静态类的使用