Tensor

自从张量(Tensor)计算这个概念出现后,神经网络的算法就可以看作是一系列的张量计算。所谓的张量,它原本是个数学概念,表示各种向量或者数值之间的关系。PyTorch的张量(torch.Tensor)表示的是N维矩阵与一维数组的关系。

http://web.mit.edu/~ezyang/Public/pytorch-internals.pdf

torch.Tensor的使用方法和numpy很相似(https://pytorch.org/...tensor-tutorial-py),两者唯一的区别在于torch.Tensor可以使用GPU来计算,这就比用CPU的numpy要快很多。

张量计算的种类有很多,比如加法、乘法、矩阵相乘、矩阵转置等,这些计算被称为算子(Operator),它们是PyTorch的核心组件。

算子的backend一般是C/C++的拓展程序,PyTorch的backend是称为"ATen"的C/C++库,ATen是"A Tensor"的缩写。

Operator

PyTorch所有的Operator都定义在Declarations.cwrap和native_functions.yaml这两个文件中,前者定义了从Torch那继承来的legacy operator(aten/src/TH),后者定义的是native operator,是PyTorch的operator。

相比于用C++开发的native code,legacy code是在PyTorch编译时由gen.py根据Declarations.cwrap的内容动态生成的。因此,如果你想要trace这些code,需要先编译PyTorch。

legacy code的开发要比native code复杂得多。如果可以的话,建议你尽量避开它们。

aten/src/ATen/Declarations.cwrap

MatMul

本文会以矩阵相乘--torch.matmul()为例来分析PyTorch算子的工作流程。

我在深入浅出全连接层(fully connected layer)中有讲在GPU层面是如何进行矩阵相乘的。Nvidia、AMD等公司提供了优化好的线性代数计算库--cuBLAS/rocBLAS/openBLAS,PyTorch只需要调用它们的API即可。

Figure 1: function flow of torch.matmul()

Figure 1是torch.matmul()在ATen中的function flow。可以看到,这个flow可不短,这主要是因为不同类型的tensor(2d or Nd, batched gemm or not,with or without bias,cuda or cpu)的操作也不尽相同。

at::matmul()主要负责将Tensor转换成cuBLAS需要的格式。前面说过,Tensor可以是N维矩阵,如果tensor A是3d矩阵,tensor B是2d矩阵,就需要先将3d转成2d;如果它们都是>=3d的矩阵,就要考虑batched matmul的情况;如果bias=True,后续就应该交给at::addmm()来处理;总之,matmul要考虑的事情比想象中要多。

除此之外,不同的dtype、device和layout需要调用不同的操作函数,这部分工作交由c10::dispatcher来完成。

Dispatcher

dispatcher主要用于动态调用dtype、device以及layout等方法函数。用过numpy的都知道,np.array()的数据类型有:float32, float16,int8,int32,.... 如果你了解C++就会知道,这类程序最适合用模板(template)来实现。

很遗憾,由于ATen有一部分operator是用C语言写的(从Torch继承过来),不支持模板功能,因此,就需要dispatcher这样的动态调度器。

类似地,PyTorch的tensor不仅可以运行在GPU上,还可以跑在CPU、mkldnn和xla等设备,Figure 1中的dispatcher4就根据tensor的device调用了mm的GPU实现。

layout是指tensor中元素的排布。一般来说,矩阵的排布都是紧凑型的,也就是strided layout。而那些有着大量0的稀疏矩阵,相应地就是sparse layout。

Figure 2: strided layout example

Figure 2是strided layout的演示实例,这里创建了一个2行2列的矩阵a,它的数据实际存放在一维数组(a.storage)里,2行2列只是这个数组的视图。

stride充当了从数组到视图的桥梁,比如,要打印第2行第2列的元素时,可以通过公式:来计算该元素在数组中的索引。

除了dtype、device、layout之外,dispatcher还可以用来调用legacy operator。比如说addmm这个operator,它的GPU实现就是通过dispatcher来跳转到legacy::cuda::_th_addmm。

aten/src/ATen/native/native_functions.yaml

END

到此,就完成了对PyTorch算子的学习。如果你要学习其他算子,可以先从aten/src/ATen/native目录的相关函数入手,从native_functions.yaml中找到dispatch目标函数,详情可以参考Figure 1。

pytorch 矩阵相乘_深入浅出PyTorch(算子篇)相关推荐

  1. pytorch 矩阵相乘_编译PyTorch静态库

    背景 众所周知,PyTorch项目作为一个C++工程,是基于CMake进行构建的.然而当你想基于CMake来构建PyTorch静态库时,你会发现: 静态编译相关的文档不全: CMake文件bug太多, ...

  2. pytorch 矩阵相乘_深度学习 — — PyTorch入门(三)

    点击关注我哦 autograd和动态计算图可以说是pytorch中非常核心的部分,我们在之前的文章中提到:autograd其实就是反向求偏导的过程,而在求偏导的过程中,链式求导法则和雅克比矩阵是其实现 ...

  3. Pytorch 矩阵相乘

    torch.bmm() torch.matmul() torch.bmm()强制规定维度和大小相同 torch.matmul()没有强制规定维度和大小,可以用利用广播机制进行不同维度的相乘操作 当进行 ...

  4. pytorch卷积可视化_使用Pytorch可视化卷积神经网络

    pytorch卷积可视化 Filter and Feature map Image by the author 筛选和特征图作者提供的图像 When dealing with image's and ...

  5. matlab非同秩矩阵相乘_线性代数精华——讲透矩阵的初等变换与矩阵的秩

    这篇文章和大家聊聊矩阵的初等变换和矩阵的秩. 矩阵的初等变换这个概念可能在很多人听来有些陌生,但其实我们早在初中的解多元方程组的时候就用过它.只不过在课本当中,这种方法叫做消元法.我们先来看一个课本里 ...

  6. numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践

    <<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗  读完<<深度学习框架PyTorc ...

  7. julia有 pytorch包吗_用 PyTorch 实现基于字符的循环神经网络 | Linux 中国

    导读:在过去的几周里,我花了很多时间用 PyTorch 实现了一个 char-rnn 的版本.我以前从未训练过神经网络,所以这可能是一个有趣的开始. 本文字数:7201,阅读时长大约: 9分钟 htt ...

  8. bert pytorch源码_【PyTorch】梯度爆炸、loss在反向传播变为nan

    点击上方"MLNLP",选择"星标"公众号 重磅干货,第一时间送达 作者丨CV路上一名研究僧 知乎专栏丨深度图像与视频增强 地址丨https://zhuanla ...

  9. pytorch保存准确率_初学Pytorch:MNIST数据集训练详解

    前言 本文讲述了如何使用Pytorch(一种深度学习框架)构建一个简单的卷积神经网络,并使用MNIST数据集(28*28手写数字图片集)进行训练和测试.针对过程中的每个步骤都尽可能的给出了详尽的解释. ...

最新文章

  1. saltstack state模块-状态管理
  2. 数字图像缩放之最近邻插值与双线性插值处理效果对比
  3. sql server 根据身份证号计算出生日期和年龄的存储过程
  4. cfree运行程序错误_C/C++程序调试和内存检测
  5. 多媒体计算机接口卡,多媒体技术基础 2.2多媒体接口卡 多媒体接口卡.docx
  6. Android教程 第四章 用户界面设计基础
  7. html在线播放mp4文件,使用HTML5视频在Firefox中播放MP4文件
  8. 微型计算机原理及应用考试重点,微型计算机原理及应用考试重点.doc
  9. Tableau学习教程(万字保姆级教程)​​​​​​
  10. .Net代码检查工具 Gendarme
  11. oracle中的varchar2存储中文,varchar2存储汉字
  12. 基带传输编码方式HDB3码的快速编码步骤、原理及举例
  13. BCNF无损分解例题
  14. 【mysql数据类型】uint和int的区别
  15. 2021.2冬入京都大学修士考试复习经验贴
  16. 【译】Rust 中的错误处理
  17. FPGA series # 基于SDx的fft函数加速
  18. 家族查询系统c语言源程序,家谱管理系统(含源代码).docx
  19. BERT原理和结构详解
  20. 工作方法论: 请别跟我说“帮我解决一个问题”

热门文章

  1. 结构体在多线程中用法
  2. Django不能ip调试访问
  3. pycharm2017设置注释字体颜色
  4. oracle忘记口令
  5. 什么是视频会议?什么是H.323?SIP是什么协议?
  6. python 类和函数的区别
  7. 弹性方法计算内力例题_弹性力学重要公式汇总,还不快来强记一波【含参考答案】...
  8. 复杂个人信息输出程序python_Python高级技巧:用一行代码减少一半内存占用
  9. php图片如何让浮动,页面中用css属性怎么控制图片自定义浮动?(示例)
  10. mysql8.0.22 win7_现在还能不能下载到正版WIN 7