autograd 实现了自动微分系统,然而对深度学习来说过于底层,而本节将介绍 nn 模块,是构建于 autograd 之上的神经网络模块。

1. 简单感知机

使用 autograd 可实现深度学习模型,但其抽象程度较低,如果用其来实现深度学习模型,则需要编写的代码量极大。在这种情况下,torch.nn 应运而生,其是专门为深度学习设计的模块。

torch.nn 的核心数据结构是 Module ,它是一个抽象的概念,既可以表示神经网络中的某个层(layer),也可以表示一个包含很多层的神经网络。

在实际使用中,最常见的做法继承 nn.Module ,撰写自己的网络层。

下面先来看看如何使用 nn.Module 实现自己的全连接层。全连接层,又名仿射层,输入 y 和输入 x 满足y=xW+bWb 是可学习的参数。

import torch as t
from torch import  nnclass Linear(nn.Module):def __init__(self, input_features, out_features):super(Linear, self).__init__() # 等价于 nn.Module.__init__(self)self.w = nn.Parameter(t.randn(input_features, out_features))self.b = nn.Parameter(t.randn(out_features))def forward(self, x):x = x.mm(self.w)return x + self.b.expand_as(x)layer = Linear(4, 3)x = t.randn(2, 4)
output = layer(x)
print outputfor name, parameter in layer.named_parameters():print name, parameter

output 输出为 :

tensor([[ 1.5752,  0.6730, -0.0763],[-0.7037, -0.6641, -2.3261]], grad_fn=<ThAddBackward>)

name, parameter 输出为:

w Parameter containing:
tensor([[-1.0459, -0.1899,  0.2202],[ 1.5751,  0.0613,  1.7350],[-0.2644,  0.7728,  1.4141],[-0.3739, -0.4349, -0.0984]], requires_grad=True)
b Parameter containing:
tensor([1.3054, 0.3063, 0.4375], requires_grad=True)

可见,全连接层的实现非常简单,但需注意以下几点:

  • 自定义层 Linear 必须继承 nn.Module ,并且在其构造函数中需调用 nn.Module 的构造函数,即super(Linear,self).__init()__nn.Module.__init(self)__
  • 在构造函数 __init__ 中必须自己定义可学习的参数,并封装成 Parameter,如在本例中我们把 wb 封装成 ParameterParameter 是一种特殊的 Variable ,但其默认需要求导(requires_grad=True );
  • forward 函数实现前向传播过程,其输入可以是一个或多个 variable,对 x 的任何操作也必须是 variable 支持的操作。
  • 无须写反向传播函数,因其前向传播都是对 variable 进行操作,nn.Module 能够利用 autograd 自动实现反向传播,这一点比 Function 简单许多。
  • 使用时,直观上可将 layer 看成数学概念中的函数,调用 layer(input) 即可得到 input 对应的结果。它等价于 layers.__call(input)__ ,在 __call__ 函数中,主要调用的是 layer.forward(x) 。所以在实际使用中应尽量使用layer(x) 而不是使用 layer.forward(x)
  • Module 中的可学习参数可以通过 named_parameters() 或者 parameters() 返回迭代器,前者会给每个parameter 附上名字,使其更具有辨识度。

可见,利用 Module 实现的全连接层,比利用 Function 实现的更简单,因其不再需要写反向传播函数。

2. 多层感知机

Module 能够自动检测到自己的 parameter ,并将其作为学习参数。除了 parameterModule 还包含子Module ,主 Module 能够递归查找子 Module 中的 parameter 。下面再来看看稍微复杂一点的网络:多层感知机。

多层感知机的网络结构如图所示。它由两个全连接层组成,采用 sigmoid 函数作为激活函数(图中没有画出)。

实现代码如下:

