第1章 PyTorch和神经网络

1.1 PyTorch入门

1.1.2 PyTorch张量

1.1.3 PyTorch的自动求导机制

1.1.4 计算图

自动梯度计算看似很神奇,但它其实并不是魔术。
它背后的原理很值得深入了解,这些知识将帮助我们构建更大规模的网络。
看看下面这个非常简单的网络。它甚至不算一个神经网络,而只是一系列计算。

在上图中,我们看到输入x被用于计算y,y再被用于计算输出z。
假设y和z的计算过程如下:

如果我们希望知道输出zzz如何随xxx变化,我们需要知道梯度 dy/dxdy/dxdy/dx。下面我们来逐步计算。

第一行是微积分的链式法则(chain rule),对我们非常重要。
我们刚刚算出,zzz随xxx的变化可表示为4x4x4x。如果x=3.5x = 3.5x=3.5,则dz/dx=4×3.5=14dz/dx = 4 × 3.5 = 14dz/dx=4×3.5=14。

当yyy以xxx的形式定义,而zzz以yyy的形式定义时,PyTorch便将这些张量连成一幅图,以展示这些张量是如何连接的。这幅图叫计算图(computation graph)。

在我们的例子中,计算图看起来可能是下面这样的:

我们可以看到yyy是如何从xxx计算得到的,zzz是如何从yyy计算得到的。此外,PyTorch还增加了几个反向箭头,表示yyy如何随着xxx变化,zzz如何随着yyy变化。这些就是梯度,在训练过程中用来更新神经网络。微积分的过程由PyTorch完成,无须我们自己动手计算。

为了计算出zzz如何随着xxx变化,我们合并从zzz经由yyy回到xxx的路径中的所有梯度。这便是微积分的链式法则。

PyTorch先构建一个只有正向连接的计算图。我们需要通过backward()函数,使PyTorch计算出反向的梯度。

梯度dz/dx在张量x中被存储为x.grad。

值得注意的是,张量xxx内部的梯度值与z的变化有关。这是因为我们要求PyTorch使用z.backward ()从zzz反向计算。因此,x.gradx.gradx.grad是dz/dxdz/dxdz/dx,而不是dy/dxdy/dxdy/dx。

大多数有效的神经网络包含多个节点,每个节点有多个连进该节点的链接,以及从该节点出发的链接。让我们来看一个简单的例子,例子中的节点有多个进入的链接。

可见,输入aaa和bbb同时对xxx和yyy有影响,而输出zzz是由xxx和yyy计算出来的。
这些节点之间的关系如下。

我们按同样的方法计算梯度。

接着,把这些信息添加到计算图中。

现在,我们可以轻易地通过z到a的路径计算出梯度dz/da。实际上,从z到a有两条路径,一条通过x,另一条通过y,我们只需要把两条路径的表达式相加即可。这么做是合理的,因为从a到z的两条路径都影响了z的值,这也与我们用微积分的链式法则计算出的dz/da的结果一致。

dz/da=dz/dx+dx/da+dz/dy+dy/dadz/da = dz/dx+dx/da +dz/dy+dy/dadz/da=dz/dx+dx/da+dz/dy+dy/da

第一条路径经过xxx,表示为2×22 × 22×2;第二条路径经过yyy,表示为3×10a3×10a3×10a。所以,zzz随aaa变化的速率是4+30a4 + 30a4+30a。
如果aaa是2,则dz/dadz/dadz/da是4 + 30 × 2 = 64。

我们来检验一下用PyTorch是否也能得出这个值。首先,我们定义PyTorch构建计算图所需要的关系。


接着,我们触发梯度计算并查询张量aaa里面的值。


有效的神经网络通常比这个小型网络规模大得多。但是PyTorch构建计算图的方式以及沿着路径向后计算梯度的过程是一样的。

1.1.5 学习要点

  • Colab服务允许我们在谷歌的服务器上运行Python代码。Colab使用Python笔记本,我们只需要一个Web浏览器即可使用。
  • PyTorch是一个领先的Python机器学习架构。它与numpy类似,允许我们使用数字数组。同时,它也提供了丰富的工具集和函数,使机器学习更容易上手。
  • 在PyTorch中,数据的基本单位是张量(tensor)。张量可以是多维数组、简单的二维矩阵、一维列表,也可以是单值。
  • PyTorch的主要特性是能够自动计算函数的梯度(gradient)。梯度的计算是训练神经网络的关键。为此,PyTorch需要构建一张计算图(computationgraph),图中包含多个张量以及它们之间的关系。在代码中,该过程在我们以一个张量定义另一个张量时自动完成。

