点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:机器之心

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

近日,有用户在自己的项目中发现了一个微小的 bug,在 PyTorch 同时使用 NumPy 的随机数生成器和多进程数据加载会导致相同的扩充数据,只有专门设置 seed 才可以解决这个 bug,否则会降低模型的准确率。不过,有人认为这并不是一个 bug,而是预期功能,是「按预期工作的」。

行内人都知道,机器学习(ML)代码中的 bug 很难修复,并且它们不会造成编译错误,而是悄悄地降低准确率。这些 bug 简直防不胜防。最近,一位专注于机器学习的用户遇到了一个非常熟悉的 bug,修复了之后性能有了大幅度提升。这是一个什么样的 bug 呢?

根据用户的描述,bug 是这样的:除非你在 DataLoader 中使用 worker_init_fn 选项专门设置 seed,否则在 PyTorch 同时使用 NumPy 的随机数生成器和多进程数据加载会导致相同的扩充数据。用户没有这样做,因而这个 bug 悄悄地降低了模型的准确率。

该 bug 非常小并且很容易出现。所以,这位用户很好奇会不会也对其他项目造成损害呢?ta 从 GitHub 上下载了 10 万个导入 PyTorch 的库,并分析了这些库的源代码。之后,ta 保留了那些具有自定义数据集、同时使用 NumPy 的随机数生成器和多进程数据加载以及或多或少使用抽象语法树进行分析的项目。

结果显示,95% 以上的库存在着这个 bug,如 PyTorch 的官方教程、OpenAI 的代码以及 NVIDIA 的项目。甚至特斯拉 AI 负责人 Andrej Karpathy 也曾遭受过该 bug 的困扰。

OpenAI 的 ebm_code_release 项目。

这个 bug 究竟怎样影响模型的准确率?这位用户从以下两个示例中进行了简要描述。

bug 描述

在 PyTorch 中加载、预处理和扩充数据的标准方法是子类化 torch.utils.data.Dataset 并重写 __getitem__方法。要应用扩充方法(如随机裁剪、图像翻转),__getitem__方法经常使用 NumPy 来生成随机数,然后将 map-styled 数据集传递给 DataLoader 来创建 batch。这种训练 pipeline 可能会受到数据预处理的阻碍,因此并行加载数据是有意义的。可以通过增加 DataLoader 对象中的 num_workers 参数来实现。

问题是,这个工作流导致了相同的数据扩充。

PyTorch 使用多进程并行加载数据,worker 进程是使用 fork start 方法创建的。这意味着每个工作进程继承父进程的所有资源,包括 NumPy 的随机数生成器的状态。

示例 1

为了更加形象地描述问题,用户从以下两个示例中进行了简要概述。

示例 1 为一个示例数据集,它返回三个元素的随机向量。示例使用两个和四个工作进程的 batch 大小。

代码返回如下结果:每个进程返回的随机数都是相同的。

示例 2

示例 2 演示了如何在 face-landmarks 数据集上使用 Dataset 和 DataLoader 类。此外,还提到了数据扩充的重要性,并提供了一个随机裁剪扩充的例子。这是使用 NumPy 的随机数生成器实现的。

通过增加 num_workers 来加速数据加载,可以得到相同的裁剪结果:

batch 大小为 8, num_workers 为 2,random crop augmentation(随机裁剪扩充)

这个 bug 很容易产生。在某些情况下,它对最终性能的影响很小。在另一些情况下,相同的扩充会导致严重的退化。

基于对开放源码 PyTorch 项目的分析,发现 bug 的这位用户担心这个问题在许多支持真实产品的代码库中都存在。

究竟是 bug,还是预期功能或特征?

这位用户描述的 bug 也引起了众多网友的热议,其中一些人并不认为这是 bug。

用户「amasterblaster」认为,这不是一个 bug,而是所有种子随机函数的预期功能。这是因为即使在随机实验中,有时你想要对比静态参数的变化,并得到相同的随机数。只有当你被读为真随机(true random)时,才会根据 OS time 设置 seed。

用户「xicor7017」表示自己也遇到了相同的问题,也认为它并不是一个 bug,而是一个可能不为人所知的特征。如果忽略它的话,调试问题时会很麻烦。

与此同时,另一些人表达出了不同的观点,认为既然「如果事情朝着人们不希望的方向发展,那么它就不应该这样,也就构成了 bug。」

用户「IntelArtiGen」称自己意识到了这个 bug,认为它是不正常的,并且对自己的项目造成了一些小问题。用户「gwern」赞同这种观点,认为如果 95% 以上的用户使用时出现错误,则代码就是错的。

用户「synonymous1964」进一步解读了这个 bug。ta 认为,人们可能误解了这个问题,问题不在于设置特定的随机种子会导致每次训练过程中生成相同序列的随机数,这显然是按预期工作的。相反,问题在于多个数据下载进程中(由 PyTorch 中的 num_workers 设置)的每个进程都会在某个特定的训练过程中输出相同序列的随机数。毫无疑问,这当然会对项目造成影响,具体取决于你如何进行数据加载和扩充。所以,即使这个 bug 是「按预期工作的」,但向更多其他用户指出来也挺好的。

不知道机器之心的读者,有没有遇到过类似的 bug 呢?如果有,可以在评论中发表自己对该 bug 的观点。

参考链接:

https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects/

https://www.reddit.com/r/MachineLearning/comments/mocpgj/p_using_pytorch_numpy_a_bug_that_plagues/

end

我是王博Kings,一名985AI博士,华为云专家/CSDN博客专家,单个AI项目在Github上获得了2000标星,为了方便大家交流,附上了联系方式。

