从零开始学Pytorch(第5天)

  • 前言
  • 一、模块类的构建
    • 1. nn.Module
    • 2.构建一个线性回归类
  • 二、计算图和自动求导机制
    • 1.计算图
    • 2.自动求导
  • 总结

前言

今天主要了解和学习Pytorch中的模块类和计算图、自动求导机制


一、模块类的构建

1. nn.Module

Pytorch模型通过继承nn.Module,在类的内部定义子模块实例化,
通过前向计算调用子模块,最后实现深度学习模型的搭建。
import torch.nn as nnclass Model(nn.Module):def __init__(self,...): #定义类的初始化函数,...是传入的参数super(Model, self).__init__()...#根据传入的参数来定义子模块def forward(self,...):#定义前向计算的输入参数,...一般是张量或者其他的参数。ret                #根据传入的张量和子模块计算返回张量return ret

这个基本框架记住之后套用就好。

2.构建一个线性回归类

import torch
import torch.nn as nnclass LinearModel(nn.Module):def __init__(self, ndim):super(LinearModel, self).__init__()self.ndim = ndimself.weight = nn.Parameter(torch.randn(ndim, 1))  # 定义权重,这里的nn.Parameter可以理解为将张量转换为可训练的参数类self.bias = nn.Parameter(torch.randn(1))  # 定义偏置def forward(self, x):# y=Wx+breturn x.mm(self.weight) + self.biaslm=LinearModel(5)   #初始化一个线性模型实例,特征数为5
x=torch.randn(4,5)  #定义输入,可以理解为mini-batch=4lm(x)
tensor([[-3.0978],[ 4.5808],[ 0.2038],[-1.6345]], grad_fn=<AddBackward0>)

可以看到,我们要做的就是把计算过程变成代码输入到模型中即可。

二、计算图和自动求导机制

1.计算图

计算图(Computational Graph)是用来描述运算的有向无环图,其中的节点表示数据,如张量等;边表示运算,如加减乘除卷积等。

深度学习框架采用的是两种策略:静态图动态图
Tensorflow 1 和Caffe支持静态图,即提前构造好运算,再根据输入的张量进行计算得出结果。其优点是,减少了计算图构建的时间消耗,效率高;但却无法进行修改且上手较难。

Pytorch使用的是动态计算图,运算与搭建同时进行,可以实时输出深度学习模型的中间张量,便于调试。

2.自动求导

自动求导的过程是:从计算图输出的损失函数标量值,利用反向传播算法,反推计算图中权重张量的梯度。

t1=torch.randn(3,3,requires_grad=True)
t1tensor([[ 0.1071, -0.6140, -1.0037],[ 0.3234, -1.4746, -1.7091],[-1.0635,  1.3680,  1.4820]], requires_grad=True)t2=t1.pow(2).sum() #计算张量所有张量平方和
t2.backward()t1.grad   #梯度是原分量的两倍
tensor([[ 0.2142, -1.2281, -2.0074],[ 0.6468, -2.9492, -3.4183],[-2.1270,  2.7360,  2.9641]])t2=t1.pow(2).sum() #计算张量所有张量平方和
t2.backward()
t1.grad   #梯度累计tensor([[ 0.4285, -2.4562, -4.0147],[ 1.2937, -5.8984, -6.8366],[-4.2539,  5.4720,  5.9282]])t1.grad.zero_() #梯度清零tensor([[0., 0., 0.],[0., 0., 0.],[0., 0., 0.]])

假如一个函数是f(x)=x2f(x)=x^2f(x)=x2则它的导数是f′(x)=2xf'(x)=2xf′(x)=2x,在代码中可以看到求导之后的梯度变化。

总结

今天只要掌握如何构建(套用)模块类,和使用backward()函数即可。明天将学习Pytorch的损失函数和优化器,此后的学习就越来越贴近于实战了。

