点击上方“AI公园”,关注公众号,选择加“星标“或“置顶”


作者:Erfandi Maula Yusnu, Lalu

编译:ronghuaiyang

导读

对使用PyTorch Lightning的训练代码和原始的PyTorch代码进行了对比,展示了其简单,干净,灵活的优点,相信你会喜欢的。

PyTorch Lightning是为ML研究人员设计的轻型PyTorch封装。它帮助你扩展模型并编写更少的样板文件,同时维护代码干净和灵活同时进行扩展。它帮助研究人员更多地专注于解决问题,而不是编写工程代码。

我从两年前就开始使用PyTorch了,我从0.3.0版本开始使用。在我使用PyTorch之前,我使用Keras作为我的深度学习框架,但后来我开始切换到PyTorch,原因有几个。如果你想知道我的原因,看看下面这篇文章:https://medium.com/swlh/why-i-switch-from-keras-to-pytorch-e48922f5846。

由于我一直在使用PyTorch,所以我需要牺牲在Keras中只用几行简单的行代码就可以进行训练的乐趣,而编写自己的训练代码。它有优点也有缺点,但是我选择PyTorch编写代码的方式来获得对我的训练代码的更多控制。但每当我想在深度学习中尝试一些新的模型时,就意味着我每次都需要编写训练和评估代码。

所以,我决定建立我自己的库,我称之为torchwisdom,但我陷入了困境,因为我仍在为我的公司构建OCR全pipeline系统。所以,我试图找到另一个解决方案,然后我找到了PyTorch Lightning,在我看到代码后,它让我一见钟情。

因此,我将在本文中介绍的内容是安装、基本的代码比较以及通过示例进行比较,这些示例是我自己通过从pytorch lightning site获取的,一些代码自己创建的。最后是本文的结论。

安装

好的,让我们从安装pytorch-lighting开始,这样你就可以跟着我一起做了。你可以使用pip或者conda安装pytorch lightning。

pip install

pip install pytorch-lightning

conda install

conda install pytorch-lightning -c conda-forge

对我来说,我更喜欢用anaconda作为我的python解释器,它对于深度学习和数据科学的人来说更完整。从第一次安装开始,它就自带了许多标准机器学习和数据处理库包。

基本代码的比较

在我们进入代码之前,我想让你看看下面的图片。下面有2张图片解释了pytorch和pytorch lightning在编码、建模和训练上的区别。在左边,你可以看到,pytorch需要更多的代码行来创建模型和训练。

有了pytorch lightning,代码就变成了Lightning模块的内部,所有的训练工程代码都被pytorch lightning解决了。但是你需要在一定程度上定制你的训练步骤,如下面的示例代码所示。

对于训练代码,你只需要3行代码,第一行是用于实例化模型类,第二行是用于实例化Trainer类,第三行是用于训练模型。

这个例子是用pytorch lightning训练的一种方法。当然,你可以对pytorch进行自定义风格的编码,因为pytorch lightning具有不同程度的灵活性。你想看吗?让我们继续。

通过例子进行比较

好了,在完成安装之后,让我们开始编写代码。要做的第一件事是导入需要使用的所有库。在此之后,你需要构建将用于训练的数据集和数据加载器。

# import all you need
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision import datasets, transforms
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.core.lightning import LightningModule# transforms
# prepare transforms standard to MNIST
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])# data
mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=64)

正如上面看到的代码,我们使用来自torchvision的MNIST数据集,并使用torch.utils.DataLoader创建数据加载器。现在,在下面的代码中,我们使网络与28x28像素的MNIST数据集想匹配。第一层有128个隐藏节点,第二层有256个隐藏节点,第三层为输出层,有10个类作为输出。

# build your model
class CustomMNIST(LightningModule):def __init__(self):super().__init__()# mnist images are (1, 28, 28) (channels, width, height)self.layer1 = torch.nn.Linear(28 * 28, 128)self.layer2 = torch.nn.Linear(128, 256)self.layer3 = torch.nn.Linear(256, 10)def forward(self, x):batch_size, channels, width, height = x.size()# (b, 1, 28, 28) -> (b, 1*28*28)x = x.view(batch_size, -1)x = self.layer1(x)x = torch.relu(x)x = self.layer2(x)x = torch.relu(x)x = self.layer3(x)x = torch.log_softmax(x, dim=1)return xdef training_step(self, batch, batch_idx):data, target = batchlogits = self.forward(data)loss = F.nll_loss(logits, target)return {'loss': loss}def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=1e-3)# train your model
model = CustomMNIST()
trainer = Trainer(max_epochs=5, gpus=1)