这是我的私人微信,还有少量坑位,可与相关学者研究人员交流学习 

目前开设有人工智能、机器学习、计算机视觉、自动驾驶(含SLAM)、Python、求职面经、综合交流群扫描添加CV联盟微信拉你进群,备注:CV联盟

王博Kings 的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章特征选择与稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论(上)

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论(下)

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点个在看支持一下吧

PyTorch + NumPy这么做会降低模型准确率?相关推荐

  1. PyTorch + NumPy这么做会降低模型准确率,这是bug还是预期功能?

    作者|维度 来源|机器之心 近日,有用户在自己的项目中发现了一个微小的 bug,在 PyTorch 同时使用 NumPy 的随机数生成器和多进程数据加载会导致相同的扩充数据,只有专门设置 seed 才 ...

  2. 深度学习提高模型准确率方法

    这里写目录标题 深度学习 数据 使用更多数据 更改图像大小 减少颜色通道 算法 模型改进 增加训练轮次 迁移学习 添加更多层 调整超参数 总结 深度学习 我们已经收集好了一个数据集,建立了一个神经网络 ...

  3. 【深度学习】Pytorch实现CIFAR10图像分类任务测试集准确率达95%

    文章目录 前言 CIFAR10简介 Backbone选择 训练+测试 训练环境及超参设置 完整代码 部分测试结果 完整工程文件 Reference 前言 分享一下本人去年入门深度学习时,在CIFAR1 ...

  4. Hinton等人最新研究:大幅提升模型准确率,标签平滑技术到底怎么用?

    作者 | Rafael Müller , Simon Kornblith, Geoffrey Hinton 译者 | Rachel 责编 | Jane 出品 | AI科技大本营(ID: rgznai1 ...

  5. Github标星10.4k:用 NumPy 实现所有主流机器学习模型

    用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgin 最近开源了一个非常剽悍的项目.超过 3 万行代码.30 多个模型,这也许能打造「最强」的机器学习基石?(编辑:机器 ...

  6. 内涝预测过程的噪音_提高人工智能模型准确率的测试过程中需要注意什么?

    黑马程序员视频库 播妞微信号:boniu236 传智播客旗下互联网资讯.学习资源免费分享平台 现在人工智能行业发展迅猛,那么人工智能产品特别是使用分类算法实现的产品中判断其能否上线通常是通过算法自带的 ...

  7. 一步步读懂Pytorch Chatbot Tutorial代码(四) - 为模型准备数据

    文章目录 自述 有用的工具 代码出处 目录 头大 代码及说明 Prepare Data for Models 重点关注 indexesFromSentence zeroPadding binaryMa ...

  8. 用pyinstaller打包pytorch环境下的深度学习模型,实现通过exe程序实现界面显示模型的分类效果

    用pyinstaller打包pytorch环境下的深度学习模型,实现通过exe应用实现界面显示模型的分类效果 训练深度学习模型和界面显示,看我之前的博客,链接在下面: 通过残差网络实现CLFAR-10 ...

  9. 深度学习中学习率和batchsize对模型准确率的影响

    本内容来自其他的人解析,参考链接在最后的注释. 1. 前言 目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下: n是批量大小(batchsize),η是学习率(learn ...

最新文章

  1. Java项目:图书管理系统(java+SSM+jsp+mysql+maven)
  2. Flatten Nested Arrays(展平嵌套数组)
  3. word2vec (一) 简介与训练过程概要
  4. 拯救react的hooks:react的问题和hooks的作用
  5. 2020-11-22(工作集与常驻集)
  6. mysql 加号的作用_MySQL学习笔记(一)
  7. 可转债数据一览表集思录_可转债股票数据一览表
  8. codevs1521 华丽的吊灯
  9. 【ExtJS实践】之五 :常用语句及脚本备忘
  10. Go语言的flag库、os库、strconv库
  11. android+影子系统,神器再升级,手机影子系统来啦
  12. 关于最近几次给客户做系统 DEMO的感悟总结
  13. flash人物原地走路_Flash怎么制作一个行走的小人动画?
  14. 毕业生登记表特长填写计算机,大学生毕业登记表中有何特长该怎么填啊。
  15. 密码学基础之对称密钥的分发和存储
  16. QT的下载与安装(QT5.9.1)
  17. 《诗经·陈风·月出》presentation
  18. 计算机 澳洲 博士后 要考雅思么,博士后移民澳大利亚(澳洲做科研博士后)
  19. 在word中一个符号怎么打,这个符号是上边一个白三角,下边一个黑三角,两个三角对称形成一个向右的箭头。
  20. 关于传奇自动触发的几个常用脚本OnKillMob、StdModeFunc、等触发事件

热门文章

  1. sqlyog同步mysql_大坑:用SQLyog连mysql的部分操作不能同步到从库
  2. python模拟登录webspare_全面解读python web 程序的9种部署方式
  3. 关闭mysql方法_启动和关闭MySQL的方法
  4. php 解包二进制,workerman的二进制怎么玩啊,怎么封包,怎么解包啊
  5. Android自定义事件总线,android事件总线EventBus3.0使用方法详解
  6. 中职读计算机什么专业好,读职校选择什么专业好一些
  7. 期待鸿蒙是什么意思,如何看待华为将于 6月2 日举办鸿蒙发布会?你对此有哪些期待?...
  8. ASP.NET MVC Controller Overview摘录
  9. 优盘提示插入多卷集的最后一卷解决办法(5)
  10. Django环境搭建及学前准备