加入极市专业CV交流群,与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度 等名校名企视觉开发者互动交流!

同时提供每月大咖直播分享、真实项目需求对接、干货资讯汇总,行业技术交流。关注 极市平台 公众号 ,回复 加群,立刻申请入群~

编译|McGL,https://zhuanlan.zhihu.com/p/147723652来源|https://towardsdatascience.com/efficient-pytorch-part-1-fe40ed5db76c

高效的 PyTorch 训练pipeline是怎样的呢?是产生准确率最高模型?还是跑得最快?或是容易理解和扩展?还是很容易并行计算?嗯,以上都是!

在研究和生产领域, PyTorch都是一个很好用的工具。从斯坦福大学、 Udacity、 SalelsForce和Tesla等都采用这个深度学习框架清楚的表明了这一点。然而,每个工具都需要投入时间来最高效地使用和掌握它。在使用 PyTorch 两年多之后,我决定总结一下我使用这个深度学习库的经验。

高效 ——(系统或机器)达到最大生产力,而浪费的精力或费用却最少。(牛津辞典)

高效 PyTorch 系列的这一部分展示了识别和消除 I/O 和 CPU 、GPU瓶颈的一般技巧。第二部分将揭露一些高效张量运算的技巧。第三部分——高效的模型调试技术。

声明: 本文假设你至少了解了 PyTorch 的基本知识和概念。

从最明显的一个开始:

建议0: 了解代码中的瓶颈在哪里

nvidia-smi, htop, iotop, nvtop, py-spy, strace 等命令行工具应该成为你最好的朋友。你的训练pipeline是CPU-bound? IO-bound 还是GPU-bound? 这些工具将帮助你找到答案。

你可能甚至没有听说过它们,或者听说过但没有使用过。没关系。如果你不马上开始使用它们也没关系。只要记住,别人可能用它们来训练模型,比你快5%-10%-15%...这最终可能会决定是否赢得或失去市场,得到或失去理想工作岗位的offer。

数据预处理

几乎所有的训练pipeline都是从 Dataset 类开始的。它负责提供数据样本。任何必要的数据转换和增强都可能在这里发生。简而言之,Dataset 是一个抽象,它报告数据的大小,并通过给定的索引返回数据样本。

如果使用类图像的数据(2D、3D 扫描) ,磁盘I/O可能会成为瓶颈。要获得原始像素数据,代码需要从磁盘读取数据并将图像解码到内存中。每个任务都很快,但当你需要尽快处理成千上万个任务时,这可能会成为一个挑战。像 NVidia Dali 这样的库提供 GPU加速的 JPEG 解码。如果在数据处理pipeline中遇到 IO 瓶颈,这绝对值得一试。

还有一个选择。SSD 磁盘的存取时间约为0.08-0.16毫秒。RAM 的访问时间为纳秒。我们可以把我们的数据直接放到内存中!

建议1: 如果可能的话,将所有或部分数据移动到 RAM。

如果你有足够的内存来加载和保存你所有的训练数据,这是从pipeline中消除最慢的数据读取步骤的最简单的方法。

这个建议对于云实例特别有用,比如 Amazon 的 p3.8 xlarge。此实例有 EBS 磁盘,其默认设置的性能非常有限。然而,这个实例配备了惊人的248Gb 内存。这足够将所有 ImageNet 数据集保存在内存中了!以下是实现这个目标的方法:

class RAMDataset(Dataset):  def __init__(image_fnames, targets):    self.targets = targets    self.images = []    for fname in tqdm(image_fnames, desc="Loading files in RAM"):      with open(fname, "rb") as f:        self.images.append(f.read())  def __len__(self):    return len(self.targets)  def __getitem__(self, index):    target = self.targets[index]    image, retval = cv2.imdecode(self.images[index], cv2.IMREAD_COLOR)    return image, target