import torch as t
from torch import  nnclass Linear(nn.Module):def __init__(self, input_features, out_features):super(Linear, self).__init__() # 等价于 nn.Module.__init__(self)self.w = nn.Parameter(t.randn(input_features, out_features))self.b = nn.Parameter(t.randn(out_features))def forward(self, x):x = x.mm(self.w)return x + self.b.expand_as(x)class Perceptron(nn.Module):def __init__(self, in_features, hidden_features, out_features):nn.Module.__init__(self)self.layer1 = Linear(in_features, hidden_features) # 此处的 Linear 前面自定义的全连接层self.layer2 = Linear(hidden_features, out_features)def forward(self, x):x = self.layer1(x)x = t.sigmoid(x)return self.layer2(x)perception = Perceptron(3,4,1)
for name, param in perception.named_parameters():print(name, param.size())

输出结果:

layer1.w torch.Size([3, 4])
layer1.b torch.Size([4])
layer2.w torch.Size([4, 1])
layer2.b torch.Size([1])

可见,即使是稍复杂的多层感知机,其实现依旧很简单。这里需要注意以下两个知识点。

  • 构造函数 __init__ 中,可利用前面自定义的 Linear 层( Module )作为当前 Module 对象的一个子Module ,它的可学习参数,也会成为当前 Module 的可学习参数。

  • 在前向传播函数中,我们有意识地将输出变量都命名为 x,是为了能让 Python 回收一些中间层的输出,从而节省内存。但并不是所有的中间结果都会被回收,有些 variable 虽然名字被覆盖,但其在反向传播时仍需要用到,此时 Python 的内存回收模块将通过检查引用计数,不会回收这一部分内存。

Moduleparameter 的全局命名规范如下:

  • Parameter 直接命名。例如 self.param_name = nn.Parameter(t.randn(3,4)) ,命名为 param_name
  • Module 中的 parameter ,会在其名字之前加上当前 Module 的名字。例如 self.sub_module = SubModule()SubModule 中有个 parameter 的名字也叫作 param_name ,那么二者拼接而成的 parameter name 就是sub_module.param_name

为了方便用户使用,PyTorch 实现了神经网络中绝大多数的 layer ,这些 layer 都继承于 nn.Module ,封装了可学习参数 parameter ,并实现了 forward 函数,且专门针对 GPU 运算进行了 CuDNN 优化,其速度和性能都十分优异。

  • 构造函数的参数,如 nn.Linear(in_features,out_features,bias),需关注这三个参数的作用。
  • 属性、可学习参数和子 Module 。如 nn.Linear 中有 weightbias 两个可学习参数,不包含子 Module
  • 输入输出的形状,如 nn.Linear 的输入形状是(Ninput_features),输出形状为(N,output_features),Nbatch_size

这些自定义 layer 对输入形状都有假设:输入的不是单个数据,而是一个 batch 。若想输入一个数据,必须调用 unsqueeze(0) 函数将数据伪装成 batch_size=1batch