从零开始学Pytorch(第5天)相关推荐

  1. 从零开始学Pytorch(零)之安装Pytorch

    本文首发于公众号"计算机视觉cv" Pytorch优势   聊聊为什么使用Pytorch,个人觉得Pytorch比Tensorflow对新手更为友善,而且现在Pytorch在学术界 ...

  2. mpandroidchart y轴从0开始_从零开始学Pytorch(十七)之目标检测基础

    目标检测和边界框 %matplotlib inline from PIL import Imageimport sys sys.path.append('/home/input/') #数据集路径 i ...

  3. 从零开始学Pytorch(十七)之目标检测基础

    目标检测和边界框 %matplotlib inline from PIL import Imageimport sys sys.path.append('/home/input/') #数据集路径 i ...

  4. 从零开始学Pytorch(五)之欠拟合和过拟合

    本文首发于微信公众号"计算机视觉cv" 模型选择.过拟合和欠拟合 训练误差和泛化误差 训练误差(training error)指模型在训练数据集上表现出的误差,泛化误差(gener ...

  5. 建议收藏!从零开始学PyTorch

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自:机器学习算法那些事 PyTorch 是一个深度学习框架, ...

  6. 从零开始学Pytorch(十六)之模型微调

    微调 在前面的一些章节中,我们介绍了如何在只有6万张图像的Fashion-MNIST训练数据集上训练模型.我们还描述了学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像 ...

  7. 从零开始学Pytorch(十)之循环神经网络基础

    本节介绍循环神经网络,下图展示了如何基于循环神经网络实现语言模型.我们的目的是基于当前的输入与过去的输入序列,预测序列的下一个字符.循环神经网络引入一个隐藏变量HHH,用HtH_{t}Ht​表示HHH ...

  8. 从零开始学Pytorch(四)之softmax与分类模型

    softmax的基本概念 分类问题 一个简单的图像分类问题,输入图像的高和宽均为2像素,色彩为灰度. 图像中的4像素分别记为x1,x2,x3,x4x_1, x_2, x_3, x_4x1​,x2​,x ...

  9. 从零开始学Pytorch(三)之多层感知机的实现

    多层感知机的基本知识 我们将以多层感知机(multilayer perceptron,MLP)为例,介绍多层神经网络的概念. 隐藏层 下图展示了一个多层感知机的神经网络图,它含有一个隐藏层,该层中有5 ...

最新文章

  1. bzoj3270 博物馆
  2. 【内网安全】域横向网络传输应用层隧道技术
  3. 手机投屏到电视的5种方法_安卓手机、苹果手机投屏到电视史上最全的方法
  4. CodeForces - 1208F Bits And Pieces(SOSdp+贪心)
  5. 浪潮K1 Power通过ISO/IEC 20243标准认证
  6. kafka自带的zk启动_Centos上将zookeeper和kafka设置为开机自启
  7. 政企上云网络适配复杂,看华为云Stack有妙招
  8. pycharm的background task一直更新index,速度慢的解决方法
  9. TypeScript 中的 SOLID 原则
  10. linux2.6.34编译安装,busybox linux-2.6.2 编译安装中碰到的若干问题
  11. python简明教程_06
  12. 就国内而言,读大学的意义是什么?
  13. PHP简单的学生管理系统的代码
  14. MeterSphere一站式开源持续测试平台
  15. RK3588平台开发系列讲解(USB篇)UAC初识
  16. Chrome插件实现GitHub代码离线翻译v0.0.4 2018-10-19
  17. UI-Bootstrap 模态对话框示例
  18. [RabbitMQ--1] MQ简介
  19. springcloud使用RestTemplate进行接口调用
  20. fast unfolding 算法——论文总结

热门文章

  1. 计算机维修工具和仪器,421常用测量仪器和维修工具.ppt
  2. 对于app触控屏幕触发音效的延迟与杂音测试
  3. 永磁同步电机转子磁链_无轴承永磁同步电机研究现状和未来发展趋势
  4. 手机上搭建Linux服务器
  5. 【图文详细 】Scala——编程练习
  6. C程序翻译成汇编语言
  7. Solaris 10 学习笔记
  8. snmp同步端口号_SNMP端口号教程及其示例
  9. 根据AutoCAD地形图建立ANSYS和Flac3D实体模型
  10. 点云IO篇之las文件读写