我个人就遇到过这个瓶颈问题。我有一台配了4x1080Ti GPUs 的家用电脑。有一次我用一个 p3.8xlarge 实例,它有四个 NVidia Tesla V100,我把我的训练代码移到了那里。考虑到 V100比我的老款1080Ti 更新更快,我希望看到快15-30% 的训练速度。令我惊讶的是,每个epoch的训练时间都在增加!这是我学到的教训,要注意基础设施和环境的细微差别,而不仅仅是 CPU 和 GPU 的速度。

根据你的场景,你可以在内存中保持每个文件的二进制内容不变,并“动态”解码它,或者保留未压缩的图像的原始像素。无论你选择哪种方式,这里有第二个建议:

建议2: 性能分析。测量。比较。每次你对pipeline进行任何改动时,都要仔细评估它对整体的影响。

这个建议仅仅关注训练速度,假设你不对模型、超参数、数据集等进行更改。你可以拥有一个魔术般的命令行参数(魔术开关) ,如果指定了,它将运行一些合理数量的数据样本的训练。有了这个功能,你可以快速的对你的pipeline进行性能分析:

# Profile CPU bottleneckspython -m cProfile training_script.py --profiling# Profile GPU bottlenecksnvprof --print-gpu-trace python train_mnist.py# Profile system calls bottlenecksstrace -fcT python training_script.py -e trace=open,close,read

建议3: 线下预处理所有数据

如果你正在训练512x512大小的图像,这些图像是由2048 × 2048的图片转换的,那么事先调整它们的大小。如果你使用灰度图像作为模型的输入,请离线进行颜色转换。如果你正在做 NLP ——事先做tokenization并保存到磁盘。没有必要在训练期间一遍又一遍地重复同样的操作。就渐进式学习而言,你可以保存多种分辨率的训练数据——这仍然比在线调整目标分辨率要快。

对于表格数据,请考虑在 Dataset 创建时将 pd.DataFrame 对象转换为 PyTorch 张量。

建议4: 调整 DataLoader 的workers数量

PyTorch 使用 DataLoader 类来简化为训练模型生成batches的过程。为了加快速度,它可以并行执行,使用 python 的multiprocessing。大多数情况下,直接用就很好了。以下是一些需要记住的事情:

每个进程生成一批数据,这些batches通过互斥同步(mutex synchronization)提供给主进程。如果你有 N 个workers,那么你的脚本将需要 N 倍的内存才能在系统内存中存储这些batches。你究竟需要多少RAM?让我们计算一下:

  1. 假设我们用batch size 32 来训练 Cityscapes 图像分割模型,RGB 图像大小为512x512x3 (高, 宽, 通道). 我们在 CPU 端进行图像标准化(稍后我将解释为什么它很重要)。在这种情况下,我们的最终图像张量512 * 512 * 3 * sizeof(float32) = 3,145,728 bytes. 乘以batch size得到的结果是100,663,296 bytes ,即大约100mb
  2. 除了图像,我们需要提供ground-truth masks。它们各自的大小为(默认情况下,掩码的类型为 long,为8字节)ー 512 * 512 * 1 * 8 * 32 = 67,108,864 , 即大约67mb
  3. 因此,一批数据所需的总内存为167 Mb。如果我们有8个workers,所需的总内存量将是167 Mb * 8 = 1,336 Mb.

听起来还不算太糟,对吧?当你的硬件配置能够处理超过8个workers所能提供的batches时,问题就出现了。你可以简单的设置64个workers,但这至少会消耗11gb 的内存。

如果你的数据是3D的,情况会变得更糟; 在这种情况下,即使是一个单通道512x512x512的样本也将占用134 Mb,而batch size为32将占用4.2 Gb,如果有8个workers,则需要32gb 的内存来保存中间数据。

这个问题有一个部分解决方案ーー你可以尽可能地减少输入数据的通道深度:

  1. 保持 RGB 图像在每个通道深度为8位。图像转换为float和标准化可以很容易地在 GPU 上完成
  2. 在数据集中使用 uint8或 uint16数据类型代替 long