PyTorch 笔记(14)— nn.module 实现简单感知机和多层感知机相关推荐

  1. pytorch教程之nn.Module类详解——使用Module类来自定义网络层

    前言:前面介绍了如何自定义一个模型--通过继承nn.Module类来实现,在__init__构造函数中申明各个层的定义,在forward中实现层之间的连接关系,实际上就是前向传播的过程. 事实上,在p ...

  2. pytorch教程之nn.Module类详解——使用Module类来自定义模型

    pytorch教程之nn.Module类详解--使用Module类来自定义模型_MIss-Y的博客-CSDN博客_nn是什么意思前言:pytorch中对于一般的序列模型,直接使用torch.nn.Se ...

  3. 05_多层感知机_多层感知机笔记

    4. 多层感知机 多层感知机:最简单的深度网络,由多层神经元组成,每一层都与下面一层(从中接收输入)和上面一层(反过来影响当前层的神经元)完全相连 训练大容量模型时,面临着过拟合的风险 4.1. 多层 ...

  4. 深度学习——感知机:多层感知机(multi-layered perceptron)图文详解

    多层感知机 一,多层感知机 1.1 现在已有的门电路组合 1.2 异或门的实现 二,从与非门到计算机 三,总结 一,多层感知机   在上一篇深度学习--感知机(perceptron)图文详解中我们已经 ...

  5. 感知机和多层感知机详细学习

    1. 感知机的前向推理? 感知机其实就是类似神经网络的一个神经元 w0相当于bias,也就是偏置 w1-wn是权重 step fuction是sign 前向推理的公式 2. 感知机的loss func ...

  6. torch的拼接函数_从零开始深度学习Pytorch笔记(13)—— torch.optim

    前文传送门: 从零开始深度学习Pytorch笔记(1)--安装Pytorch 从零开始深度学习Pytorch笔记(2)--张量的创建(上) 从零开始深度学习Pytorch笔记(3)--张量的创建(下) ...

  7. PyTorch学习笔记2:nn.Module、优化器、模型的保存和加载、TensorBoard

    文章目录 一.nn.Module 1.1 nn.Module的调用 1.2 线性回归的实现 二.损失函数 三.优化器 3.1.1 SGD优化器 3.1.2 Adagrad优化器 3.2 分层学习率 3 ...

  8. 深度学习笔记其三:多层感知机和PYTORCH

    深度学习笔记其三:多层感知机和PYTORCH 1. 多层感知机 1.1 隐藏层 1.1.1 线性模型可能会出错 1.1.2 在网络中加入隐藏层 1.1.3 从线性到非线性 1.1.4 通用近似定理 1 ...

  9. 【小白学习PyTorch教程】四、基于nn.Module类实现线性回归模型

    「@Author:Runsen」 上次介绍了顺序模型,但是在大多数情况下,我们基本都是以类的形式实现神经网络. 大多数情况下创建一个继承自 Pytorch 中的 nn.Module 的类,这样可以使用 ...

最新文章

  1. Linux下通过txt文件导入数据到MySQL数据库
  2. 2020图灵年度好书大赏 | 15周年视频纪念版
  3. wifi协议栈_一文读懂米家部分智能硬件:米家Zigbee及WiFi模块拆解分析
  4. 【SpringBoot】在SpringBoot中使用Ehcache
  5. 下列属于计算机人工智能应用领域的是多选题,每天五道选择题(10)
  6. SQL server 2005安装问题汇总
  7. linux 备份mysql_linux下备份MYSQL数据库的方法
  8. android 底部去除list渐变,layer-list渐变色的处理
  9. 纯前端实现人脸识别-提取-合成
  10. ICCV2021|性能优于何恺明团队MoCo v2,DetCo:为目标检测定制任务的对比学习
  11. [译] 深度学习的未来
  12. html自定义字体,css怎么自定义字体?
  13. git 和gitHup工具笔记的详细教程
  14. echarts地图整体渐变色
  15. Android P+通过反射调用系统API实现高级功能
  16. NoSQL 一致性[详解]更新一致性
  17. 椭圆曲线ECC倍点运算forJava
  18. Windows设置/去除C盘的写保护
  19. 手机python怎么画图_无所不能的python编程是怎么快速画图的呢?5分钟学会!
  20. Python+Eclipse配置`PyDev`完整教程

热门文章

  1. 在kotlin companion object中读取spring boot配置文件,静态类使用@Value注解配置
  2. docker一步安装mysql,docker的魅力就在于此
  3. 2022-2028年中国高密度聚乙烯(HDPE)行业市场发展调研及投资前景分析报告
  4. 德国最受欢迎的程序员技能排行
  5. 【VS实践】VS解决方案中出现无法生成DLL文件
  6. Springboot前后端分离上传、下载压缩包、查看文件
  7. 基于javaGUI的文档识别工具制作
  8. GAAFET与FinFET架构
  9. 如何使用Nsight Compute?
  10. 2021年大数据Hadoop(二十六):YARN三大组件介绍