作者|维度

来源|机器之心

近日,有用户在自己的项目中发现了一个微小的 bug,在 PyTorch 同时使用 NumPy 的随机数生成器和多进程数据加载会导致相同的扩充数据,只有专门设置 seed 才可以解决这个 bug,否则会降低模型的准确率。不过,有人认为这并不是一个 bug,而是预期功能,是「按预期工作的」。

行内人都知道,机器学习(ML)代码中的 bug 很难修复,并且它们不会造成编译错误,而是悄悄地降低准确率。这些 bug 简直防不胜防。最近,一位专注于机器学习的用户遇到了一个非常熟悉的 bug,修复了之后性能有了大幅度提升。这是一个什么样的 bug 呢?

根据用户的描述,bug 是这样的:除非你在 DataLoader 中使用 worker_init_fn 选项专门设置 seed,否则在 PyTorch 同时使用 NumPy 的随机数生成器和多进程数据加载会导致相同的扩充数据。用户没有这样做,因而这个 bug 悄悄地降低了模型的准确率。

该 bug 非常小并且很容易出现。所以,这位用户很好奇会不会也对其他项目造成损害呢?ta 从 GitHub 上下载了 10 万个导入 PyTorch 的库,并分析了这些库的源代码。之后,ta 保留了那些具有自定义数据集、同时使用 NumPy 的随机数生成器和多进程数据加载以及或多或少使用抽象语法树进行分析的项目。

结果显示,95% 以上的库存在着这个 bug,如 PyTorch 的官方教程、OpenAI 的代码以及 NVIDIA 的项目。甚至特斯拉 AI 负责人 Andrej Karpathy 也曾遭受过该 bug 的困扰。

OpenAI 的 ebm_code_release 项目。

这个 bug 究竟怎样影响模型的准确率?这位用户从以下两个示例中进行了简要描述。

bug 描述

在 PyTorch 中加载、预处理和扩充数据的标准方法是子类化 torch.utils.data.Dataset 并重写 __getitem__方法。要应用扩充方法(如随机裁剪、图像翻转),__getitem__方法经常使用 NumPy 来生成随机数,然后将 map-styled 数据集传递给 DataLoader 来创建 batch。这种训练 pipeline 可能会受到数据预处理的阻碍,因此并行加载数据是有意义的。可以通过增加 DataLoader 对象中的 num_workers 参数来实现。

问题是,这个工作流导致了相同的数据扩充。

PyTorch 使用多进程并行加载数据,worker 进程是使用 fork start 方法创建的。这意味着每个工作进程继承父进程的所有资源,包括 NumPy 的随机数生成器的状态。

示例 1

为了更加形象地描述问题,用户从以下两个示例中进行了简要概述。

示例 1 为一个示例数据集,它返回三个元素的随机向量。示例使用两个和四个工作进程的 batch 大小。

代码返回如下结果:每个进程返回的随机数都是相同的。

示例 2

示例 2 演示了如何在 face-landmarks 数据集上使用 Dataset 和 DataLoader 类。此外,还提到了数据扩充的重要性,并提供了一个随机裁剪扩充的例子。这是使用 NumPy 的随机数生成器实现的。

通过增加 num_workers 来加速数据加载,可以得到相同的裁剪结果:

batch 大小为 8, num_workers 为 2,random crop augmentation(随机裁剪扩充)

这个 bug 很容易产生。在某些情况下,它对最终性能的影响很小。在另一些情况下,相同的扩充会导致严重的退化。

基于对开放源码 PyTorch 项目的分析,发现 bug 的这位用户担心这个问题在许多支持真实产品的代码库中都存在。

究竟是 bug,还是预期功能或特征?

这位用户描述的 bug 也引起了众多网友的热议,其中一些人并不认为这是 bug。

用户「amasterblaster」认为,这不是一个 bug,而是所有种子随机函数的预期功能。这是因为即使在随机实验中,有时你想要对比静态参数的变化,并得到相同的随机数。只有当你被读为真随机(true random)时,才会根据 OS time 设置 seed。

用户「xicor7017」表示自己也遇到了相同的问题,也认为它并不是一个 bug,而是一个可能不为人所知的特征。如果忽略它的话,调试问题时会很麻烦。

与此同时,另一些人表达出了不同的观点,认为既然「如果事情朝着人们不希望的方向发展,那么它就不应该这样,也就构成了 bug。」

用户「IntelArtiGen」称自己意识到了这个 bug,认为它是不正常的,并且对自己的项目造成了一些小问题。用户「gwern」赞同这种观点,认为如果 95% 以上的用户使用时出现错误,则代码就是错的。

用户「synonymous1964」进一步解读了这个 bug。ta 认为,人们可能误解了这个问题,问题不在于设置特定的随机种子会导致每次训练过程中生成相同序列的随机数,这显然是按预期工作的。相反,问题在于多个数据下载进程中(由 PyTorch 中的 num_workers 设置)的每个进程都会在某个特定的训练过程中输出相同序列的随机数。毫无疑问,这当然会对项目造成影响,具体取决于你如何进行数据加载和扩充。所以,即使这个 bug 是「按预期工作的」,但向更多其他用户指出来也挺好的。

