nn.Module是在pytorch使用非常广泛的类,搭建网络基本都需要用到这个。

当我们搭建自己的网络时,可以继承官方写好的nn.Module模块,为什么要用这个呢?好处如下:

nn.Module作用

  • 1.可以提供一些现成的基本模块比如:
  • 2. 容器
  • 3.参数管理
  • 4. 所有modules的节点 孩子节点都是直系的
  • 5.to(device)
  • 6.保存和加载模型
  • 7.训练/测试
  • 8.实现自己的类
    • 8.1举一个自己写的线性层的例子

1.可以提供一些现成的基本模块比如:

Linear、ReLU、Sigmoid、Conv2d、Dropout

不用自己一个一个的写这些函数了,这也是为什么我们用框架的原因之一吧。

2. 容器

比如我们经常用到的 nn.Sequential(),顾名思义,将网络模块封装在一个容器中,可以方面网络搭建
如下面一个例子:

class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(),nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)

3.参数管理

参数名字可以自动生成(想想如果自己去命名,百万参数的网络没法搭建),然后这些参数都可以传到优化器里面去优化

4. 所有modules的节点 孩子节点都是直系的

class BasicNet(nn.Module):def __init__(self):super(BasicNet, self).__init__()self.net = nn.Linear(4, 3)def forward(self, x):return self.net(x)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(BasicNet(),nn.ReLU(),nn.Linear(3, 2))def forward(self, x):return self.net(x)

比如上面的代码,我们可以看出Net网络中有5个孩子节点:nn.Sequential,BasicNet, nn.ReLU,nn.Linear,BasicNet里面的nn.Linear

5.to(device)

nn.Module还有一个功能是将某个网络所有成员、函数、操作都搬移到GPU上面。
采用代码如下:

    device = torch.device('cuda')net = Net()net.to(device)

上面device代表当前的设备是GPU还是CPU,需要注意的是为什么我们不写

net = net.to(device)

其实效果是一样的,采用nn.Module模块,net加上.to(device),还是net。如果是变量则不是一样的,即如果对于tensor bias,那么biasbias.to(device)不是一样的,则需要重新命名。

6.保存和加载模型

可以方面我们保存和加载模型

加载模型:

net.load_state_dict(torch.load('ckpt.mdl'))

保存模型:

torch.save(net.state_dict(), 'ckpt.mdl')

7.训练/测试

方便训练和测试进行切换,为什么?因为网络中Dropout和BN在训练和测试是不一样的,需要切换
如果不切换效果就会很差,这个是容易犯的一个错误

    net.train()net.eval()

8.实现自己的类

官方给的模块还是基础操作的,如果自己要搭建复杂的操作也容易实现,一个典型的例子就是可以自己设计一个新的损失函数。
下面给出将tensor压平的例子(nn.Module没有这个操作):

class Flatten(nn.Module):def __init__(self):super(Flatten, self).__init__()def forward(self, input):return input.view(input.size(0), -1)class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(),  #自己定义的nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)

Flatten压平的操作则是我们自己构建的类,可以方便后续BasicNet类使用,注意nn.Sequential里面必须是类。
且在上面例子中Flatten不需要接任何参数。

8.1举一个自己写的线性层的例子

class MyLinear(nn.Module):def __init__(self, inp, outp):super(MyLinear, self).__init__()# requires_grad = Trueself.w = nn.Parameter(torch.randn(outp, inp))self.b = nn.Parameter(torch.randn(outp))def forward(self, x):x = x @ self.w.t() + self.breturn x

在上面自己写的线性层 y=wx+by=wx+by=wx+b,可以看出www和bbb必须要使用nn.Parameter这个模块。原因是只用加上了nn.Parameter后,www和bbb才可以用优化器SGD等进行优化。

如果不写nn.Parameter那么则需要写requires_grad = True,还要自己写优化器,就很麻烦。用了Parameter可以方便我们优化网络:

model = MyLinear.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()

pytorch里面nn.Module讲解相关推荐

  1. PyTorch中nn.Module类中__call__方法介绍

    在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...

  2. pytorch:nn.Sequential讲解

    接下来想讲一下参数初始化方式对训练的影响,但是必须要涉及到pytorch的自定义参数初始化,然而参数初始化又包括在不同结构定义中初始化方式,因而先讲一下pytorch中的nn.Sequential n ...

  3. pytorch torch.nn.Module.register_buffer

    API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...

  4. Pytorch中nn.Module中的self.register_buffer解释

    self.register_buffer作用解释 今天遇到了这样一种用法,self.register_buffer('name',Tensor),该方法的作用在于定义一组参数.该组参数在模型训练时不会 ...

  5. Pytorch中nn.Module和nn.Sequencial的简单学习

    文章目录 前言 1.Python 类 2.nn.Module 和 nn.Sequential 2.1 nn.Module 2.1.1 torch.nn.Module类 2.1.2 nn.Sequent ...

  6. PyTorch中nn.Module类简介

    torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中.你的模型也应该继承这个类,主要重载__init__.for ...

  7. pytorch torch.nn.Module

    应用 >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >&g ...

  8. PyTorch nn.Module 一些疑问

    在阅读书籍时,遇到了一些不太理解,或者介绍的不太详细的点. 从代码角度学习理解Pytorch学习框架03: 神经网络模块nn.Module的了解. Pytorch 03: nn.Module模块了解 ...

  9. 详解Pytorch的nn.DataParallel

    ↑ 点击蓝字 关注视学算法 作者丨Mario@知乎 来源丨https://zhuanlan.zhihu.com/p/102697821 编辑丨极市平台 极市导读 在Pytorch中,nn.DataPa ...

最新文章

  1. Hibernate学习总结【比较与Mybatis框架之间的区别】
  2. ETL异构数据源Datax_日期增量同步_13
  3. Django------多表操作
  4. sqlserver中某列转成以逗号连接的字符串及逆转、数据行转列列转行
  5. #大数加减乘除#校赛D题solve
  6. diskgenius创建efi分区_无损分区大小调整
  7. 基于ESP-IDF环境的ESP32-C3开发之No such file or directory
  8. 计算机再带word打不开,电脑word文档打不开怎么办(附:4种解决办法)
  9. 中国撸串指北:13万家烧烤店的吃货最爱
  10. #打卡day1 ROS talker/listener
  11. MyHDL中文手册(六)—— RTL建模
  12. Android动态化方案
  13. TopOpt | 针对99行改进的88行拓扑优化程序完全注释
  14. 人非生而知之者,智之者,孰能无惑,无过
  15. 磐石云服务器_磐石云双十二高防ip、海外服务器限量1元秒杀
  16. 【Js逆向】__jsl_clearance ob混淆加速乐
  17. mapbox-gl开发:集成deck.gl
  18. 基于python+django+vue的大学生租房系统pycharm源码
  19. 轻松学习JavaScript十一:JavaScript基本类型(包含类型转换)和引用类型
  20. 在地图上,如何对一个地区进行矩形划分

热门文章

  1. java 刷新jtextarea,SwingPropertyChangeSupport动态更新JTextArea
  2. for循环中的三语句执行顺序
  3. 数据压缩技术简史---关于实时数据压缩的基础知识
  4. 递归算法到非递归算法的转换
  5. 趋势面法优缺点_趋势面分析法
  6. 分布式session会话Sticky Sessions
  7. 博客设计展示:25个优秀博客设计
  8. Oracle EBS R12 - Application patch可不可以reapply
  9. 微信小程序实现一个简单的倒计时效果
  10. LRUCache的C++实现