class MySegmentationDataset(Dataset):  ...  def __getitem__(self, index):    image = cv2.imread(self.images[index])    target = cv2.imread(self.masks[index])    # No data normalization and type casting here    return torch.from_numpy(image).permute(2,0,1).contiguous(),           torch.from_numpy(target).permute(2,0,1).contiguous()class Normalize(nn.Module):    # https://github.com/BloodAxe/pytorch-toolbelt/blob/develop/pytorch_toolbelt/modules/normalize.py    def __init__(self, mean, std):        super().__init__()        self.register_buffer("mean", torch.tensor(mean).float().reshape(1, len(mean), 1, 1).contiguous())        self.register_buffer("std", torch.tensor(std).float().reshape(1, len(std), 1, 1).reciprocal().contiguous())    def forward(self, input: torch.Tensor) -> torch.Tensor:        return (input.to(self.mean.type) - self.mean) * self.stdclass MySegmentationModel(nn.Module):  def __init__(self):    self.normalize = Normalize([0.221 * 255], [0.242 * 255])    self.loss = nn.CrossEntropyLoss()  def forward(self, image, target):    image = self.normalize(image)    output = self.backbone(image)    if target is not None:      loss = self.loss(output, target.long())      return loss    return output

这样做,你可以大大降低内存需求。对于上面的示例,用于内存高效数据表示的内存使用量为每批33 Mb,而不是167 Mb。那是5倍的缩减!当然,这需要在模型本身中执行额外的步骤,将数据normalize/cast为适当的数据类型。然而,张量越小,CPU 到 GPU 的传输时间越快。

应该理性的选择 DataLoader 的workers数量。你应该检查你的 CPU 和 IO 系统有多快,有多少内存,GPU 处理这些数据有多快。

多GPU训练与推理

神经网络模型变得越来越大。今天的趋势是使用多个 GPU 来提速训练。由于更大的batch size,它还经常可以改善模型的性能。PyTorch 只需要几行代码就可以实现多GPU功能。然而,一些说明乍一看并不明显。

model = nn.DataParallel(model) # Runs model on all available GPUs

使用多 GPU 最简单的方法是将模型包在nn.DataParallel类中。在大多数情况下,它工作得不错,除非你训练一些图像分割模型(或任何其它模型,产生大型张量作为输出)。在前传结束的时候, nn.DataParallel从所有 GPU 收集输出回到主 GPU 上,通过输出后向传播并进行梯度更新。

有两个问题:

  • GPUs负载不平衡
  • 在主 GPU 上收集需要额外的内存

首先,只有主 GPU 在进行损失计算、后向传播和梯度更新,而其它 GPU 则在60C 凉快处等待下一组数据。

其次,主 GPU 收集所有输出所需的额外内存通常会迫使你减少batch size。问题是nn.DataParallel 将一批数据均匀的分给各个GPU。假设你有4个GPUs,总batch size为32。那么每个 GPU 将得到8个样本。但问题是,虽然所有非主 GPU 都可以轻松地将这些batch放入相应的 VRAM 中,但主 GPU 必须分配额外的空间,以保持batch size为32的其它卡的输出。

GPU使用率不均衡有两种解决方案:

  1. 继续使用nn.DataParallel并在训练前传中计算损失。在这种情况下,你不会将密集预测masks返回给主 GPU,而只返回单个标量损失
  2. 使用分布式训练,也就是nn.DistributedDataParallel. 在分布式训练的帮助下,你可以解决上面这两个问题,同时享受观看所有GPU的100% 负载的乐趣

建议5: 如果你有超过2个 GPU ——考虑使用分布式训练模式

节省多少时间很大程度上取决于你的场景,根据我的观察,在4x1080Ti 上训练图像分类pipeline时,时间减少了20% 。