第1章 PyTorch和神经网:1.1 PyTorch和神经网络相关推荐

  1. pytorch贝叶斯网络_贝叶斯神经网络:2个在TensorFlow和Pytorch中完全连接

    pytorch贝叶斯网络 贝叶斯神经网络 (Bayesian Neural Net) This chapter continues the series on Bayesian deep learni ...

  2. PyTorch学习笔记(19) ——NIPS2019 PyTorch: An Imperative Style, High-Performance Deep Learning Library

    0. 前言 波兰小哥Adam Paszke从15年的Torch开始,到现在发表了关于PyTorch的Neurips2019论文(令我惊讶的是只中了Poster?而不是Spotlight?).中间经历了 ...

  3. Pytorch:入门指南和 PyTorch 的 GPU版本安装(非常详细)

    Pytorch: 入门指南和 PyTorch 的 GPU版本安装(非常详细) Copyright: Jingmin Wei, Pattern Recognition and Intelligent S ...

  4. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  5. 基于PyTorch,如何构建一个简单的神经网络

    本文为 PyTorch 官方教程中:如何构建神经网络.基于 PyTorch 专门构建神经网络的子模块 torch.nn 构建一个简单的神经网络. 完整教程运行 codelab→ https://ope ...

  6. PyTorch | (1)初识PyTorch

    PyTorch | (1)初识PyTorch 介绍 PyTorch是一个非常有可能改变深度学习领域前景的Python库.我尝试使用了几星期PyTorch,然后被它的易用性所震惊,在我使用过的各种深度学 ...

  7. PyTorch or TensorFlow?强力推荐PyTorch不是没有理由的!一文学透pytorch!

    在机器学习领域,面对各类复杂多变的业务问题,构建灵活易调整的模型是高阶机器学习工程师必备的工作能力.然而,许多工程师还是有一个想法上的误区,以为只要掌握了一种深度学习的框架就能走遍天下了. 事实上,在 ...

  8. Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配、对应版本安装之详细攻略

    Pytorch:深度学习中pytorch/torchvision版本和CUDA版本最正确版本匹配.对应版本安装之详细攻略 目录 深度学习中pytorch/torchvision版本和CUDA版本最正确 ...

  9. Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet、AlexNet、VGG、NIN、GoogleNet、ResNet)——从代码认知CNN经典架构

    Pytorch之CNN:基于Pytorch框架实现经典卷积神经网络的算法(LeNet.AlexNet.VGG.NIN.GoogleNet.ResNet)--从代码认知CNN经典架构 目录 CNN经典算 ...

  10. PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN

    PyTorch之LeNet-5:利用PyTorch实现最经典的LeNet-5卷积神经网络对手写数字图片识别CNN 目录 训练过程 代码设计 训练过程 代码设计 #PyTorch:利用PyTorch实现 ...

最新文章

  1. 使用 PHP 在站点上构建类似 Twitter 的系统
  2. 7个Debug linux程序的Strace 列子
  3. python中txt转成csv_Python-如何将JSON转换为CSV?
  4. 读《我编程,我快乐,程序员的职业规划之道》有感
  5. 网页性能优化04-函数节流
  6. 11.LNMP基础架构
  7. 如何用 SQL 做留存率分析?
  8. 详解Oracle临时表的几种用法及意义
  9. 利用Python制作微信跳一跳外挂,又来带你装一波X!
  10. vmd与ovito的对比
  11. yum install gcc报错Error: Package: glibc-2.17-260.el7_6.6.i686 (updates) Requires: glibc-common = 2.17
  12. qcnfa435_【路由知识小课堂番外篇】支持MU-MIMO技术设备一览表(2017.9.25第一版)...
  13. Servlet的原理和基础使用
  14. 如何查看虚拟机服务器ftp,如何通过FTP工具查看虚拟空间使用了多少?
  15. 详解机器学习中的梯度消失、爆炸原因及其解决方法
  16. IDC机房运维工程师需要具备哪些技能及素质
  17. Vue实现路径转二维码,并用手机扫码下载APP
  18. 云笔记有哪些好用的功能,这4款云笔记一定要试试
  19. LS1028 使用serdes mode 99BB软件修改方案
  20. java实习第二周总结

热门文章

  1. linux 检查zip是否损坏,用-v参数 unzip -v test.zip 检查zip文件是否损坏代常亮
  2. 百分比布局参照物的总结
  3. 在局域网内怎样使两台计算机共享,怎么使两台电脑共享数据?
  4. 物联网的体系结构和关键技术
  5. Rasa课程、Rasa培训、Rasa面试系列 金融银行案例Bot Step By Step学习
  6. 冲向2021 荣耀“无限”创新
  7. 【2018国赛线上比赛】知识问答题真题演练第一波
  8. 二阶魔方还原 C++ BFS
  9. 【Office】wps表格如何让后面的单元格随着下拉选项自动填充
  10. Microsoft Visual SourceSafe 2005 简体中文版