作者:Eugene Khvedchenya   编译:ronghuaiyang

导读

只报告模型的Top-1准确率往往是不够的。

将train.py脚本转换为具有一些附加特性的强大pipeline

每一个深度学习项目的最终目标都是为产品带来价值。当然,我们想要最好的模型。什么是“最好的” —— 取决于特定的用例,我将把这个讨论放到这篇文章之外。我想谈谈如何从你的train.py脚本中得到最好的模型。

在这篇文章中,我们将介绍以下技巧:

  1. 用高级框架代替自己写的循环

  2. 使用另外的度量标准监控训练的进展

  3. 使用TensorBoard

  4. 使模型的预测可视化

  5. 使用Dict作为数据集和模型的返回值

  6. 检测异常和解决数值的不稳定性

免责声明:在下一节中,我将引用一些源代码。大多数都是为[Catalyst](https://github.com/catalysts -team/catalyst)框架(20.08版)定制的,可以在pytorch-toolbelt中使用。

不要重复造轮子

建议1 — 利用PyTorch生态系统的高级训练框架

PyTorch在从头开始编写训练循环时提供了极佳的灵活性和自由度。理论上,这为编写任何训练逻辑提供了无限可能。在实践中,你很少会为训练CycleGAN、distilling BERT或3D物体检测从头开始实现编写训练循环。

从头编写一个完整的训练循环是学习PyTorch基本原理的一个很好的方法。不过,我强烈建议你在掌握了一些知识之后,转向高级框架。有很多选择:Catalyst, PyTorch-Lightning, Fast.AI, Ignite,以及其他。高级框架通过以下方式节省你的时间:

  • 提供经过良好测试的训练循环

  • 支持配置文件

  • 支持多gpu和分布式训练

  • 管理检查点/实验

  • 自动记录训练进度

从这些高级库中获得最大效果需要一些时间。然而,这种一次性的投资从长期来看是有回报的。

优点

  • 训练pipeline变得更小 —— 代码越少 —— 出错的机会就越少。

  • 易于进行实验管理。

  • 简化分布式和混合精度训练。

缺点

  • 通常,当使用一个高级框架时,我们必须在框架特定的设计原则和范例中编写代码。

  • 时间投资,学习额外的框架需要时间。

给我看指标

建议2 —— 在训练期间查看其他指标

几乎每一个用于在MNIST或CIFAR甚至ImageNet中对图像进行分类的快速启动示例项目都有一个共同点 —— 它们在训练期间和训练之后都报告了一组最精简的度量标准。通常情况下,包括Top-1和Top-5准确度、错误率、训练/验证损失,仅此而已。虽然这些指标是必要的,但它只是冰山一角!

现代图像分类模型有数千万个参数。你想只使用一个标量值来计算它吗?

Top-1准确率最好的CNN分类模型在泛化方面可能不是最好的。根据你的领域和需求,你可能希望保存具有最 false-positive/false-negative的模型,或者具有最高平均精度的模型。

让我给你一些建议,在训练过程中你可以记录哪些数据:

  • Grad-CAM heat-map —— 看看图像的哪个部分对某一特定类的贡献最大。

可视化Grad-CAM heat-maps有助于识别模型是否基于真实病理或图像伪影做出预测

  • Confusion Matrix — 显示了对你的模型来说哪两个类最具挑战性。

混淆矩阵揭示了一个模型对特定类型进行不正确分类的频率

  • Distribution of predictions — 让你了解最优决策边界。

该模型的negative和positive 预测的分布表明,有很大一部分数据模型无法确定地分类

  • Minimum/Average/Maximum 跨所有层的梯度值,允许识别是否在模型中存在消失/爆炸的梯度或初始化不好的层。

使用面板工具来监控训练

建议3 — 使用TensorBoard或任何其他解决方案来监控训练进度

在训练模型时,你可能最不愿意做的事情就是查看控制台输出。通过一个功能强大的仪表板,你可以在其中一次看到所有的度量标准,这是检查训练结果的更有效的方法。

Tensorboard可以快速的检查和比较你运行的训练

对于少量实验和非分布式环境,TensorBoard是一个黄金标准。自版本1.3以来,PyTorch就完全支持它,并提供了一组丰富的特性来管理试用版。还有一些更先进的基于云的解决方案,比如Weights&Biases、[Alchemy](https://github.com/catalyst team/alchemy)和TensorBoard.dev,这些解决方案使得在多台机器上监控和比较训练变得更容易。

当使用Tensorboard时,我通常记录这样一组指标:

  • 学习率和其他可能改变的优化参数(动量,重量衰减,等等)

  • 用于数据预处理和模型内部的时间

  • 贯穿训练和验证的损失(每个batch和每个epoch的平均值)

  • 跨训练和验证的度量

  • 训练session的超参数最终值

  • 混淆矩阵,Precision-Recall曲线,AUC(如果适用)

  • 模型预测的可视化(如适用)

一图胜千言

直观地观察模型的预测是非常重要的。有时训练数据是有噪声的;有时,模型会过拟合图像的伪影。通过可视化最好的和最差的batch(基于损失或你感兴趣的度量),你可以对模型执行良好和糟糕的情况进行有价值的洞察。

建议5 — 可视化每个epoch中最好和最坏的batch。它可能会给你宝贵的见解。

Catalyst用户提示:这里是使用可视化回调的示例:https://github.com/BloodAxe/Catalyst-Inria-Segmentation-Example/blob/master/fit_predict.py#L258

例如,在全球小麦检测挑战中,我们需要在图像上检测小麦头。通过可视化最佳batch的图片(基于mAP度量),我们看到模型在寻找小物体方面做得近乎完美。

最佳模型预测的可视化显示了模型在小物体上的良好表现

相反,当我们查看最差一批的第一个样本时,我们看到模型很难对大物体做出准确的预测。可视化分析为任何数据科学家都提供了宝贵的见解。

最差模型预测的可视化揭示了模型在大物体上的性能很差

查看最差的batch也有助于发现数据标记中的错误。通常情况下,贴错标签的样本损失更大,因此会成为最差的batch。通过在每个epoch对最糟糕的batch做一个视觉检查,你可以消除这些错误:

标记错误的例子。绿色像素表示true positives,红色像素表示false negative。在这个示例中,ground-truth掩模标在了它实际上不存在的位置上。

使用Dict作为Dataset和Model的返回值

建议4 — 如果你的模型返回一个以上的值,使用Dict来返回结果,不要使用tuple

在复杂的模型中,返回多个输出并不少见。例如,目标检测模型通常返回边界框及其标签,在图像分割CNN-s中,我们经常返回中间层的mask进行深度监督,多任务学习最近也很常用。

在许多开源实现中,我经常看到这样的东西:

# Bad practice, don't return tuple
class RetinaNet(nn.Module):...def forward(self, image):x = self.encoder(image)x = self.decoder(x)bboxes, scores = self.head(x)return bboxes, scores...

对于作者来说,我认为这是一种非常糟糕的从模型返回结果的方法。下面是我推荐的替代方法:

class RetinaNet(nn.Module):RETINA_NET_OUTPUT_BBOXES = "bboxes"RETINA_NET_OUTPUT_SCORES = "scores"...def forward(self, image):x = self.encoder(image)x = self.decoder(x)bboxes, scores = self.head(x)return { RETINA_NET_OUTPUT_BBOXES: bboxes, RETINA_NET_OUTPUT_SCORES: scores }...

这个建议在某种程度上与“The Zen of Python”的设定产生了共鸣 —— “明确的比含蓄的更好”。遵循这一规则将使你的代码更清晰、更容易维护。

那么为什么我认为第二种选择更好呢?有几个原因:

  • 返回值有一个显式的名称与它关联。你不需要记住元组中元素的确切顺序。

  • 如果你需要访问返回的字典的一个特定元素,你可以通过它的名字来访问。

  • 从模型中添加新的输出不会破坏代码。

使用Dict,你甚至可以更改模型的行为,以按需返回额外的输出。例如,这里有一个简短的片段,演示了如何返回多个“主”输出和两个“辅助”输出来进行度量学习:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/models/timm.py#L104def forward(self, **kwargs):x = kwargs[self.input_key]x = self.rgb_bn(x)x = self.encoder.forward_features(x)embedding = self.pool(x)result = {OUTPUT_PRED_MODIFICATION_FLAG: self.flag_classifier(self.drop(embedding)),OUTPUT_PRED_MODIFICATION_TYPE: self.type_classifier(self.drop(embedding)),}if self.need_embedding:result[OUTPUT_PRED_EMBEDDING] = embeddingif self.arc_margin is not None:result[OUTPUT_PRED_EMBEDDING_ARC_MARGIN] = self.arc_margin(embedding)return result

同样的建议也适用于Dataset类。对于Cifar-10玩具示例,可以将图像及其对应的标签作为元组返回。但当处理多任务或多输入模型,你想从数据集返回Dict类型的样本:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2/blob/master/alaska2/dataset.py#L373
class TrainingValidationDataset(Dataset):def __init__(self,images: Union[List, np.ndarray],targets: Optional[Union[List, np.ndarray]],quality: Union[List, np.ndarray],bits: Optional[Union[List, np.ndarray]],transform: Union[A.Compose, A.BasicTransform],features: List[str],):""":param obliterate - Augmentation that destroys embedding."""if targets is not None:if len(images) != len(targets):raise ValueError(f"Size of images and targets does not match: {len(images)} {len(targets)}")self.images = imagesself.targets = targetsself.transform = transformself.features = featuresself.quality = qualityself.bits = bitsdef __len__(self):return len(self.images)def __repr__(self):return f"TrainingValidationDataset(len={len(self)}, targets_hist={np.bincount(self.targets)}, qf={np.bincount(self.quality)}, features={self.features})"def __getitem__(self, index):image_fname = self.images[index]try:image = cv2.imread(image_fname)if image is None:raise FileNotFoundError(image_fname)except Exception as e:print("Cannot read image ", image_fname, "at index", index)print(e)qf = self.quality[index]data = {}data["image"] = imagedata.update(compute_features(image, image_fname, self.features))data = self.transform(**data)sample = {INPUT_IMAGE_ID_KEY: os.path.basename(self.images[index]), INPUT_IMAGE_QF_KEY: int(qf)}if self.bits is not None:# OKsample[INPUT_TRUE_PAYLOAD_BITS] = torch.tensor(self.bits[index], dtype=torch.float32)if self.targets is not None:target = int(self.targets[index])sample[INPUT_TRUE_MODIFICATION_TYPE] = targetsample[INPUT_TRUE_MODIFICATION_FLAG] = torch.tensor([target > 0]).float()for key, value in data.items():if key in self.features:sample[key] = tensor_from_rgb_image(value)return sample

当你的代码中有Dictionaries时,你可以在任何地方使用名称常量引用输入/输出。遵循这条规则将使你的训练管道非常清晰和容易遵循:

# https://github.com/BloodAxe/Kaggle-2020-Alaska2callbacks += [CriterionCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,criterion_key="bce"),CriterionCallback(input_key=INPUT_TRUE_MODIFICATION_TYPE,output_key=OUTPUT_PRED_MODIFICATION_TYPE,criterion_key="ce"),CompetitionMetricCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,prefix="auc",output_activation=binary_logits_to_probas,class_names=class_names,),OutputDistributionCallback(input_key=INPUT_TRUE_MODIFICATION_FLAG,output_key=OUTPUT_PRED_MODIFICATION_FLAG,output_activation=binary_logits_to_probas,prefix="distribution/binary",),BestMetricCheckpointCallback(target_metric="auc", target_metric_minimize=False, save_n_best=3),
]

在训练中检测异常

就像人类可以阅读含有许多错误的文本一样,深度学习模型也可以在训练过程中出现错误时学习“一些合理的东西”。作为一名开发人员,你要负责搜索异常并对其表现进行推理。

建议5 — 在训练期间使用 torch.autograd.detect_anomaly()查找算术异常

如果你在训练过程中在损失/度量中看到NaNs或Inf,你的脑海中就会响起一个警报。它是你的管道中有问题的指示器。通常情况下,它可能由以下原因引起:

  • 模型或特定层的初始化不好(你可以通过观察梯度大小来检查哪些层)

  • 数学上不正确的运算(负数的torch.sqrt(),非正数的torch.log(),等等)

  • 不当使用torch.mean()torch.sum() 的reduction(zero-sized张量上的均值会得到nan,大张量上的sum容易导致溢出)

  • 在loss中使用x.sigmoid()(如果你需要在loss函数中使用概率,更好的方法是x.sigmoid().clamp(eps,1-eps )以防止梯度消失)

  • 在Adam-like的优化器中的低epsilon值

  • 在使用fp16的训练的时候没有使用动态损失缩放

为了找到你代码中第一次出现Nan/Inf的确切位置,PyTorch提供了一个简单易用的方法torch. autograde .detect_anomaly()

import torchdef main():torch.autograd.detect_anomaly()...# Rest of the training code# OR
class MyNumericallyUnstableLoss(nn.Module):def forward(self, input, target):with torch.autograd.set_detect_anomaly(True):loss = input * targetreturn loss

将其用于调试目的,否则就禁用它,异常检测会带来计算开销,并将训练速度降低10-15% 。

—END—

英文原文:https://towardsdatascience.com/efficient-pytorch-supercharging-training-pipeline-19a26265adae

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习及深度学习笔记等资料打印机器学习在线手册深度学习笔记专辑《统计学习方法》的代码复现专辑
AI基础下载机器学习的数学基础专辑获取一折本站知识星球优惠券,复制链接直接打开:https://t.zsxq.com/662nyZF本站qq群1003271085。加入微信群请扫码进群(如果是博士或者准备读博士请说明):

【深度学习】高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力相关推荐

  1. pytorch 模型可视化_【深度学习】高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力...

    作者:Eugene Khvedchenya   编译:ronghuaiyang 导读 只报告模型的Top-1准确率往往是不够的. 将train.py脚本转换为具有一些附加特性的强大pipeline 每 ...

  2. pytorch 模型可视化_高效使用Pytorch的6个技巧:为你的训练Pipeline提供强大动力

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Eugene Khvedchenya 编译:ronghuaiyang 导读 ...

  3. 深度学习入门之PyTorch学习笔记:深度学习框架

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 2.1 深度学习框架介绍 2.1.1 TensorFlow 2.1.2 Caffe 2.1.3 Theano 2.1.4 ...

  4. Python深度学习:基于PyTorch [Deep Learning with Python and PyTorch]

    作者:吴茂贵,郁明敏,杨本法,李涛,张粤磊 著 出版社:机械工业出版社 品牌:机工出版 出版时间:2019-11-01 Python深度学习:基于PyTorch [Deep Learning with ...

  5. 干货|《深度学习入门之Pytorch》资料下载

    深度学习如今已经成为了科技领域中炙手可热的技术,而很多机器学习框架也成为了研究者和业界开发者的新宠,从早期的学术框架Caffe.Theano到如今的Pytorch.TensorFlow,但是当时间线来 ...

  6. 【深度学习】基于Pytorch进行深度神经网络计算(一)

    [深度学习]基于Pytorch进行深度神经网络计算(一) 文章目录 1 层和块 2 自定义块 3 顺序块 4 在正向传播函数中执行代码 5 嵌套块 6 参数管理(不重要) 7 参数初始化(重要) 8 ...

  7. 【深度学习】基于Pytorch进行深度神经网络计算(二)

    [深度学习]基于Pytorch进行深度神经网络计算(二) 文章目录 1 延后初始化 2 Pytorch自定义层2.1 不带参数的层2.2 带参数的层 3 基于Pytorch存取文件 4 torch.n ...

  8. 【深度学习】基于Pytorch的卷积神经网络概念解析和API妙用(一)

    [深度学习]基于Pytorch的卷积神经网络API妙用(一) 文章目录 1 不变性 2 卷积的数学分析 3 通道 4 互相关运算 5 图像中目标的边缘检测 6 基于Pytorch的卷积核 7 特征映射 ...

  9. 【深度学习】基于Pytorch的卷积神经网络概念解析和API妙用(二)

    [深度学习]基于Pytorch的卷积神经网络API妙用(二) 文章目录1 Padding和Stride 2 多输入多输出Channel 3 1*1 Conv(笔者在看教程时,理解为降维和升维) 4 池 ...

最新文章

  1. 如何用JS获取页面上的所有标签
  2. 报告!钉钉宜搭的8月总结,请查收~
  3. sql2008 获取表结构说明
  4. 深度可分离卷积Depthwise Separable Convolution
  5. C# 字符串string的基本操作
  6. Linux安装caffe问题汇总
  7. SSH(六)hibernate持久层模板于事务管理
  8. hbuilderX连接雷电模拟器
  9. OISPT 内网安全项目组A1-渗透测试基础项目训练文档
  10. 一句话,读懂首席架构师、CTO和技术总监的区别
  11. html中小星星打分,折腾:2颗星星+纯CSS实现星星评分交互效果
  12. APK瘦身优化检测工具-Matrix ApkChecker 使用
  13. VSTO中Word的查找方式
  14. APL开发日志--2012-11-28
  15. VS2019中C#开发的bottom按钮在哪里?
  16. C语言人机大战之决战三子棋之巅
  17. java计算机毕业设计智慧农业水果销售系统MyBatis+系统+LW文档+源码+调试部署
  18. Microsoft visual c++2017 X64 Minimum Runtime等vc++运行库问题的解决记录
  19. STM32启动文件的分析
  20. U盘FAT32格式如何转换成NTFS格式

热门文章

  1. 多对多关联查询sql语句
  2. 浅析Facebook LibraBFT与比原链Bystack BBFT共识
  3. [luogu3231 HNOI2013] 消毒 (二分图最小点覆盖)
  4. 漫扯:从polling到Websocket(ZZ)
  5. Linux下批量添加用户的两种方法
  6. TextBox中的KeyDown 时间不能响应的问题!
  7. Button的点击事件
  8. EndNote批量实现文献标题首字母大写 附最新版endnote下载
  9. 导入jar包到Maven本地仓库(maven install jar)
  10. Matlab | Matlab从入门到放弃(5)——矩阵与format