点上方计算机视觉联盟获取更多干货

仅作学术分享,不代表本公众号立场,侵权联系删除

转载于:作者丨Caliber@知乎(已授权)

来源丨https://zhuanlan.zhihu.com/p/370185203

编辑丨极市平台

AI博士笔记系列推荐

周志华《机器学习》手推笔记正式开源!可打印版本附pdf下载链接

Pytorch-lightning(以下简称pl)可以非常简洁得构建深度学习代码。但是其实大部分人用不到很多复杂得功能。而pl有时候包装得过于深了,用的时候稍微有一些不灵活。通常来说,在你的模型搭建好之后,大部分的功能都会被封装在一个叫trainer的类里面。一些比较麻烦但是需要的功能通常如下:

  1. 保存checkpoints

  2. 输出log信息

  3. resume training 即重载训练,我们希望可以接着上一次的epoch继续训练

  4. 记录模型训练的过程(通常使用tensorboard)

  5. 设置seed,即保证训练过程可以复制

好在这些功能在pl中都已经实现。

由于doc上的很多解释并不是很清楚,而且网上例子也不是特别多。下面分享一点我自己的使用心得。

首先关于设置全局的种子:

from pytorch_lightning import seed_everything
# Set seedseed = 42seed_everything(seed)

只需要import如上的seed_everything函数即可。它应该和如下的函数是等价的:

def seed_all(seed_value):    random.seed(seed_value) # Python    np.random.seed(seed_value) # cpu vars    torch.manual_seed(seed_value) # cpu vars        if torch.cuda.is_available():         print ('CUDA is available')        torch.cuda.manual_seed(seed_value)        torch.cuda.manual_seed_all(seed_value) # gpu vars        torch.backends.cudnn.deterministic = True  #needed        torch.backends.cudnn.benchmark = False
seed=42seed_all(seed)

但经过我的测试,好像pl的seed_everything函数应该更全一点。

下面通过一个具体的例子来说明一些使用方法:

先下载、导入必要的包和下载数据集:

!pip install pytorch-lightning!wget https://download.pytorch.org/tutorial/hymenoptera_data.zip!unzip -q hymenoptera_data.zip!rm hymenoptera_data.zip
import pytorch_lightning as plimport osimport numpy as np import randomimport matplotlib.pyplot as plt
import torchimport torch.nn.functional as Fimport torchvisionimport torchvision.transforms as transforms

以下代码种加入!的代码是在terminal中运行的。在google colab中运行linux命令需要在之前加!

如果是使用google colab,由于它创建的是一个虚拟机,不能及时保存,所以如果需要保存,挂载自己google云盘也是有必要的。使用如下的代码:

from google.colab import drivedrive.mount('./content/drive')
import osos.chdir("/content/drive/My Drive/")

先如下定义如下的LightningModule和main函数。

class CoolSystem(pl.LightningModule):def __init__(self, hparams):super(CoolSystem, self).__init__()self.params = hparamsself.data_dir = self.params.data_dirself.num_classes = self.params.num_classes ########## define the model ########## arch = torchvision.models.resnet18(pretrained=True)num_ftrs = arch.fc.in_featuresmodules = list(arch.children())[:-1] # ResNet18 has 10 childrenself.backbone = torch.nn.Sequential(*modules) # [bs, 512, 1, 1]self.final = torch.nn.Sequential(torch.nn.Linear(num_ftrs, 128),torch.nn.ReLU(inplace=True),torch.nn.Linear(128, self.num_classes),torch.nn.Softmax(dim=1))def forward(self, x):x = self.backbone(x)x = x.reshape(x.size(0), -1)x = self.final(x)return xdef configure_optimizers(self):# REQUIREDoptimizer = torch.optim.SGD([{'params': self.backbone.parameters()},{'params': self.final.parameters(), 'lr': 1e-2}], lr=1e-3, momentum=0.9)exp_lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)return [optimizer], [exp_lr_scheduler]def training_step(self, batch, batch_idx):# REQUIREDx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, dim=1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)self.log('train_loss', loss)self.log('train_acc', acc)return {'loss': loss, 'train_acc': acc}def validation_step(self, batch, batch_idx):# OPTIONALx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, 1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)self.log('val_loss', loss)self.log('val_acc', acc)return {'val_loss': loss, 'val_acc': acc}def test_step(self, batch, batch_idx):# OPTIONALx, y = batchy_hat = self.forward(x)loss = F.cross_entropy(y_hat, y)_, preds = torch.max(y_hat, 1)acc = torch.sum(preds == y.data) / (y.shape[0] * 1.0)return {'test_loss': loss, 'test_acc': acc}def train_dataloader(self):# REQUIREDtransform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])train_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'train'), transform)train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=4)return train_loaderdef val_dataloader(self):transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=True, num_workers=4)return val_loaderdef test_dataloader(self):transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])val_set = torchvision.datasets.ImageFolder(os.path.join(self.data_dir, 'val'), transform)val_loader = torch.utils.data.DataLoader(val_set, batch_size=8, shuffle=True, num_workers=4)return val_loaderdef main(hparams):model = CoolSystem(hparams)trainer = pl.Trainer(max_epochs=hparams.epochs,gpus=1,accelerator='dp')  trainer.fit(model)
下面是run的部分:
from argparse import Namespace
args = {    'num_classes': 2,    'epochs': 5,    'data_dir': "/content/hymenoptera_data",}
hyperparams = Namespace(**args)if __name__ == '__main__':    main(hyperparams)

