推荐一个快速定位深度学习代码bug的炼丹神器!
文 | McGL
源 | 知乎
写深度学习网络代码,最大的挑战之一,尤其对新手来说,就是把所有的张量维度正确对齐。如果以前就有TensorSensor这个工具,相信我的头发一定比现在更浓密茂盛!
TensorSensor,码痴教授 Terence Parr 出品,他也是著名 parser 工具 ANTLR 的作者。
在包含多个张量和张量运算的复杂表达式中,张量的维数很容易忘了。即使只是将数据输入到预定义的 TensorFlow 网络层,维度也要弄对。当你要求进行错误的计算时,通常会得到一些没啥用的异常消息。为了帮助自己和其他程序员调试张量代码,Terence Parr 写了一个名叫 TensorSensor 的库(pip install tensor-sensor 直接安装) 。TensorSensor 通过增加消息和可视化 Python 代码来展示张量变量的形状,让异常更清晰(见下图)。它可以兼容 TensorFlow、PyTorch 和 Numpy以及 Keras 和 fastai 等高级库。
在张量代码中定位问题令人抓狂!
即使是专家,执行张量操作的 Python 代码行中发生异常,也很难快速定位原因。调试过程通常是在有问题的行前面添加一个 print 语句,以打出每个张量的形状。这需要编辑代码添加调试语句并重新运行训练过程。或者,我们可以使用交互式调试器手动单击或键入命令来请求所有张量形状。(这在像 PyCharm 这样的 IDE 中不太实用,因为在调试模式很慢。)下面将详细对比展示看了让人贫血的缺省异常消息和 TensorSensor 提出的方法,而不用调试器或 print 大法。
调试一个简单的线性层
让我们来看一个简单的张量计算,来说明缺省异常消息提供的信息不太理想。下面是一个包含张量维度错误的硬编码单(线性)网络层的简单 NumPy 实现。
import numpy as npn = 200 # number of instances
d = 764 # number of instance features
n_neurons = 100 # how many neurons in this layer?W = np.random.rand(d,n_neurons) # Ooops! Should be (n_neurons,d)
b = np.random.rand(n_neurons,1)
X = np.random.rand(n,d) # fake input matrix with n rows of d-dimensionsY = W @ X.T + b # pass all X instances through layer
10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
执行该代码会触发一个异常,其重要元素如下:
...
---> 10 Y = W @ X.T + b
ValueError: matmul: Input operand 1 has a mismatch in its core dimension 0, with gufunc signature (n?,k),(k,m?)->(n?,m?) (size 764 is different from 100)
异常显示了出错的行以及是哪个操作(matmul: 矩阵乘法),但是如果给出完整的张量维数会更有用。此外,这个异常也无法区分在 Python 的一行中的多个矩阵乘法。
接下来,让我们看看 TensorSensor 如何使调试语句更加容易的。如果我们使用 Python with 和tsensor 的 clarify()包装语句,我们将得到一个可视化和增强的错误消息。
import tsensor
with tsensor.clarify():Y = W @ X.T + b
...
ValueError: matmul: Input operand ...
Cause: @ on tensor operand W w/shape (764, 100) and operand X.T w/shape (764, 200)
从可视化中可以清楚地看到,W 的维度应该翻转为 n _ neurons x d; W 的列必须与 X.T 的行匹配。您还可以检查一个完整的带有和不带阐明()的并排图像,以查看它在笔记本中的样子。下面是带有和没有 clarify() 的例子在notebook 中的比较。
clarify() 功能在没有异常时不会增加正在执行的程序任何开销。有异常时, clarify():
增加由底层张量库创建的异常对象消息。
给出出错操作所涉及的张量大小的可视化表示; 只突出显示异常涉及的操作对象和运算符,而其他 Python 元素则不突出显示。
TensorSensor 还区分了 PyTorch 和 TensorFlow 引发的与张量相关的异常。下面是等效的代码片段和增强的异常错误消息(Cause: @ on tensor ...)以及 TensorSensor 的可视化:
PyTorch 消息没有标识是哪个操作触发了异常,但 TensorFlow 的消息指出了是矩阵乘法。两者都显示操作对象维度。
调试复杂的张量表达式
缺省消息缺乏具体细节,在包含大量操作符的更复杂的语句中,识别出有问题的子表达式很难。例如,下面是从一个门控循环单元(GRU)实现的内部提取的一个语句:
h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
这是什么计算或者变量代表什么不重要,它们只是张量变量。有两个矩阵乘法,两个向量加法,还有一个向量逐元素修改(r*h)。如果没有增强的错误消息或可视化,我们就无法知道是哪个操作符或操作对象导致了异常。为了演示 TensorSensor 在这种情况下是如何分清异常的,我们需要给语句中使用的变量(为 h _ 赋值)一些伪定义,以得到可执行代码:
nhidden = 256
Whh_ = torch.eye(nhidden, nhidden) # Identity matrix
Uxh_ = torch.randn(d, nhidden)
bh_ = torch.zeros(nhidden, 1)
h = torch.randn(nhidden, 1) # fake previous hidden state h
r = torch.randn(nhidden, 1) # fake this computation
X = torch.rand(n,d) # fake inputwith tsensor.clarify():h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
同样,你可以忽略代码执行的实际计算,将重点放在张量变量的形状上。
对于我们大多数人来说,仅仅通过张量维数和张量代码是不可能识别问题的。当然,默认的异常消息是有帮助的,但是我们中的大多数人仍然难以定位问题。以下是默认异常消息的关键部分(注意对 C++ 代码的不太有用的引用) :
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
我们需要知道的是哪个操作符和操作对象出错了,然后我们可以通过维数来确定问题。以下是 TensorSensor 的可视化和增强的异常消息:
---> 10 h_ = torch.tanh(Whh_ @ (r*h) + Uxh_ @ X.T + bh_)
RuntimeError: size mismatch, m1: [764 x 256], m2: [764 x 200] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: @ on tensor operand Uxh_ w/shape [764, 256] and operand X.T w/shape [764, 200]
人眼可以迅速锁定在指示的算子和矩阵相乘的维度上。哎呀, Uxh 的列必须与 X.T的行匹配,Uxh_的维度翻转了,应该为:
Uxh_ = torch.randn(nhidden, d)
现在,我们只在 with 代码块中使用我们自己直接指定的张量计算。那么在张量库的内置预建网络层中触发的异常又会如何呢?
理清预建层中触发的异常
TensorSensor 可视化进入你选择的张量库前的最后一段代码。例如,让我们使用标准的 PyTorch nn.Linear 线性层,但输入一个 X 矩阵维度是 n x n,而不是正确的 n x d:
L = torch.nn.Linear(d, n_neurons)
X = torch.rand(n,n) # oops! Should be n x d
with tsensor.clarify():Y = L(X)
增强的异常信息
RuntimeError: size mismatch, m1: [200 x 200], m2: [764 x 100] at /tmp/pip-req-build-as628lz5/aten/src/TH/generic/THTensorMath.cpp:41
Cause: L(X) tensor arg X w/shape [200, 200]
TensorSensor 将张量库的调用视为操作符,无论是对网络层还是对 torch.dot(a,b) 之类的简单操作的调用。在库函数中触发的异常会产生消息,消息标示了函数和任何张量参数的维数。
后台回复关键词【入群】
加入卖萌屋NLP/IR/Rec与求职讨论群
后台回复关键词【顶会】
获取ACL、CIKM等各大顶会论文集!
[1] https://explained.ai/tensor-sensor/index.html
推荐一个快速定位深度学习代码bug的炼丹神器!相关推荐
- 推荐 | 一个机器学习与深度学习的优质公众号
学习资源推荐 推荐人 榛果 俗话说,一个人走得快,但一群人可以走的远.在数据科学和机器学习的道路上,相信每个人都不是闭门造车的人.技术学习除了在个人努力外,交流和分享也是很重要的一部分. 今天给大家推 ...
- python深度学习代码列子
以下是一个简单的深度学习代码的例子,可以帮助你了解深度学习的基本概念和实现方法. # Import necessary libraries import tensorflow as tf import ...
- 如何才能信任你的深度学习代码?
深度学习是一门很难评估代码正确性的学科.随机初始化.庞大的数据集和权重的有限可解释性意味着,要找到模型为什么不能训练的确切问题,大多数时候都需要反复试验.在传统的软件开发中,自动化单元测试是确定代码是 ...
- 如何快速入门深度学习写论文?
原文作者:月来客栈 https://www.zhihu.com/people/the_lastest 最快的方式: 第一,选择一篇有代码的论文,记住一定要有代码: 第二,大致弄清楚论文里所提出 ...
- 论文合集 | 李飞飞新论文:深度学习代码搜索综述;Adobe用GAN生成动画(附地址)...
来源:机器之心 本文约3200字,建议阅读7分钟. 本文介绍了李飞飞新论文,深度学习代码搜索综述,Adobe用GAN生成动画. 本周有李飞飞.朱玉可等的图像因果推理和吴恩达等的 NGBoost 新论文 ...
- 快速入门深度学习,其实并不难!
深度学习的概念源于人工神经网络的研究,而深度学习的过程就是使用多个处理层对数据进行高层抽象,得到多重非线性变换函数的过程. 虽然深度学习的概念看似高大上,让人有种莫名的距离感,实际上它在日常生活中随处 ...
- 如何提高深度学习代码能力
个人经历:一个正在努力提高自身代码能力和实践能力的求职人员. 背景:想通过几个实践项目提高工程实践能力和底层代码能力.提出这个问题是因为自己对深度学习的探究更多停留在能根据文章看懂代码,能根据代码更深 ...
- 新手如何快速入门 深度学习
如何快速入门深度学习 深度学习入门必备基础 避开常见误区 学习路线图 干货分享 深度学习必备基础 深度学习发展至今已然有几个年头了,上个世纪九十年代的美国银行率先使用深度学习技术做为手写字体识别,但深 ...
- 新手如何快速入门深度学习领域
如何快速入门深度学习 本篇学习笔记对应深度学习入门视频课程 博客地址:http://blog.csdn.net/tangyudi 欢迎转载 深度学习入门必备基础 避开常见误区 学习路线图 干货分享 深 ...
最新文章
- python m什么意思_Python -m参数原理及使用方法解析
- 【android】og
- 很多网站,软件对自定义的dpi支持不好
- jpa初学 hibernate学习
- Guava的介绍与使用示例
- 发现个好玩的,去页面敲键盘,页面键变色
- ssh: connect to host port 22: Connection refused
- 单例设计模式全局缓存accessToken
- tomcat集成activeMq 简单例子
- VMware Workstation 12 Pro的安装
- js实现全国省份下拉
- HI3798MV200驱动移植
- excel两列数据对比找不同_技巧不求人168期 Excel两列数据找不同的3种方法 Word快速更改文本排序...
- 今夏流行的十大避暑胜地
- 几个实用的生活服务网站和APP
- 三次握手的过程、四次挥手、为什么要进行第三次握手、为什么要进行四次挥手
- 鸿蒙系统如何进入语音助手,原来华为手机的语音助手还可以这么玩,九个实用技能分享给你...
- 【缺陷检测】基于形态学实现印刷电路板缺陷检测技术附matlab代码
- pip install mysqlclient安装
- 16天进入“已问询”状态,上市进程神速,这家芯片设计企业凭什么?
热门文章
- 排序算法之冒泡排序(C/C++)
- of_property_read_string 剖析~
- Python3——文件与异常
- ubuntu常见问题
- python中import os_Python常用模块os--与操作系统交互
- 十、关于MySQL 标识列,你该了解这些!
- LeetCode 2000. 反转单词前缀
- LeetCode 1750. 删除字符串两端相同字符后的最短长度(双指针)
- [Kaggle] Digit Recognizer 手写数字识别(神经网络)
- LeetCode 1198. 找出所有行中最小公共元素(二分/合并有序链表)