pytorch里面nn.Module讲解
nn.Module
是在pytorch使用非常广泛的类,搭建网络基本都需要用到这个。
当我们搭建自己的网络时,可以继承官方写好的nn.Module
模块,为什么要用这个呢?好处如下:
nn.Module作用
- 1.可以提供一些现成的基本模块比如:
- 2. 容器
- 3.参数管理
- 4. 所有modules的节点 孩子节点都是直系的
- 5.to(device)
- 6.保存和加载模型
- 7.训练/测试
- 8.实现自己的类
- 8.1举一个自己写的线性层的例子
1.可以提供一些现成的基本模块比如:
Linear、ReLU、Sigmoid、Conv2d、Dropout
不用自己一个一个的写这些函数了,这也是为什么我们用框架的原因之一吧。
2. 容器
比如我们经常用到的 nn.Sequential()
,顾名思义,将网络模块封装在一个容器中,可以方面网络搭建
如下面一个例子:
class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(),nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)
3.参数管理
参数名字可以自动生成(想想如果自己去命名,百万参数的网络没法搭建),然后这些参数都可以传到优化器里面去优化
4. 所有modules的节点 孩子节点都是直系的
class BasicNet(nn.Module):def __init__(self):super(BasicNet, self).__init__()self.net = nn.Linear(4, 3)def forward(self, x):return self.net(x)class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.net = nn.Sequential(BasicNet(),nn.ReLU(),nn.Linear(3, 2))def forward(self, x):return self.net(x)
比如上面的代码,我们可以看出Net网络中有5个孩子节点:nn.Sequential,BasicNet, nn.ReLU,nn.Linear,BasicNet里面的nn.Linear
5.to(device)
nn.Module
还有一个功能是将某个网络所有成员、函数、操作都搬移到GPU上面。
采用代码如下:
device = torch.device('cuda')net = Net()net.to(device)
上面device代表当前的设备是GPU还是CPU,需要注意的是为什么我们不写
net = net.to(device)
其实效果是一样的,采用nn.Module
模块,net加上.to(device)
,还是net。如果是变量则不是一样的,即如果对于tensor bias
,那么bias
和bias.to(device)
不是一样的,则需要重新命名。
6.保存和加载模型
可以方面我们保存和加载模型
加载模型:
net.load_state_dict(torch.load('ckpt.mdl'))
保存模型:
torch.save(net.state_dict(), 'ckpt.mdl')
7.训练/测试
方便训练和测试进行切换,为什么?因为网络中Dropout和BN在训练和测试是不一样的,需要切换
如果不切换效果就会很差,这个是容易犯的一个错误。
net.train()net.eval()
8.实现自己的类
官方给的模块还是基础操作的,如果自己要搭建复杂的操作也容易实现,一个典型的例子就是可以自己设计一个新的损失函数。
下面给出将tensor压平的例子(nn.Module
没有这个操作):
class Flatten(nn.Module):def __init__(self):super(Flatten, self).__init__()def forward(self, input):return input.view(input.size(0), -1)class TestNet(nn.Module):def __init__(self):super(TestNet, self).__init__()self.net = nn.Sequential(nn.Conv2d(1, 16, stride=1, padding=1),nn.MaxPool2d(2, 2),Flatten(), #自己定义的nn.Linear(1*14*14, 10))def forward(self, x):return self.net(x)
Flatten
压平的操作则是我们自己构建的类,可以方便后续BasicNet类使用,注意nn.Sequential
里面必须是类。
且在上面例子中Flatten
不需要接任何参数。
8.1举一个自己写的线性层的例子
class MyLinear(nn.Module):def __init__(self, inp, outp):super(MyLinear, self).__init__()# requires_grad = Trueself.w = nn.Parameter(torch.randn(outp, inp))self.b = nn.Parameter(torch.randn(outp))def forward(self, x):x = x @ self.w.t() + self.breturn x
在上面自己写的线性层 y=wx+by=wx+by=wx+b,可以看出www和bbb必须要使用nn.Parameter
这个模块。原因是只用加上了nn.Parameter
后,www和bbb才可以用优化器SGD等进行优化。
如果不写nn.Parameter
那么则需要写requires_grad = True
,还要自己写优化器,就很麻烦。用了Parameter
可以方便我们优化网络:
model = MyLinear.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# backprop
optimizer.zero_grad()
loss.backward()
optimizer.step()
pytorch里面nn.Module讲解相关推荐
- PyTorch中nn.Module类中__call__方法介绍
在PyTorch源码的torch/nn/modules/module.py文件中,有一条__call__语句和一条forward语句,如下: __call__ : Callable[-, Any] = ...
- pytorch:nn.Sequential讲解
接下来想讲一下参数初始化方式对训练的影响,但是必须要涉及到pytorch的自定义参数初始化,然而参数初始化又包括在不同结构定义中初始化方式,因而先讲一下pytorch中的nn.Sequential n ...
- pytorch torch.nn.Module.register_buffer
API register_buffer(name: str, tensor: Optional[torch.Tensor], persistent: bool = True) → None 注册buf ...
- Pytorch中nn.Module中的self.register_buffer解释
self.register_buffer作用解释 今天遇到了这样一种用法,self.register_buffer('name',Tensor),该方法的作用在于定义一组参数.该组参数在模型训练时不会 ...
- Pytorch中nn.Module和nn.Sequencial的简单学习
文章目录 前言 1.Python 类 2.nn.Module 和 nn.Sequential 2.1 nn.Module 2.1.1 torch.nn.Module类 2.1.2 nn.Sequent ...
- PyTorch中nn.Module类简介
torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中.你的模型也应该继承这个类,主要重载__init__.for ...
- pytorch torch.nn.Module
应用 >>> for name, param in self.named_parameters(): >>> if name in ['bias']: >&g ...
- PyTorch nn.Module 一些疑问
在阅读书籍时,遇到了一些不太理解,或者介绍的不太详细的点. 从代码角度学习理解Pytorch学习框架03: 神经网络模块nn.Module的了解. Pytorch 03: nn.Module模块了解 ...
- 详解Pytorch的nn.DataParallel
↑ 点击蓝字 关注视学算法 作者丨Mario@知乎 来源丨https://zhuanlan.zhihu.com/p/102697821 编辑丨极市平台 极市导读 在Pytorch中,nn.DataPa ...
最新文章
- Hibernate学习总结【比较与Mybatis框架之间的区别】
- ETL异构数据源Datax_日期增量同步_13
- Django------多表操作
- sqlserver中某列转成以逗号连接的字符串及逆转、数据行转列列转行
- #大数加减乘除#校赛D题solve
- diskgenius创建efi分区_无损分区大小调整
- 基于ESP-IDF环境的ESP32-C3开发之No such file or directory
- 计算机再带word打不开,电脑word文档打不开怎么办(附:4种解决办法)
- 中国撸串指北:13万家烧烤店的吃货最爱
- #打卡day1 ROS talker/listener
- MyHDL中文手册(六)—— RTL建模
- Android动态化方案
- TopOpt | 针对99行改进的88行拓扑优化程序完全注释
- 人非生而知之者,智之者,孰能无惑,无过
- 磐石云服务器_磐石云双十二高防ip、海外服务器限量1元秒杀
- 【Js逆向】__jsl_clearance ob混淆加速乐
- mapbox-gl开发:集成deck.gl
- 基于python+django+vue的大学生租房系统pycharm源码
- 轻松学习JavaScript十一:JavaScript基本类型(包含类型转换)和引用类型
- 在地图上,如何对一个地区进行矩形划分