点击上方“算法猿的成长“,关注公众号,选择加“星标“或“置顶”

总第 132 篇文章,本文大约 7000 字,阅读大约需要 20 分钟

原文:https://github.com/vahidk/EffectivePyTorch

作者:vahidk

前言

这是一份 PyTorch 教程和最佳实践笔记,目录如下所示:

  1. PyTorch 基础

  2. 将模型封装为模块

  3. 广播机制的优缺点

  4. 使用好重载的运算符

  5. 采用 TorchScript 优化运行时间

  6. 构建高效的自定义数据加载类

  7. PyTorch 的数值稳定性

因为原文太长所以分为上下两篇文章进行介绍,本文介绍前四点,从基础开始介绍到使用重载的运算符。

首先 PyTorch 的安装可以根据官方文档进行操作:

https://pytorch.org/

pip install torch torchvision

1. PyTorch 基础

PyTorch 是数值计算方面其中一个最流行的库,同时也是机器学习研究方面最广泛使用的框架。在很多方面,它和 NumPy 都非常相似,但是它可以在不需要代码做多大改变的情况下,在 CPUs,GPUs,TPUs 上实现计算,以及非常容易实现分布式计算的操作。PyTorch 的其中一个最重要的特征就是自动微分。它可以让需要采用梯度下降算法进行训练的机器学习算法的实现更加方便,可以更高效的自动计算函数的梯度。我们的目标是提供更好的 PyTorch 介绍以及讨论使用 PyTorch 的一些最佳实践。

对于 PyTorch 第一个需要学习的就是张量(Tensors)的概念,张量就是多维数组,它和 numpy 的数组非常相似,但多了一些函数功能。

一个张量可以存储一个标量数值、一个数组、一个矩阵:

import torch
# 标量数值
a = torch.tensor(3)
print(a)  # tensor(3)
# 数组
b = torch.tensor([1, 2])
print(b)  # tensor([1, 2])
# 矩阵
c = torch.zeros([2, 2])
print(c)  # tensor([[0., 0.], [0., 0.]])
# 任意维度的张量
d = torch.rand([2, 2, 2])

张量还可以高效的执行代数的运算。机器学习应用中最常见的运算就是矩阵乘法。例如希望将两个随机矩阵进行相乘,维度分别是 和 ,这个运算可以通过矩阵相乘运算实现(@):

import torchx = torch.randn([3, 5])
y = torch.randn([5, 4])
z = x @ yprint(z)

对于向量相加,如下所示:

z = x + y

将张量转换为 numpy 数组,可以调用 numpy() 方法:

print(z.numpy())

当然,反过来 numpy 数组转换为张量是可以的:

x = torch.tensor(np.random.normal([3, 5]))

自动微分

PyTorch 中相比 numpy  最大优点就是可以实现自动微分,这对于优化神经网络参数的应用非常有帮助。下面通过一个例子来帮助理解这个优点。

假设现在有一个复合函数:g(u(x)) ,为了计算 gx 的导数,这里可以采用链式法则,即

而 PyTorch 可以自动实现这个求导的过程。

为了在 PyTorch 中计算导数,首先要创建一个张量,并设置其 requires_grad = True ,然后利用张量运算来定义函数,这里假设 u 是一个二次方的函数,而 g 是一个简单的线性函数,代码如下所示:

x = torch.tensor(1.0, requires_grad=True)def u(x):return x * xdef g(u):return -u

在这个例子中,复合函数就是 ,所以导数是 ,如果 x=1 ,那么可以得到 -2

在 PyTorch 中调用梯度函数:

dgdx = torch.autograd.grad(g(u(x)), x)[0]
print(dgdx)  # tensor(-2.)

拟合曲线

为了展示自动微分有多么强大,这里介绍另一个例子。

首先假设我们有一些服从一个曲线(也就是函数 )的样本,然后希望基于这些样本来评估这个函数 f(x) 。我们先定义一个带参数的函数:

函数的输入是 x,然后 w 是参数,目标是找到合适的参数使得下列式子成立:

实现的一个方法可以是通过优化下面的损失函数来实现:

尽管这个问题里有一个正式的函数(即 f(x) 是一个具体的函数),但这里我们还是采用一个更加通用的方法,可以应用到任何一个可微分的函数,并采用随机梯度下降法,即通过计算 L(w) 对于每个参数 w 的梯度的平均值,然后不断从相反反向移动。

利用 PyTorch 实现的代码如下所示:

import numpy as np
import torch# Assuming we know that the desired function is a polynomial of 2nd degree, we
# allocate a vector of size 3 to hold the coefficients and initialize it with
# random noise.
w = torch.tensor(torch.randn([3, 1]), requires_grad=True)# We use the Adam optimizer with learning rate set to 0.1 to minimize the loss.
opt = torch.optim.Adam([w], 0.1)def model(x):# We define yhat to be our estimate of y.f = torch.stack([x * x, x, torch.ones_like(x)], 1)yhat = torch.squeeze(f @ w, 1)return yhatdef compute_loss(y, yhat):# The loss is defined to be the mean squared error distance between our# estimate of y and its true value. loss = torch.nn.functional.mse_loss(yhat, y)return lossdef generate_data():# Generate some training data based on the true functionx = torch.rand(100) * 20 - 10y = 5 * x * x + 3return x, ydef train_step():x, y = generate_data()yhat = model(x)loss = compute_loss(y, yhat)opt.zero_grad()loss.backward()opt.step()for _ in range(1000):train_step()print(w.detach().numpy())

运行上述代码,可以得到和下面相近的结果:

[4.9924135, 0.00040895029, 3.4504161]

这和我们的参数非常接近。

上述只是 PyTorch 可以做的事情的冰山一角。很多问题,比如优化一个带有上百万参数的神经网络,都可以用 PyTorch 高效的用几行代码实现,PyTorch 可以跨多个设备和线程进行拓展,并且支持多个平台。


2. 将模型封装为模块

在之前的例子中,我们构建模型的方式是直接实现张量间的运算操作。但为了让代码看起来更加有组织,推荐采用 PyTorch 的 modules 模块。一个模块实际上是一个包含参数和压缩模型运算的容器。

比如,如果想实现一个线性模型 ,那么实现的代码可以如下所示:

import torchclass Net(torch.nn.Module):def __init__(self):super().__init__()self.a = torch.nn.Parameter(torch.rand(1))self.b = torch.nn.Parameter(torch.rand(1))def forward(self, x):yhat = self.a * x + self.breturn yhat

使用的例子如下所示,需要实例化声明的模型,并且像调用函数一样使用它:

x = torch.arange(100, dtype=torch.float32)net = Net()
y = net(x)

参数都是设置 requires_gradtrue 的张量。通过模型的 parameters() 方法可以很方便的访问和使用参数,如下所示:

for p in net.parameters():print(p)

现在,假设是一个未知的函数 y=5x+3+n ,注意这里的 n 是表示噪音,然后希望优化模型参数来拟合这个函数,首先可以简单从这个函数进行采样,得到一些样本数据:

x = torch.arange(100, dtype=torch.float32) / 100
y = 5 * x + 3 + torch.rand(100) * 0.3

和上一个例子类似,需要定义一个损失函数并优化模型的参数,如下所示:

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)for i in range(10000):net.zero_grad()yhat = net(x)loss = criterion(yhat, y)loss.backward()optimizer.step()print(net.a, net.b) # Should be close to 5 and 3

在 PyTorch 中已经实现了很多预定义好的模块。比如 torch.nn.Linear 就是一个类似上述例子中定义的一个更加通用的线性函数,所以我们可以采用这个函数来重写我们的模型代码,如下所示:

class Net(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(1, 1)def forward(self, x):yhat = self.linear(x.unsqueeze(1)).squeeze(1)return yhat

这里用到了两个函数,squeezeunsqueeze ,主要是torch.nn.Linear 会对一批向量而不是数值进行操作。

同样,默认调用 parameters() 会返回其所有子模块的参数:

net = Net()
for p in net.parameters():print(p)

当然也有一些预定义的模块是作为包容其他模块的容器,最常用的就是 torch.nn.Sequential ,它的名字就暗示了它主要用于堆叠多个模块(或者网络层),例如堆叠两个线性网络层,中间是一个非线性函数 ReLU ,如下所示:

model = torch.nn.Sequential(torch.nn.Linear(64, 32),torch.nn.ReLU(),torch.nn.Linear(32, 10),
)

3. 广播机制的优缺点

优点

PyTorch 支持广播的元素积运算。正常情况下,当想执行类似加法和乘法操作的时候,你需要确认操作数的形状是匹配的,比如无法进行一个 [3, 2] 大小的张量和 [3, 4] 大小的张量的加法操作。

但是存在一种特殊的情况:只有单一维度的时候,PyTorch 会隐式的根据另一个操作数的维度来拓展只有单一维度的操作数张量。因此,实现 [3,2] 大小的张量和 [3,1] 大小的张量相加的操作是合法的。

如下代码展示了一个加法的例子:

import torcha = torch.tensor([[1., 2.], [3., 4.]])
b = torch.tensor([[1.], [2.]])
# c = a + b.repeat([1, 2])
c = a + bprint(c)

广播机制可以实现隐式的维度复制操作(repeat 操作),并且代码更短,内存使用上也更加高效,因为不需要存储复制的数据的结果。这个机制非常适合用于结合多个维度不同的特征的时候。

为了拼接不同维度的特征,通常的做法是先对输入张量进行维度上的复制,然后拼接后使用非线性激活函数。整个过程的代码实现如下所示:

a = torch.rand([5, 3, 5])
b = torch.rand([5, 1, 6])linear = torch.nn.Linear(11, 10)# concat a and b and apply nonlinearity
tiled_b = b.repeat([1, 3, 1]) # b shape:  [5, 3, 6]
c = torch.cat([a, tiled_b], 2) # c shape: [5, 3, 11]
d = torch.nn.functional.relu(linear(c))print(d.shape)  # torch.Size([5, 3, 10])

但实际上通过广播机制可以实现得更加高效,即 f(m(x+y)) 是等同于 f(mx+my) 的,也就是我们可以先分别做线性操作,然后通过广播机制来做隐式的拼接操作,如下所示:

a = torch.rand([5, 3, 5])
b = torch.rand([5, 1, 6])linear1 = torch.nn.Linear(5, 10)
linear2 = torch.nn.Linear(6, 10)pa = linear1(a) # pa shape: [5, 3, 10]
pb = linear2(b) # pb shape: [5, 1, 10]
d = torch.nn.functional.relu(pa + pb)print(d.shape)  # torch.Size([5, 3, 10])

实际上这段代码非常通用,可以用于任意维度大小的张量,只要它们之间是可以实现广播机制的,如下所示:

class Merge(torch.nn.Module):def __init__(self, in_features1, in_features2, out_features, activation=None):super().__init__()self.linear1 = torch.nn.Linear(in_features1, out_features)self.linear2 = torch.nn.Linear(in_features2, out_features)self.activation = activationdef forward(self, a, b):pa = self.linear1(a)pb = self.linear2(b)c = pa + pbif self.activation is not None:c = self.activation(c)return c

缺点

到目前为止,我们讨论的都是广播机制的优点。但它的缺点是什么呢?原因也是出现在隐式的操作,这种做法非常不利于进行代码的调试。

这里给出一个代码例子:

a = torch.tensor([[1.], [2.]])
b = torch.tensor([1., 2.])
c = torch.sum(a + b)print(c)

所以上述代码的输出结果 c 是什么呢?你可能觉得是 6,但这是错的,正确答案是 12 。这是因为当两个张量的维度不匹配的时候,PyTorch 会自动将维度低的张量的第一个维度进行拓展,然后在进行元素之间的运算,所以这里会将b  先拓展为 [[1, 2], [1, 2]],然后 a+b 的结果应该是 [[2,3], [3, 4]] ,然后sum 操作是将所有元素求和得到结果 12。

那么避免这种结果的方法就是显式的操作,比如在这个例子中就需要指定好想要求和的维度,这样进行代码调试会更简单,代码修改后如下所示:

a = torch.tensor([[1.], [2.]])
b = torch.tensor([1., 2.])
c = torch.sum(a + b, 0)print(c)

这里得到的 c 的结果是 [5, 7],而我们基于结果的维度可以知道出现了错误。

这有个通用的做法,就是在做累加( reduction )操作或者使用 torch.squeeze 的时候总是指定好维度。


4. 使用好重载的运算符

和 NumPy 一样,PyTorch 会重载 python 的一些运算符来让 PyTorch 代码更简短和更有可读性。

例如,切片操作就是其中一个重载的运算符,可以更容易的对张量进行索引操作,如下所示:

z = x[begin:end]  # z = torch.narrow(0, begin, end-begin)

但需要谨慎使用这个运算符,它和其他运算符一样,也有一些副作用。正因为它是一个非常常用的运算操作,如果过度使用可以导致代码变得低效。

这里给出一个例子来展示它是如何导致代码变得低效的。这个例子中我们希望对一个矩阵手动实现行之间的累加操作:

import torch
import timex = torch.rand([500, 10])z = torch.zeros([10])start = time.time()
for i in range(500):z += x[i]
print("Took %f seconds." % (time.time() - start))

上述代码的运行速度会非常慢,因为总共调用了 500 次的切片操作,这就是过度使用了。一个更好的做法是采用 torch.unbind 运算符在每次循环中将矩阵切片为一个向量的列表,如下所示:

z = torch.zeros([10])
for x_i in torch.unbind(x):z += x_i

这个改进会提高一些速度(在作者的机器上是提高了大约30%)。

但正确的做法应该是采用 torch.sum 来一步实现累加的操作:

z = torch.sum(x, dim=0)

这种实现速度就非常的快(在作者的机器上提高了100%的速度)。

其他重载的算数和逻辑运算符分别是:

z = -x  # z = torch.neg(x)
z = x + y  # z = torch.add(x, y)
z = x - y
z = x * y  # z = torch.mul(x, y)
z = x / y  # z = torch.div(x, y)
z = x // y
z = x % y
z = x ** y  # z = torch.pow(x, y)
z = x @ y  # z = torch.matmul(x, y)
z = x > y
z = x >= y
z = x < y
z = x <= y
z = abs(x)  # z = torch.abs(x)
z = x & y
z = x | y
z = x ^ y  # z = torch.logical_xor(x, y)
z = ~x  # z = torch.logical_not(x)
z = x == y  # z = torch.eq(x, y)
z = x != y  # z = torch.ne(x, y)

还可以使用这些运算符的递增版本,比如 x += yx **=2 都是合法的。

另外,Python 并不允许重载 andornot 三个关键词。


精选AI文章

1. 10个实用的机器学习建议

2. 深度学习算法简要综述(上)

3. 深度学习算法简要综述(上)

4. 常见的数据增强项目和论文介绍

5. 实战|手把手教你训练一个基于Keras的多标签图像分类器

精选python文章

1.  python数据模型

2. python版代码整洁之道

3. 快速入门 Jupyter notebook

4. Jupyter 进阶教程

5. 10个高效的pandas技巧

精选教程资源文章

1. [资源分享] TensorFlow 官方中文版教程来了

2. [资源]推荐一些Python书籍和教程,入门和进阶的都有!

3. [Github项目推荐] 推荐三个助你更好利用Github的工具

4. Github上的各大高校资料以及国外公开课视频

5. GitHub上有哪些比较好的计算机视觉/机器视觉的项目?

欢迎关注我的微信公众号--算法猿的成长,或者扫描下方的二维码,大家一起交流,学习和进步!

如果觉得不错,在看、转发就是对小编的一个支持!

编写高效的PyTorch代码技巧(上)相关推荐

  1. 编写高效的PyTorch代码技巧(下)

    点击上方"算法猿的成长",关注公众号,选择加"星标"或"置顶" 总第 133 篇文章,本文大约 3000 字,阅读大约需要 15 分钟 原文 ...

  2. 编写高效的jQuery代码

    编写高效的jQuery代码 最近写了很多的js,虽然效果都实现了,但是总感觉自己写的js在性能上还能有很大的提升.本文我计划总结一些网上找的和我本人的一些建议,来提升你的jQuery和javascri ...

  3. 编写高效的Android代码

    编写高效的Android代码 转自:http://www.chinaup.org/docs/toolbox/performance.html 介绍 对于如何判断一个系统的不合理,这里有两个基本的原则: ...

  4. 编写并运行php程序,上传所编写的PHP程序代码,并上传运行后的效果截图

    上传所编写的PHP程序代码,并上传运行后的效果截图 更多相关问题 [多选] 对税务机关的下列行政行为,纳税人可以申请行政复议的有(). [多选] 纳税人收到税务机关的行政处罚决定书之后,在法定期限内可 ...

  5. 编写同时在PyTorch和Tensorflow上工作的代码

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 ❝ "库开发人员不再需要在框架之间进行选择." ...

  6. 【深度学习】编写同时在PyTorch和Tensorflow上工作的代码

    作者 | Ram Sagar 编译 | VK 来源 | Analytics In Diamag ❝ "库开发人员不再需要在框架之间进行选择." ❞ 来自德国图宾根人工智能中心的研究 ...

  7. 编写高效的java代码

    供稿人:肖华飚 概述 随着Java的广泛应用,越来越多的关键企业系统也使用Java构建.作为Java核心运行环境的Java虚拟机JVM被广泛地部署在各种系统平台上.对Java应用的性能优化也越来越受到 ...

  8. 编写高效Excel VBA代码的最佳实践(一)

    很多Excel VBA文章和图书都介绍过如何优化VBA代码,使代码运行得更快.下面搜集了一些使Excel VBA代码运行更快的技术和技巧,基本上都是实践经验的总结.如果您还有其它优化Excel VBA ...

  9. 转:编写高效的Android代码

    毫无疑问,基于Android平台的设备一定是嵌入式设备.现代的手持设备不仅仅是一部电话那么简单,它还是一个小型的手持电脑,但是,即使是最快的最高端的手持设备也远远比不上一个中等性能的桌面机. 这就是为 ...

最新文章

  1. vmrun 批量创建vmware虚拟机
  2. C++知识点22——使用C++标准库(顺序容器list的初始化、赋值、访问、交换、添加、删除与迭代器失效)
  3. buuctf [GKCTF 2021]你知道apng吗 <apng图片格式的考察>
  4. Object 标签遮挡 Div 显示
  5. matlab中solver函数_Simulink求解器(Solver)相关知识
  6. jdk说明文档_给JDK报了一个P4的Bug,结果居然……
  7. IT的2017,面临数字生态系统新挑战,该怎么办?
  8. zoom怎么解除静音_如何召开一场Zoom视频会议
  9. CakePHP:链接地址问题(不用mod_rewrite,IIS)
  10. C++ 引用的几个用法
  11. 计算机 管理 被停用,如果电脑上出现“你的账户已被停用请向系统管理员咨询”怎么办?...
  12. 2020-08-17 java实战项目汇总
  13. 【真题21套】计算机二级公共基础知识选择题真题【含解析】
  14. 映美精双目相机无法同时显示的问题
  15. 漏洞挖掘 符号执行_漏洞挖掘综述
  16. 撒金币动画android,Anime Gacha
  17. 国产系统中标麒麟安装教程
  18. 【精华帖】PS拼接图片最简单教程
  19. SIM卡在手机中的主要作用
  20. Typora 是什么?

热门文章

  1. sae mysql 同步本地_MYSQL入门之三_将本地MySQL数据导入SAE数据库_MySQL
  2. js 数组遍历符合条件跳出循环体_C++模拟面试:从数组“紧凑”操作说开来
  3. java 面试题 由浅入深_面试官由浅入深的面试套路
  4. 王道操作系统考研笔记——1.1.6 系统调用
  5. Proteus仿真单片机:PIC18单片机的仿真
  6. js rem 单位适配(手机、平板、PC)?
  7. 动画 自制弹框上滑+渐显效果
  8. windows等宽字体
  9. win7 下的 cmdhere 及其他
  10. 对vector中的数据排序