不知道机器之心的读者,有没有遇到过类似的 bug 呢?如果有,可以在评论中发表自己对该 bug 的观点。

参考链接:

https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/

https://www.reddit.com/r/MachineLearning/comments/mocpgj/p_using_pytorch_numpy_a_bug_that_plagues/

????

现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧

关于PaperWeekly

PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。

PyTorch + NumPy这么做会降低模型准确率,这是bug还是预期功能?相关推荐

  1. PyTorch + NumPy这么做会降低模型准确率?

    点上方计算机视觉联盟获取更多干货 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:机器之心 AI博士笔记系列推荐 周志华<机器学习>手推笔记正式开源!可打印版本附pdf下载链接 近 ...

  2. 深度学习提高模型准确率方法

    这里写目录标题 深度学习 数据 使用更多数据 更改图像大小 减少颜色通道 算法 模型改进 增加训练轮次 迁移学习 添加更多层 调整超参数 总结 深度学习 我们已经收集好了一个数据集,建立了一个神经网络 ...

  3. 【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

    文章目录 前言 CIFAR10简介 Backbone选择 训练+测试 训练环境及超参设置 完整代码 部分测试结果 完整工程文件 Reference 前言 分享一下本人去年入门深度学习时,在CIFAR1 ...

  4. Hinton等人最新研究:大幅提升模型准确率,标签平滑技术到底怎么用?

    作者 | Rafael Müller , Simon Kornblith, Geoffrey Hinton 译者 | Rachel 责编 | Jane 出品 | AI科技大本营(ID: rgznai1 ...

  5. Github标星10.4k:用 NumPy 实现所有主流机器学习模型

    用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgin 最近开源了一个非常剽悍的项目.超过 3 万行代码.30 多个模型,这也许能打造「最强」的机器学习基石?(编辑:机器 ...

  6. 内涝预测过程的噪音_提高人工智能模型准确率的测试过程中需要注意什么?

    黑马程序员视频库 播妞微信号:boniu236 传智播客旗下互联网资讯.学习资源免费分享平台 现在人工智能行业发展迅猛,那么人工智能产品特别是使用分类算法实现的产品中判断其能否上线通常是通过算法自带的 ...

  7. 一步步读懂Pytorch Chatbot Tutorial代码(四) - 为模型准备数据

    文章目录 自述 有用的工具 代码出处 目录 头大 代码及说明 Prepare Data for Models 重点关注 indexesFromSentence zeroPadding binaryMa ...

  8. 用pyinstaller打包pytorch环境下的深度学习模型,实现通过exe程序实现界面显示模型的分类效果

    用pyinstaller打包pytorch环境下的深度学习模型,实现通过exe应用实现界面显示模型的分类效果 训练深度学习模型和界面显示,看我之前的博客,链接在下面: 通过残差网络实现CLFAR-10 ...

  9. 深度学习中学习率和batchsize对模型准确率的影响

    本内容来自其他的人解析,参考链接在最后的注释. 1. 前言 目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下: n是批量大小(batchsize),η是学习率(learn ...

最新文章

  1. CDN加速技术和云计算
  2. 《Linux系统初讲》学习总结(一)
  3. mysql创建表的时候不要添加drop操作
  4. (七)全半角转换(转)
  5. linux jar和zip,Linux命令———zip和jar文件压缩解压
  6. 关于PE可执行文件的修改
  7. 2字节取值范围_高中数学:构造不等式,解析几何范围题的有效解法
  8. js原生继承几种方式
  9. 网易云音乐Android一面面经
  10. 泰迪杯数据挖掘挑战赛—数据预处理(二)
  11. 【质量管理】41页PPT系统学习质量管理体系!
  12. C++移动语义及拷贝优化
  13. 关于最短剩余时间优先算法-进程调度模拟【C++】
  14. Burp Suite 实战指南
  15. 深度学习论文写作框架
  16. BI工具调研之——帆软
  17. JavaScriptJQuery_jQuery选择器
  18. Linux中的if-then语句
  19. monkey脚本执行中如何强行停止
  20. 离职了,写点什么吧~

热门文章

  1. java系统界面找不到符号,找不到符号,java找不到符号
  2. VS 团队资源管理 强制解锁锁定文件
  3. 走在程序世界道路上的我___大一篇
  4. java 气泡 提示插件_Java气泡提示功能实现
  5. flex 会使div撑满_如何讲清楚Flex弹性盒模型?(中)
  6. elasticsearch狂神说笔记_神级学习笔记!别再说不会Elasticsearch了,这位架构师都整理好了...
  7. linux下使用python3_Linux上python3的安装和使用
  8. 已经围上为何不算目_在湖人打球顺风顺水,戴维斯为何还要亏本卖掉洛杉矶豪宅?...
  9. plt图片输出 python_利用Python制作词云,wordcloud神器你值得拥有
  10. html5 json转字符串,web前端-js小记(5)-字符串及json