另外值得一提的是,你也可以使用nn.DataParallel and nn.DistributedDataParallel来进行推理。

关于自定义损失函数

编写自定义损失函数是一个有趣和令人兴奋的练习。我建议每个人都不时地尝试一下。实现一个逻辑复杂的损失函数时,有一件事你必须记住: 它运行在 CUDA上,编写高效的 CUDA 代码是你的职责。CUDA 高效意味着“没有 python 控制流”。在 CPU 和 GPU 之间来回切换,访问 GPU 张量的值可以完成任务,但是性能会很差。

不久前,我实现了一个自定义的cosine embedding损失函数,用于实例分割,该函数来自论文“Segmenting and tracking cell instances with cosine embeddings and recurrent hourglass networks”。它的文本形式相当简单,但是实现起来有点复杂。

我编写的第一个简单的实现(除了 bugs)花了几分钟(!) 计算一个batch的损失值。为了分析 CUDA 的瓶颈,PyTorch 提供了一个非常方便的内置性能分析器。使用起来非常简单,并且给出了解决代码瓶颈的所有信息:

def test_loss_profiling():    loss = nn.BCEWithLogitsLoss()    with torch.autograd.profiler.profile(use_cuda=True) as prof:        input = torch.randn((8, 1, 128, 128)).cuda()        input.requires_grad = True        target = torch.randint(1, (8, 1, 128, 128)).cuda().float()        for i in range(10):            l = loss(input, target)            l.backward()    print(prof.key_averages().table(sort_by="self_cpu_time_total"))

建议6: 如果你设计自定义模块和损失函数——进行性能分析和测试

在分析了我的初始实现之后,我可以将实现的速度提高100倍。关于在 PyTorch 中编写高效张量表达式的更多内容将在高效PyTorch 的第二部分中解释。

时间 vs 金钱

最后但并非最不重要的一点是,有时值得投资功能更强大的硬件,而不是优化代码。软件优化总是一个具有不确定结果的高风险过程。升级 CPU,内存,GPU,或所有可能更有效。资金和工程时间都是资源,合理利用两者是成功的关键。

建议7: 有些瓶颈可以通过硬件升级更容易的解决

总结

最大限度的利用你的日常工具是熟练的关键。尽量不要走捷径,如果你不清楚某些事情,就要深入挖掘。总有机会获得新知识的。问问你自己或你的同事——“我的代码如何改进? ” 我相信这种追求完美的意识对于计算机工程师来说和其它技能一样重要。

推荐阅读

  • 给训练踩踩油门:编写高效的PyTorch代码技巧

  • Pytorch数据加载的分析

  • Pytorch有什么节省内存(显存)的小技巧?

添加极市小助手微信(ID : cv-mart),备注:研究方向-姓名-学校/公司-城市(如:目标检测-小极-北大-深圳),即可申请加入极市技术交流群,更有每月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、行业技术交流,一起来让思想之光照的更远吧~

△长按添加极市小助手

△长按关注极市平台,获取最新CV干货

觉得有用麻烦给个在看啦~  