如果希望重载训练的话,可以按如下方式:

# resume training
RESUME = True
if RESUME:    resume_checkpoint_dir = './lightning_logs/version_0/checkpoints/'    checkpoint_path = os.listdir(resume_checkpoint_dir)[0]    resume_checkpoint_path = resume_checkpoint_dir + checkpoint_pathargs = {    'num_classes': 2,    'data_dir': "/content/hymenoptera_data"}hparams = Namespace(**args)model = CoolSystem(hparams)trainer = pl.Trainer(gpus=1,                 max_epochs=10,                             accelerator='dp',                resume_from_checkpoint = resume_checkpoint_path)trainer.fit(model)

如果我们想要从checkpoint加载模型,并进行使用可以按如下操作来:

import matplotlib.pyplot as pltimport numpy as np
# functions to show an imagedef imshow(inp):    inp = inp.numpy().transpose((1, 2, 0))    mean = np.array([0.485, 0.456, 0.406])    std = np.array([0.229, 0.224, 0.225])    inp = std * inp + mean    inp = np.clip(inp, 0, 1)    plt.imshow(inp)    plt.show()
classes = ['ants', 'bees']
checkpoint_dir = 'lightning_logs/version_1/checkpoints/'checkpoint_path = checkpoint_dir + os.listdir(checkpoint_dir)[0]
checkpoint = torch.load(checkpoint_path)model_infer = CoolSystem(hparams)model_infer.load_state_dict(checkpoint['state_dict'])
try_dataloader = model_infer.test_dataloader()
inputs, labels = next(iter(try_dataloader))
# print images and ground truthimshow(torchvision.utils.make_grid(inputs))print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(8)))
# inferenceoutputs = model_infer(inputs)
_, preds = torch.max(outputs, dim=1)# print (preds)print (torch.sum(preds == labels.data) / (labels.shape[0] * 1.0))
print('Predicted: ', ' '.join('%5s' % classes[preds[j]] for j in range(8)))

预测结果如上。

如果希望检测训练过程(第一部分+重载训练的部分),如下:

# tensorboard
%load_ext tensorboard%tensorboard --logdir = ./lightning_logs

训练过程在tensorboard里面记录,version0是第一次的训练,version1是重载后的结果。

完整的code在这里.

https://colab.research.google.com/gist/calibertytz/a9de31175ce15f384dead94c2a9fad4d/pl_tutorials_1.ipynb

-------------------

END

--------------------

我是王博Kings,985AI博士,华为云专家、CSDN博客专家(人工智能领域优质作者)。单个AI开源项目现在已经获得了2100+标星。现在在做AI相关内容,欢迎一起交流学习、生活各方面的问题,一起加油进步!

我们微信交流群涵盖以下方向(但并不局限于以下内容):人工智能,计算机视觉,自然语言处理,目标检测,语义分割,自动驾驶,GAN,强化学习,SLAM,人脸检测,最新算法,最新论文,OpenCV,TensorFlow,PyTorch,开源框架,学习方法...

这是我的私人微信,位置有限,一起进步!

王博的公众号,欢迎关注,干货多多

王博Kings的系列手推笔记(附高清PDF下载):

博士笔记 | 周志华《机器学习》手推笔记第一章思维导图

博士笔记 | 周志华《机器学习》手推笔记第二章“模型评估与选择”

博士笔记 | 周志华《机器学习》手推笔记第三章“线性模型”

博士笔记 | 周志华《机器学习》手推笔记第四章“决策树”

博士笔记 | 周志华《机器学习》手推笔记第五章“神经网络”

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(上)

博士笔记 | 周志华《机器学习》手推笔记第六章支持向量机(下)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(上)

