机器学习中qa测试_如何对机器学习做单元测试
作者:Chase Roberts
编译:ronghuaiyang
导读
养成良好的单元测试的习惯,真的是受益终身的,特别是机器学习代码,有些bug真不是看看就能看出来的。
在过去的一年里,我把大部分的工作时间都花在了深度学习研究和实习上。那一年,我犯了很多大错误,这些错误不仅帮助我了解了ML,还帮助我了解了如何正确而稳健地设计这些系统。我在谷歌Brain学到的一个主要原则是,单元测试可以决定算法的成败,可以为你节省数周的调试和训练时间。
然而,在如何为神经网络代码编写单元测试方面,似乎没有一个可靠的在线教程。即使是像OpenAI这样的地方,也只是通过盯着他们代码的每一行,并试着思考为什么它会导致bug来发现bug的。显然,我们大多数人都没有这样的时间,所以希望本教程能够帮助你开始理智地测试你的系统!
让我们从一个简单的例子开始。试着找出这段代码中的错误。
def make_convnet(input_image): net = slim.conv2d(input_image, 32, [11, 11], scope="conv1_11x11") net = slim.conv2d(input_image, 64, [5, 5], scope="conv2_5x5") net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1') net = slim.conv2d(input_image, 64, [5, 5], scope="conv3_5x5") net = slim.conv2d(input_image, 128, [3, 3], scope="conv4_3x3") net = slim.max_pool2d(net, [2, 2], scope='pool2') net = slim.conv2d(input_image, 128, [3, 3], scope="conv5_3x3") net = slim.max_pool2d(net, [2, 2], scope='pool3') net = slim.conv2d(input_image, 32, [1, 1], scope="conv6_1x1") return net
你看到了吗?网络实际上并没有堆积起来。在编写这段代码时,我复制并粘贴了slim.conv2d(…)行,并且只修改了内核大小,而没有修改实际的输入。
我很不好意思地说,这件事在一周前就发生在我身上了……但这是很重要的一课!由于一些原因,这些bug很难捕获。
- 这段代码不会崩溃,不会产生错误,甚至不会变慢。
- 这个网络仍在运行,损失仍将下降。
- 几个小时后,这些值就会收敛,但结果却非常糟糕,让你摸不着头脑,不知道需要修复什么。
当你唯一的反馈是最终的验证错误时,你惟一需要搜索的地方就是你的整个网络体系结构。不用说,你需要一个更好的系统。
那么,在我们进行完整的多日训练之前,我们如何真正抓住这个机会呢?关于这个最容易注意到的是层的值实际上不会到达函数外的任何其他张量。假设我们有某种类型的损失和一个优化器,这些张量永远不会得到优化,所以它们总是有它们的默认值。
我们可以通过简单的训练步骤和前后对比来检测它。
def test_convnet(): image = tf.placeholder(tf.float32, (None, 100, 100, 3) model = Model(image) sess = tf.Session() sess.run(tf.global_variables_initializer()) before = sess.run(tf.trainable_variables()) _ = sess.run(model.train, feed_dict={ image: np.ones((1, 100, 100, 3)), }) after = sess.run(tf.trainable_variables()) for b, a, n in zip(before, after): # Make sure something changed. assert (b != a).any()
在不到15行代码中,我们现在验证了至少我们创建的所有变量都得到了训练。
这个测试超级简单,超级有用。假设我们修复了前面的问题,现在我们要开始添加一些批归一化。看看你能否发现这个bug。
def make_convnet(image_input): # Try to normalize the input before convoluting net = slim.batch_norm(image_input) net = slim.conv2d(net, 32, [11, 11], scope="conv1_11x11") net = slim.conv2d(net, 64, [5, 5], scope="conv2_5x5") net = slim.max_pool2d(net, [4, 4], stride=4, scope='pool1') net = slim.conv2d(net, 64, [5, 5], scope="conv3_5x5") net = slim.conv2d(net, 128, [3, 3], scope="conv4_3x3") net = slim.max_pool2d(net, [2, 2], scope='pool2') net = slim.conv2d(net, 128, [3, 3], scope="conv5_3x3") net = slim.max_pool2d(net, [2, 2], scope='pool3') net = slim.conv2d(net, 32, [1, 1], scope="conv6_1x1") return net
你看到了吗?这个非常微妙。您可以看到,在tensorflow batch_norm中,is_training的默认值是False,所以添加这行代码并不能使你在训练期间的输入正常化!值得庆幸的是,我们编写的最后一个单元测试将立即发现这个问题!(我知道,因为这是三天前发生在我身上的事。)
再看一个例子。这实际上来自我一天看到的一篇文章(https://www.reddit.com/r/MachineLearning/comments/6qyvvg/p_tensorflow_response_is_making_no_sense/)。我不会讲太多细节,但是基本上这个人想要创建一个输出范围为(0,1)的分类器。
class Model: def __init__(self, input, labels): """Classifier model Args: input: Input tensor of size (None, input_dims) label: Label tensor of size (None, 1). Should be of type tf.int32. """ prediction = self.make_network(input) # Prediction size is (None, 1). self.loss = tf.nn.softmax_cross_entropy_with_logits( logits=prediction, labels=labels) self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
注意到这个错误吗?这是真的很难提前发现,并可能导致超级混乱的结果。基本上,这里发生的是预测只有一个输出,当你将softmax交叉熵应用到它上时,它的损失总是0。
一个简单的测试方法是确保损失不为0。
def test_loss(): in_tensor = tf.placeholder(tf.float32, (None, 3)) labels = tf.placeholder(tf.int32, None, 1)) model = Model(in_tensor, labels) sess = tf.Session() loss = sess.run(model.loss, feed_dict={ in_tensor:np.ones(1, 3), labels:[[1]] }) assert loss != 0
另一个很好的测试与我们的第一个测试类似,但是是反向的。你可以确保只有你想训练的变量得到了训练。以GAN为例。出现的一个常见错误是在进行优化时不小心忘记设置要训练的变量。这样的代码经常发生。
class GAN: def __init__(self, z_vector, true_images): # Pretend these are implemented. with tf.variable_scope("gen"): self.make_geneator(z_vector) with tf.variable_scope("des"): self.make_descriminator(true_images) opt = tf.AdamOptimizer() train_descrim = opt.minimize(self.descrim_loss) train_gen = opt.minimize(self.gen_loss)
这里最大的问题是优化器有一个默认设置来优化所有变量。在像GANs这样的高级架构中,这是对你所有训练时间的死刑判决。但是,你可以通过编写这样的测试来轻松地发现这些错误:
def test_gen_training(): model = Model sess = tf.Session() gen_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='gen') des_vars = tf.get_collection(tf.GraphKeys.VARIABLES, scope='des') before_gen = sess.run(gen_vars) before_des = sess.run(des_vars) # Train the generator. sess.run(model.train_gen) after_gen = sess.run(gen_vars) after_des = sess.run(des_vars) # Make sure the generator variables changed. for b,a in zip(before_gen, after_gen): assert (a != b).any() # Make sure descriminator did NOT change. for b,a in zip(before_des, after_des): assert (a == b).all()
可以为鉴别器编写一个非常类似的测试。同样的测试也可以用于许多强化学习算法。许多行为-批评模型有单独的网络,需要根据不同的损失进行优化。
下面是一些我推荐你进行测试的模式。
- 让测试具有确定性。如果一个测试以一种奇怪的方式失败,却永远无法重现这个错误,那就太糟糕了。如果你真的想要随机输入,确保使用种子随机数,这样你就可以轻松地重新运行测试。
- 保持测试简短。不要使用单元测试来训练收敛性并检查验证集。这样做是在浪费自己的时间。
- 确保你在每个测试之间重置了计算图。
总之,这些黑箱算法仍然有很多方法需要测试!花一个小时写一个测试可以节省你几天的重新运行训练模型,并可以大大提高你的研究效率。因为我们的实现有缺陷而不得不放弃完美的想法,这不是很糟糕吗?
这个列表显然不全面,但它是一个坚实的开始!
英文原文:https://medium.com/@keeper6928/how-to-unit-test-machine-learning-code-57cf6fd81765
机器学习中qa测试_如何对机器学习做单元测试相关推荐
- 机器学习中qa测试_学会区分人工智能和机器学习,并了解QA测试方法
点击上方关注,All in AI中国 智能手机.智能音箱.智能汽车.智能咖啡机--这样的例子不胜枚举.似乎我们周围的一切都变得鲜活起来,聪明起来.尽管科幻小说的繁荣源于我们对机器人恶意接管的恐惧,但智 ...
- 机器学习中qa测试_机器学习项目测试怎么做?(看实例)
机器学习交付项目通常包含两部分产物,一部分是机器学习模型,另一部分是机器学习应用系统.机器学习模型是嫁接在应用之上产生价值的.比如:一款预测雷雨天气的APP,它的雷雨预测功能就是由机器学习模型完成的. ...
- 机器学习中qa测试_如何使用AI和机器学习的QA测试软件
智能手机,智能扬声器,智能汽车,智能咖啡机......这个名单还在继续.看起来我们周围的一切都变得生机勃勃,变得聪明起来.尽管科幻类型依赖于我们对敌对机器人接管的恐惧,但智能设备绝不是反乌托邦 - 它 ...
- 机器学习中qa测试_机器学习自动化单元测试平台
机器学习自动化单元测试平台.零代码.全方位.自动化测试方法/函数的正确性和可用性. 原理 后端不需要写任何单元测试代码(逻辑代码.注解代码等全都不要), 这个工具会自动生成测试参数,并执行方法,拿到返 ...
- 范数在机器学习中的作用_设计在机器学习中的作用
范数在机器学习中的作用 Today, machine learning (ML) is a component of practically all new software products. Fo ...
- 02.PyTorch基础操作(3-1 机器学习中的分类与回归问题-机器学习基本构成元素)
@[TOC](02.PyTorch基础操作(3-1 机器学习中的分类与回归问题-机器学习基本构成元素)) 来自慕课网 一.3-1 机器学习中的分类与回归问题-机器学习基本构成元素
- 机器学习 文本分类 代码_无需担心机器学习-如何在少于10行代码中对文本进行分类
机器学习 文本分类 代码 This article builds upon my previous two articles where I share some tips on how to get ...
- 机器学习中的随机过程_机器学习过程
机器学习中的随机过程 If you would like to get a general introduction to Machine Learning before this, check ou ...
- 机器学习 训练验证测试_测试前验证| 机器学习
机器学习 训练验证测试 In my previous article, we have discussed about the need to train and test our model and ...
最新文章
- git shanchu stash_git stash用法
- vim+快捷键+常用+命令
- 【转载】学习嵌入式系统需要具备的条件、方法及步骤
- php如何删除数据库中的数据库文件夹,学习猿地-php数据库如何删除数据
- 0122 - EOS 编程学习日志(1)
- Dogleg“狗腿”最优化算法
- 搜推广遇上用户画像:Lookalike相似人群拓展算法
- 理解 OpenStack 高可用(HA) (6): MySQL HA
- 用ByteArrayOutputStream解决IO流乱码问题
- 计算机专业论文选题网站方面,5大网站汇总,搞定新颖的计算机专业毕业设计网站汇总...
- 使用springboot + druid + mybatisplus完成多数据源配置
- python画三维图-Python基于matplotlib实现绘制三维图形功能示例
- 使用Easy Duplicate Photo Finder for Mac如何查找重复的图片?
- 深度学习还是鼠标搞定,零基础建网站必备技能
- 非递归获取二叉树中叶子结点的个数
- 一步一步搭建11gR2 rac+dg之配置单实例的DG(八)
- html敲tab键无法新建,sublime按tab键无法补全html页面模板解决办法
- 配置Skype for business 2015混合部署
- 阿里云服务器的80端口被封了么?
- 未来的计算机将是半导体,硅的未来岌岌可危?未来计算机或迎来钻石芯
热门文章
- Android Studio Gradle两种更新方式
- 优秀程序员必备的15大技能
- C程序员要学C++吗?
- C语言指针与数组之间的恩恩怨怨
- 解决:single failed: For artifact {null:null:null:jar}: The groupId cannot be empty. 把工程依赖的jar包打到入jar中
- IDEA 中的.iml文件和.idea文件夹 ( 隐藏方式 )
- springCloud - 第7篇 - 配置文件管理中心 ( SpringCloud Config )
- win2008r2 AD用户账户的批量导入方法
- [译]git revert
- umask命令:设置文件的默认权限掩码