cpu消耗 pytorch_高效 PyTorch :如何消除训练瓶颈相关推荐

  1. pytorch dataset读取数据流程_高效 PyTorch :如何消除训练瓶颈

    加入极市专业CV交流群,与 10000+来自港科大.北大.清华.中科院.CMU.腾讯.百度 等名校名企视觉开发者互动交流! 同时提供每月大咖直播分享.真实项目需求对接.干货资讯汇总,行业技术交流.关注 ...

  2. PyTorch消除训练瓶颈 提速技巧

    [GiantPandaCV导读]训练大型的数据集的速度受很多因素影响,由于数据集比较大,每个优化带来的时间提升就不可小觑.硬件方面,CPU.内存大小.GPU.机械硬盘orSSD存储等都会有一定的影响. ...

  3. 深度学习PyTorch,TensorFlow中GPU利用率较低,CPU利用率很低,且模型训练速度很慢的问题总结与分析

    在深度学习模型训练过程中,在服务器端或者本地pc端,输入nvidia-smi来观察显卡的GPU内存占用率(Memory-Usage),显卡的GPU利用率(GPU-util),然后采用top来查看CPU ...

  4. 直播分享|邓文彬:如何在GPU/CPU/移动端高效训练和推断CNN网络

    | 极市线上分享 第35期 | ➤活动信息 主题:如何在GPU/CPU/移动端高效训练CNN网络 (看TEE AI算力棒在计算机视觉训练和推断的最新突破) 时间:本周四(11月15日)晚20:00~2 ...

  5. pytorch加载训练数据集dataloader操作耗费时间太久,该如何解决?

    笔者在使用pytorch加载训练数据进行模型训练的时候,发现数据加载需要耗费太多时间,该如何缩短数据加载的时间消耗呢?经过查询相关文档,总结实际操作过程如下: 1.尽量将jpg等格式的文件保存为bmp ...

  6. pytorch 多GPU训练总结(DataParallel的使用)

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_40087578/arti ...

  7. pytorch量化感知训练(QAT)示例---ResNet

    pytorch量化感知训练(QAT)示例---ResNet 训练浮点模型,测试浮点模式在CPU和GPU上的时间; BN层融合,测试融合前后精度和结果比对; 加入torch的量化感知API,训练一个QA ...

  8. PyTorch深度学习训练可视化工具tensorboardX

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 之前笔者提到了PyTorch的专属可视化工具visdom,参看Py ...

  9. pytorch 1.7训练保存的模型在1.4低版本无法加载:frame #63: <unknown function> + 0x1db3e0 (0x55ba98ddd3e0 in /data/user

    pytorch 1.7高版本训练保存的模型在1.4低版本无法加载,报错: torch.load('/home/user1/model_best_b.pth.tar') Traceback (most ...

最新文章

  1. 智能车竞赛技术报告 | 节能信标组 - 重庆大学- 赛博坦汽车人联盟
  2. centos cmake安装mysql_Centos安装mysql实例教程
  3. Numpy学习-数组的索引
  4. 关于Android定制Launcher
  5. 在拓扑引擎内检测到故障,错误代码255
  6. 硬盘接口类型简洁区别及SCSI设备和SCSI磁盘的概念区别
  7. 微信小程序上传图片到Java后端
  8. 图像处理-灰度图像转化为伪彩图像
  9. win10计算机系统优化设置,win10加速优化的方法是什么_windows10优化设置的方法
  10. 数字转人民币大写(SQL SERVER)
  11. JS,两种在页面加载完成后自动执行的方法(ready,onload)
  12. Docker报错:E: Unable to locate package python3
  13. uva 11134 Fabled Rooks
  14. Set集合及源码分析
  15. 关于iPhone的一个广告加载问题
  16. Visual Studio SVN创建分支 合并分支 切换分支 vs 插件 visualsvn
  17. 自己动手写SGD算法
  18. 木门工厂老板诉说木门行业痛点!厂家必看
  19. 物联网的特点对行业的作用
  20. 大作业-点灯机器人-记录心得(1)-----基础铺垫-文件基础知识

热门文章

  1. 制作 macOS High Sierra U盘USB启动安装盘方法教程 (全新安装 Mac 系统)
  2. 使用webpack、babel、react、antdesign配置单页面应用开发环境
  3. Qt GUI@学习日志
  4. Hive一些参数设置
  5. 12-22 挑战留给自己,积极面对
  6. SQL Server XML数据解析(1)
  7. 电子邮件地址抓取工具
  8. Flutter异步编程async与await的基本使用
  9. 精通Android自定义View(八)绘制篇Canvas分析之绘制文本
  10. Mr.J-- HTTP学习笔记(七)-- 缓存