如果你在上面的gist代码中看到第27和33行,你会看到training_stepconfigure_optimators方法,它覆盖了在第2行中扩展的类LightningModule中的方法。这使得pytorch中标准的nn.Module不同于LightningModule,它有一些方法使它与第39行中的Trainer类兼容。

现在,让我们尝试另一种方法来编写代码。假设你必须编写一个库,或者希望其他人使用纯pytorch编写的库。你该怎样使用pytorch lightning?

下面的代码有两个类,第一个类使用标准的pytorch的nn.Module作为其父类。它是按照标准pytorch模块中通常编写的方式编写的,但是看第30行,有一个名为ExtendMNIST的类继承了两个类。这两个类由StandardMNIST类和LightningModule类组合在一起。这就是我喜欢python的地方,一个类可以有多个父类。

# build your model
class StandardMNIST(nn.Module):def __init__(self):super().__init__()# mnist images are (1, 28, 28) (channels, width, height)self.layer1 = torch.nn.Linear(28 * 28, 128)self.layer2 = torch.nn.Linear(128, 256)self.layer3 = torch.nn.Linear(256, 10)def forward(self, x):batch_size, channels, width, height = x.size()# (b, 1, 28, 28) -> (b, 1*28*28)x = x.view(batch_size, -1)x = self.layer1(x)x = torch.relu(x)x = self.layer2(x)x = torch.relu(x)x = self.layer3(x)x = torch.log_softmax(x, dim=1)return x# extend StandardMNIST and LightningModule at the same time
# this is what I like from python, extend two class at the same time
class ExtendMNIST(StandardMNIST, LightningModule):def __init__(self):super().__init__()  def training_step(self, batch, batch_idx):data, target = batchlogits = self.forward(data)loss = F.nll_loss(logits, target)return {'loss': loss}def configure_optimizers(self):return torch.optim.Adam(self.parameters(), lr=1e-3)# run the training
model = ExtendMNIST()
trainer = Trainer(max_epochs=5, gpus=1)
trainer.fit(model, mnist_train_loader)

如果你看到ExtendMNIST类中的代码,你会看到它只是覆盖了LightningModule类。使用这种编写代码的方法,你可以扩展以前编写的任何其他模型,而无需更改它,并且仍然可以使用pytorch lightning库。

那么,你能在训练时给我看一下结果吗?好,让我们看看它在训练时是什么样子。

这样你就有了它在训练时的屏幕截图。它有一个很好的进度条,显示了网络的损失,这不是让你更容易训练一个模型吗?

如果你想查看实际运行的代码,可以单击下面的链接。第一个是pytorch lightning的标准方式,第二个是自定义方式。

PyTorch Lightning StandardStandard waycolab.research.google.com

PyTorch Lightning CustomCustom Waycolab.research.google.com

总结

PyTorch Lightning已经开发出了一个很好的标准代码,它有229个贡献者,并且它的开发非常活跃。现在,它甚至有风险投资,因为它达到了版本0.7。

在这种情况下(风险投资),我相信pytorch lightning将足够稳定,可以用作你编写pytorch代码的标准库,而不必担心将来开发会停止。

对于我来说,我选择在我的下一个项目中使用pytorch lighting,我喜欢它的灵活性,简单和干净的方式来编写用于深度学习研究的代码。

好了,今天就到这里,祝你愉快。记住要去尝试,不会有什么损失。

—END—

英文原文:https://medium.com/swlh/automate-your-neural-network-training-with-pytorch-lightning-1d7a981322d1

请长按或扫描二维码关注本公众号

喜欢的话,请给我个好看吧