博士笔记 | 周志华《机器学习》手推笔记第七章贝叶斯分类(下)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(上)

博士笔记 | 周志华《机器学习》手推笔记第八章集成学习(下)

博士笔记 | 周志华《机器学习》手推笔记第九章聚类

博士笔记 | 周志华《机器学习》手推笔记第十章降维与度量学习

博士笔记 | 周志华《机器学习》手推笔记第十一章稀疏学习

博士笔记 | 周志华《机器学习》手推笔记第十二章计算学习理论

博士笔记 | 周志华《机器学习》手推笔记第十三章半监督学习

博士笔记 | 周志华《机器学习》手推笔记第十四章概率图模型

点分享

点收藏

点点赞

点在看

收藏 | Pytorch-lightning的使用相关推荐

  1. 有bug!用Pytorch Lightning重构代码速度更慢,修复后速度倍增

    选自Medium 作者:Florian Ernst 机器之心编译 编辑:小舟.陈萍 用了 Lightning 训练速度反而更慢,你遇到过这种情况吗? PyTorch Lightning 是一种重构 P ...

  2. 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...

  3. 分离硬件和代码、稳定 API,PyTorch Lightning 1.0.0 版本正式发布

    机器之心报道 机器之心编辑部 还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量.优化.日志记录.数据流. ...

  4. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  5. 使用PyTorch Lightning自动训练你的深度神经网络

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...

  6. GitHub高赞!PyTorch Lightning 你值得拥有!

    (给机器学习算法与Python学习加星标,提升AI技能) 本文转自AI新媒体量子位(公众号 ID: QbitAI) 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱.但是,一旦任务复 ...

  7. 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

    文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...

  8. pytorch lightning

    背景 众所周知,pytorch是近年热门的深度学习框架之一,与tensorflow相比,普遍认识是pytorch更适合学界,方便学者快速实践深度模型,各类研究论文中,pytorch的算法实现更多.但是 ...

  9. 0.pytorch lightning 入门

    15分钟了解Pytorch Lightning 翻译自官方文档 前置知识:推荐pytorch 目标:通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学 ...

  10. Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】

    pytorch是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练,则要安排一下Apex,Apex安装也是很烦啊,我个人经历是各种报错,安装好了程序还是各种报错,而pl则不同,这些 ...

最新文章

  1. position 定位
  2. C#苹果应用开发——第一讲初始Xamarin Xamarin ios 教程 Xamarin跨平台开发
  3. 如何安装rpm包?掌握rpm包管理工具就够了
  4. 5G已来,你能做些什么?
  5. CNN的发展历史(LeNet,Alexnet,VGGNet,GoogleNet,ReSNet)
  6. U盘装win7系统出现question(1808)的解决方法
  7. confluence统计用户文章_首次,Flink公众号公开一些后台统计数据
  8. 说说大型高并发高负载网站的系统架构 (转)
  9. python基础系列教程——python中的字符串和正则表达式全解
  10. 论发SCI论文和生孩子的共同点:那我这篇怀的也太久了!
  11. html点击热力图还原,网站页面点击热力图的SEO工具说明
  12. Redux Reducer 的拆分
  13. work with用法
  14. st58服务器装系统,联想 Thinksystem ST58服务器介绍
  15. 输了腾讯赢了阿里:凭借27天超强度复习Java核心知识+面试神技,三面阿里斩获P6岗offer(飞猪事业部)
  16. 2021年广西省安全员C证免费试题及广西省安全员C证考试试卷
  17. 一道积分不等式的最优估计探索
  18. nowcoder 79F 小H和圣诞树 换根 DP + 根号分治
  19. lm393 过零检测 功率因数检测
  20. Windows设置nacos自启动

热门文章

  1. python写脚本入门-学习Python的教程?:python 脚本菜鸟教程
  2. java 7 40,Java 7u40 Java SE 8 sun.reflect.Reflection.getCallerClass
  3. django+bootstrap_Django自学教程PDF高清文档下载
  4. 上银伺服驱动器说明书_威海伺服驱动器维修,诚信互利
  5. 用c语言输出1 n平方自然数魔方阵,用C语言求:打印出由1到n平方的自然数的魔方阵...
  6. docker启动mysql容器后又退出_docker容器刚运行就自动退出了
  7. oracle 类似decode,类似于ORACLE decode 的用法
  8. PHP表单提交后页面跳转,PHP在表单提交后重定向到另一个页面
  9. php socket 不能用,PHP无法用Socket方式连接MySQ
  10. zabbix对网站web监控(配置模板)