从零实现深度学习框架——实现Debug功能与no_grad
引言
本着“凡我不能创造的,我就不能理解”的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导。
要深入理解深度学习,从零开始创建的经验非常重要,从自己可以理解的角度出发,尽量不适用外部完备的框架前提下,实现我们想要的模型。本系列文章的宗旨就是通过这样的过程,让大家切实掌握深度学习底层实现,而不是仅做一个调包侠。
本系列文章首发于微信公众号:JavaNLP
有时候我们编写一个复杂的模型,想知道模型耗时的瓶颈在哪里,或者想知道模型是如何反向传播的。这时候就需要DEBUG功能,本文就来为我们的metagrad实现debug功能。
创建上下文管理器
class Config:debug = False@contextlib.contextmanager
def using_config(name, value):# 保存旧值old_value = getattr(Config, name)# 设置新值setattr(Config, name, value)try:yieldfinally:# 最终设回旧值setattr(Config, name, old_value)
首先创建一个Config
类,它有一个debug
属性,用来表示当前是否为DEBUG模式。
contextmanager
这个装饰器(decorator)接收一个生成器(generator),该generator必须只yield
一个值出来,该值会被用在with
语句中,绑定到as
后面的变量。
我们这里只需要修改Config
内部状态,不需要返回任何值,可以只加一个yield
。
创建操作包装类
class OpWrapper:'''支持反向传播的Debug'''def __init__(self, name, xs, backward=False):self.name = f"back_{name}" if backward else nameself.xs = xsself.output = Nonedef __enter__(self):if Config.debug:self.start = time.time()return selfdef __exit__(self, *junk):if Config.debug:end = (time.time() - self.start) * 1000print(f"{self.name:>20} : {end:>7.2f} ms {str([y.shape for y in self.xs]):>40} "f"{'-> ' + str(self.output.shape) if self.output is not None else ''}")
创建一个操作包装类,实现魔法方法__enter__
和__exit_
。当用在with
语句中时,会根据Config.debug
值来决定是否记录时间,以及打印DEBUG信息。
应用操作包装类
def debug_mode():return using_config("debug", True)
首先创建一个函数,修改debug
为True
。
修改Tensor#backward
方法:
with OpWrapper(t._ctx.__class__.__name__, [t.grad], backward=True):# 以逆序计算梯度,调用t相关运算操作的backward静态方法# 计算流向其依赖节点上的梯度(流向其下游)grads = t._ctx.backward(t._ctx, t.grad.data)
我们只需要将调用backward
方法的代码放进OpWrapper
的上下文中即可。
测试DEBUG
修改test_sigmoid
函数:
def test_sigmoid():x = np.array([[0, 1, 2], [0, 2, 4]], np.float32)with debug_mode():mx = Tensor(x, requires_grad=True)y = F.sigmoid(mx)tx = torch.tensor(x, requires_grad=True)ty = torch.sigmoid(tx)assert np.allclose(y.data, ty.data)y.sum().backward()ty.sum().backward()assert np.allclose(mx.grad.data, tx.grad.data)
这里演示了debug_mode
的使用,输出如下:
============================= test session starts =============================
collecting ... collected 1 itemtest_sigmoid.py::test_sigmoid PASSED [100%] back_Sum : 0.00 ms [()] back_TrueDiv : 0.00 ms [(2, 3)] back_Add : 0.00 ms [(2, 3)] back_Exp : 0.00 ms [(2, 3)] back_Neg : 0.00 ms [(2, 3)] ======================== 1 passed, 1 warning in 0.46s =========================
这里以打印出了反向传播中调用的方法、耗时以及操作的维度。
y = F.sigmoid(mx)
y.sum().backward()
首选调用了sigmoid
函数,实际上为:
σ(z)=11+exp(−z)\sigma(z) = \frac{1}{1 + \exp(-z) } σ(z)=1+exp(−z)1
然后为了方便求梯度,我们调用了sum
函数。所以反向传播时先经过Sum
的backward
方法,然后是σ\sigmaσ中的除法,再然后是1+exp(−z)1 + \exp(-z)1+exp(−z)中的加法,再是exp\expexp,最后是−z-z−z。
可以看到,整个反向传播过程都打印了出来,而且还有对应的维度,方便我们进行调试。
除此之外,我们还实现类似PyTorch中的no_grad()
方法。
实现no_grad
no_grad
的意思是,该上下文中的代码不需要计算梯度,常用于推理阶段或者在验证集上验证。
有了上面的工作,我们实现起来就非常简单:
class Config:debug = Falsebackprop = True # 是否需要计算并反向传播梯度
首先修改Config
增加一个backprop
属性,用于判断是否需要计算梯度。
def no_grad():return using_config("backprop", False)
然后增加no_grad
函数,用于修改backprop
属性。表示当前上下文不需要计算梯度。
def backward(self, grad: "Tensor" = None) -> None:'''实现Tensor的反向传播Args:grad: 如果该Tensor不是标量,则需要传递梯度进来Returns:'''# 只能在requires_grad=True的Tensor上调用此方法assert self.requires_grad, "called backward on tensor do not require grad"if not Config.backprop:return
修改backward
函数,如果Config.backprop
为False
,那么该函数直接返回。
有了此方法,我们以后就可以拆分训练、测试或验证集了。
从零实现深度学习框架——实现Debug功能与no_grad相关推荐
- python学习框架图-从零搭建深度学习框架(二)用Python实现计算图和自动微分
我们在上一篇文章<从零搭建深度学习框架(一)用NumPy实现GAN>中用Python+NumPy实现了一个简单的GAN模型,并大致设想了一下深度学习框架需要实现的主要功能.其中,不确定性最 ...
- 从零实现深度学习框架——GloVe从理论到实战
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.
- 从零实现深度学习框架——Seq2Seq从理论到实战【实战】
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.
- 从零实现深度学习框架——RNN从理论到实战【理论】
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.
- 从零实现深度学习框架——深入浅出Word2vec(下)
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导. 要深入理解深度学 ...
- 从零实现深度学习框架——从共现矩阵到点互信息
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.
- 从零实现深度学习框架——LSTM从理论到实战【理论】
引言 本着"凡我不能创造的,我就不能理解"的思想,本系列文章会基于纯Python以及NumPy从零创建自己的深度学习框架,该框架类似PyTorch能实现自动求导.
- python深度学习include框架_《用Python实现深度学习框架》上市
朋友们,<用Python实现深度学习框架>已经由人民邮电出版社出版上市了.在这本书中,我们带领读者仅用Python+Numpy实现一个基于计算图的深度学习框架MatrixSlow.本书讲解 ...
- python深度学习include框架_搞事情。《用Python实现深度学习框架》已出版上架。...
我和 @张觉非 合作的<用Python实现深度学习框架>一书已经由人民邮电出版社出版上市了.写作本书的缘由,是2017年11月我加入了360,开始负责以机器学习平台为中心的AI技术设施的研 ...
- 基于python的深度学习框架有_《用Python实现深度学习框架》上市
朋友们,<用Python实现深度学习框架>已经由人民邮电出版社出版上市了.在这本书中,我们带领读者仅用Python+Numpy实现一个基于计算图的深度学习框架MatrixSlow.本书讲解 ...
最新文章
- 动态给H5页面绑定数据,基本万能无错误!
- C++函数和类的封装
- ElasticSearch - 嵌套对象 nested
- python安装error: Unable to find vcvarsall.bat
- JAVA动漫论坛BBS系统的设计与实现
- php算法-输出100以内能被3整除的整数
- 大华相机RTSP获取视频方式
- 干线公路交叉口右转车辆与非机动车冲突精细化治理实例
- 驱动精灵w8ndows xp sp2,惠普HP LaserJet 1020打印机驱动官方正式版下载,适用于winxp,winvista,win7,win8,win10-驱动精灵...
- 学校计算机网络教室,关元学校计算机网络教室使用管理制度
- 在Micrium uC/Probe中添加IAR生成的.out文件的问题
- 双机热备——上下层交换机负载分担
- [实用电脑技术]Google Chrome谷歌浏览器下载完整离线安装版本
- VirtualBox管理工具Vboxmanage
- Ribbon负载均衡策略初步解读
- FM FFM:深入理解FM与FFM
- 检查suse是否安装ftp服务,安装:SuSE Linux FTP版安装指南(转)
- CSS3干货14:自定义页面滚动条
- 成功解决:You are using pip version 9.0.3, however version 20.3.3 is available. You should consider upgra
- 二代测序linux软件,二代测序数据分析软件包大全