文章目录

  • 前言
  • 1、Python 类
  • 2、nn.Module 和 nn.Sequential
    • 2.1 nn.Module
      • 2.1.1 torch.nn.Module类
      • 2.1.2 nn.Sequential 类
  • 3.自己的示例

前言

   目前在学习 Pytorch 入门,很久之前进行了自定义模型的编码,但因为学业繁忙,时隔一周再来继续对 Pytorch 的学习,以及之前对 Python 的学习并不扎实,回过头再来看之前的代码需要再次理解,浪费时间,所以写下本博客对知识理解进行记录,也便后续回忆。

1、Python 类

下面介绍一些后续会用到的关于 Python 类的知识点:

  • __init__() 方法是一种特殊的方法,被称为类的构造函数或初始化方法,当创建了这个类的实例时就会调用该方法。
  • self 代表类的实例,self在定义类的方法时是必须有的,虽然在调用时不必传入相应的参数。
  • 类的方法和普通的函数只有一个特别的区别----它们必须有一个额外的第一个参数名称,按照惯例,它的名称是 self,当然换成其他名称也是可以的。

2、nn.Module 和 nn.Sequential

   该部分主要参考下面两条blog,个人感觉感jio很不错:

  • nn.Module学习
  • nn.Sequential学习

2.1 nn.Module

  Pytorch 中没有特别明显的 LayerModule 的区别,不管是自定义层、自定义块、自定义模型,都是通过继承 Module 类完成的,这一点很重要。其实 Sequential 类也是继承自 Module 类的。

  pytorch 里面一切自定义操作基本上都是继承自 nn.Module 类实现的。

2.1.1 torch.nn.Module类

先看源码:

class Module(object):def __init__(self):def forward(self, *input):def add_module(self, name, module):def cuda(self, device=None):def cpu(self):def __call__(self, *input, **kwargs):def parameters(self, recurse=True):def named_parameters(self, prefix='', recurse=True):def children(self):def named_children(self):def modules(self):  def named_modules(self, memo=None, prefix=''):def train(self, mode=True):def eval(self):def zero_grad(self):def __repr__(self):def __dir__(self):
# ...

  在自定义模型的时候,我们需要继承 nn.Module 类,并且需要重新实现构造函数 __init__() 方法和前向传播 forward() 方法。但有一些技巧需要注意:

  1. 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数 __init__() 中,当然也可以把不具有参数的层也放在里面;

  2. 一般把不具有可学习参数的层(如ReLUdropoutBatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数 __init__() 里面,则在 forward() 方法里面可以使用 nn.functional 来代替

  3. forward() 方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

  4. __init__() 方法中定义一系列层,此时层与层之间并没有连接关系,而在 forward() 方法中实现所有层的链接关系。

2.1.2 nn.Sequential 类

  nn.Sequential 类继承自nn.Module类,先来看定义:

class Sequential(Module): # 继承Moduledef __init__(self, *args):  # 重写了构造函数def _get_item_by_idx(self, iterator, idx):def __getitem__(self, idx):def __setitem__(self, idx, module):def __delitem__(self, idx):def __len__(self):def __dir__(self):def forward(self, input):  # 重写关键方法forward

  Sequential类的三种实现:

  1. 最简单的顺序模型
import torch.nn as nn
model = nn.Sequential(nn.Conv2d(1,20,5),nn.ReLU(),nn.Conv2d(20,64,5),nn.ReLU())print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential((0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(1): ReLU()(2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(3): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 每个层没有名称,默认通过0、1、2、3来命名。
  1. 给每一个层添加名称(orderedDict)
import torch.nn as nnfrom collections import OrderedDict
model = nn.Sequential(OrderedDict([('conv1', nn.Conv2d(1,20,5)),('relu1', nn.ReLU()),('conv2', nn.Conv2d(20,64,5)),('relu2', nn.ReLU())]))print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential((conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))(relu1): ReLU()(conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))(relu2): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 从结果可以看出,此时每一层都有了自己的名字,但不能以名字来进行层的索引

model[2] →正确
model["conv2]→错误

  1. 第三种实现(add_module)
import torch.nn as nn
from collections import OrderedDictmodel = nn.Sequential()
model.add_module("conv1",nn.Conv2d(1,20,5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20,64,5))
model.add_module('relu2', nn.ReLU())print(model)
print(model[2]) # 通过索引获取第几个层
  • Sequential 类并没有定义 add_module() 方法,实际上这个方法是定义在它的父类 Module 里面的,Sequential 继承了该方法。它的定义如下:
def add_module(self, name, module)

3.自己的示例

  再看当初自己写的代码,便不难理解了:

import torch.nn as nnclass LinearNet(nn.Module):def __init__(self, n_feature):# 这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。# 也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。# 当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。super(LinearNet, self).__init__()self.linear = nn.Linear(n_feature, 1)# 前向传播def forward(self, x):y = self.linear(x)return ynet = LinearNet(2)
print(net)           # 使用print可以打印出网络的结构

Pytorch中nn.Module和nn.Sequencial的简单学习相关推荐

  1. nn.Module、nn.Sequential和torch.nn.parameter学习笔记

    nn.Module.nn.Sequential和torch.nn.parameter是利用pytorch构建神经网络最重要的三个函数.搞清他们的具体用法是学习pytorch的必经之路. 目录 nn.M ...

  2. nn.Module与nn.functional

    大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...

  3. [Pytorch系列-30]:神经网络基础 - torch.nn库五大基本功能:nn.Parameter、nn.Linear、nn.functioinal、nn.Module、nn.Sequentia

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  4. Pytorch —— nn.Module类(nn.sequential)

    对于前面的线性回归模型. Logistic回归模型和神经网络,我们在构建的时候定义了需要的参数.这对于比较小的模型是可行的,但是对于大的模型,比如100 层的神经网络,这个时候再去手动定义参数就显得非 ...

  5. pytorch 中pad函数toch.nn.functional.pad()的使用

    padding操作是给图像外围加像素点. 为了实际说明操作过程,这里我们使用一张实际的图片来做一下处理. 这张图片是大小是(256,256),使用pad来给它加上一个黑色的边框.具体代码如下: imp ...

  6. PyTorch中常用Module和Layer的学习笔记~

    1 前言 今天在学习PyTorch对于VGG网络的官方实现,朱老师在上课的时候也讲了, 不过感觉自己记得还是不是很牢,所以想写个笔记记录一下~ 2 常用Module和Layer nn.Conv2d 这 ...

  7. pytorch ——模型创建与nn.Module

    1.网络模型创建步骤 模型模块中分为两个部分,模型创建和权值初始化: 模型创建又分为两部分,构建网络层和拼接网络层:网络层有卷积层,池化层,激活函数等:构建网络层后,需要进行网络层的拼接,拼接成LeN ...

  8. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

  9. Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()

    Pytorch 保存和加载模型后缀:.pt 和.pth 1 torch.save() [source] 保存一个序列化(serialized)的目标到磁盘.函数使用了Python的pickle程序用于 ...

最新文章

  1. 详细理解JS中的继承
  2. Genome-scale de novo assembly using ALGA 使用ALGA进行 基因组规模的从头组装
  3. Python的setuptools详解【2】find_packages()
  4. Linux CPU cache
  5. 更加安全的密钥生成方法Diffie-Hellman
  6. day38-数据库应用软件
  7. 面试java常问的问题_java面试官常问的问题
  8. 目前支持DDR3-1600(包括主板超频)最强的CPU是哪个?
  9. bind 启动redis_详解Redis开启远程登录连接
  10. LINUX 添加xp虚拟机
  11. 创科视觉软件说明书_机器视觉入门指南
  12. 学成在线-处理常见视频格式avi,mp4,mov,rmvb,flv
  13. 那些精贵的文献资源下载网址经验总结
  14. C语言指针结构体详解,结构体指针,C语言结构体指针详解
  15. dxe 如何跟smm 沟通 SMM Communication Protocol
  16. Google收购传感器公司Lumedyne
  17. Oracle海量数据清理-表空间释放
  18. 微信小程序开发-页面跳转的几种方式
  19. windows 7系统安装虚拟机及在虚拟机上安装ubuntu(linux)操作系统
  20. ICT在线测试仪优点有哪些?

热门文章

  1. java上传问题,求各位高手帮帮忙,看看是什么问题,急!急!急!
  2. Jeff·Bezos:怀念我在麦当劳的屌丝时光
  3. Android开发--便签(一)
  4. [USACO21DEC] Air Cownditioning B(差分)
  5. 阿里云丨以AI助力电力产业变革,创造有为时代——访阿里云人工智能科学家闵万里...
  6. java b2c_JAVA开源B2C系统
  7. 我的世界服务器修改后变回来,我的世界:经营8年的服务器主人去世后,城市中心变成了纪念馆...
  8. Hadoop源码分析(25)
  9. 学计算机的女生,是一种怎样的存在?
  10. Android Studio之library工程中不能使用switch-case语句访问资源ID