使用PyTorch Lightning自动训练你的深度神经网络相关推荐

  1. pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络

    作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络​mp.weixin.qq.com 导 ...

  2. 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...

  3. NVIDIA向交通运输行业开源其自动驾驶汽车深度神经网络

    NVIDIA今日宣布,在NVIDIA GPU Cloud (NGC)容器注册上,向交通运输行业开源NVIDIA DRIVE™自动驾驶汽车开发深度神经网络. NVIDIA DRIVE已成为自动驾驶汽车开 ...

  4. Pytorch:优化器、损失函数与深度神经网络框架

    Pytorch: 优化器.损失函数与深度神经网络框架 Copyright: Jingmin Wei, Pattern Recognition and Intelligent System, Schoo ...

  5. 深度神经网络混合精度训练

    深度神经网络混合精度训练 Mixed-Precision Training of Deep Neural Networks 论文链接:https://arxiv.org/abs/1710.03740 ...

  6. 第五章 深度神经网络为何很难训练

    原文 假设你是一名工程师,接到一项从头开始设计计算机的任务.某天,你在工作室工作,设计逻辑电路,构建 AND 门,OR 门等等时,老板带着坏消息进来:客户刚刚添加了一个奇特的设计需求:整个计算机的线路 ...

  7. dncnn图像去噪_基于GANs训练去噪深度神经网络实现了良好的图像盲去噪效果

    从包含噪音的图像中去除未知噪音是一项充满挑战的工作,由于缺乏有效的训练数据使得这一领域面临许多问题.中山大学的研究人员们提出了一种"两步走"的框架,通过GANs训练输入图像的噪声分 ...

  8. 深度神经网络的深度是?,深度神经网络通俗理解

    深度学习算法是什么? 深度学习算法是学习样本数据的内在规律和表示层次,这些学习过程中获得的信息对诸如文字,图像和声音等数据的解释有很大的帮助. 它的最终目标是让机器能够像人一样具有分析学习能力,能够识 ...

  9. 深度学习阅读导航 | 05 基于光照感知深度神经网络的多光谱数据融合行人检测

    文章目录 摘要 一.引言 二.相关研究 2.1 可见光和热感行人检测 2.2 多光谱行人检测 三.我们的方法 3.1 建议模型概述 3.2 光照全连接神经网络(IFCNN) 3.3 光照感知双流深度卷 ...

最新文章

  1. yii2项目实战-用户管理之登录与注册功能实现
  2. 基于Html5的爱情主题网站–表白神器(第二版)
  3. POJ2891 Strange Way to Express Integers【扩展中国剩余定理】
  4. Java面试2018常考题目汇总
  5. day02.3-元组内置方法
  6. @Autowired注解警告Field injection is not recommended
  7. Docker教程-简介
  8. Java 14 发布了,不使用class也能定义类了?还顺手要干掉Lombok!
  9. java的this()与super()用法
  10. Win10家庭版禁用系统更新方法汇总及问题解决
  11. 软件概要设计如何写(文档恐惧症的程序猿必读)
  12. 明翰经验系列之面试篇V1.1(持续更新)
  13. uni-app 更改头部导航条背景,改成背景图
  14. 重读Ardupilot中stabilize model+MAVLINK解包过程
  15. 手机照片局部放大镜_怎样发照片才能惊艳朋友圈?
  16. UNI-APP,uni.scanCode扫码页面显示英文,uni.showActionSheet自带取消按钮显示英文问题的解决
  17. 电视剧旗舰剧情分集大结局
  18. linux 虚拟钢琴,Virtual MIDI Piano Keyboard下载-虚拟MIDI钢琴键盘 v0.8.0 官方版 - 安下载...
  19. 基于Pycharm运行李沐老师的深度学习课程代码
  20. SPI的4种采样模式

热门文章

  1. 超赞Win10日历悬停效果,爱了爱了(使用HTML、CSS和vanilla JS)
  2. 【SQL server】基础入门0——理论部分
  3. 【风马一族_气味】组成气味的基本成分探索
  4. 2021年第四季度全球消费者信心总体持平,印度仍为全球最高,中国大幅增长,日本仍远低于全球平均水平 | 美通社头条...
  5. 综合布线系统 (布线系统的一种)
  6. 王者无限火力服务器,王者荣耀无限火力
  7. 网易2019实习生招聘编程题集合 - 题解
  8. Unirech腾讯云代充-通过VNC 登录腾讯云国际版Windows云服务器实例教程
  9. ardruino控制继电器_Arduino基础入门篇24—继电器控制
  10. java11规范_京东Java编码规范V11.pdf