Pytorch-lightning

简介

目前好像大多AI训练学习框架都使用的pytorch-lightning,因此今天也来了解一番,以后也要熟练使用,官方的定义为:构建和训练Pytorch 模型,并使用Lightning Apps模板将它们连接到ML 的生命周期,无需处理DIY基础设施,成本管理,扩展和其他令人头疼的问题。

  • github地址:Lightning-AI/lightning
  • 官方API 文档:Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 1.8.0dev documentation )

How to Use

  1. Install
pip install pytorch-lightning
  1. Add the imports

    import os
    import torch
    from torch import nn
    import torch.nn.functional as F
    from torchvision.datasets import MNIST
    from torch.utils.data import DataLoader,random_split
    from torchvision import transforms
    import pytorch_lightning as pl
    
  2. Define a LightningModule (nn.Module)

    class LitAutoEncoder(pl.LightningModuel):def __init__(self)super().__init__()self.encoder=nn.Sequential(nn.Linear(28*28,128),nn.ReLU(),nn.Linear(128,3))self.decoder=nn.Sequential(nn.Linear(3,128),nn.ReLU(),nn.Linear(128,28*28))def forward(self,x):embedding=self.encoder(x)return embeddingdef training_step(self,batch,batch_idx):x,y=batchx=x.view(x.size(0),-1)z=self.encoder(x)x_hat=self.decoder(z)loss=F.mse_loss(x_hat,x)self.log('train_loss',loss)return lossdef configure_optimizers(self):optimizer=torch.optim.Adam(self.parameters(),lr=1e-3)return optimizer
    
  3. Train

    dataset=MNIST(os.getcwd(),download=True,transform=transforms.ToTensor())
    train,val=random_split(dataset,[55000,5000])autoencoder=LitAutoEncoder()
    trainer=pl.Trainer()
    trainer.fit(autoencoder,DataLoader(train),DataLoader(val))
    

Advanced feature

  • 多GPU

    trainer=Trainer(max_epochs=1,accelerator='gpu',device=8)
    
  • TPU

  • 16 位精度

  • 实验logging

  • early_stopping

    es=EarlyStopping(monitor='val_loss')
    trainer=Trainer(callbacks=[checkpointing])
    
  • model checkpoint

    checkpointing=ModelCheckpoint(monitor='val_loss')
    trainer=Trainer(callbacks=[checkpointing])
    
  • torchscript

    # torchscript
    autoencoder = LitAutoEncoder()
    torch.jit.save(autoencoder.to_torchscript(), "model.pt")
    
  • ONNX

    # onnx
    with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:autoencoder = LitAutoEncoder()input_sample = torch.randn((1, 64))autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)os.path.isfile(tmpfile.name)
    
  • training tricks
    40+的training trick供我们选择

Advantages

  • 模型与硬件无关
  • 代码简化
  • 已于重构
  • 犯更少的mistakes
  • 保存了灵活性,但移除了大量样本
  • 与流行的机器学习工具有集成
  • 不同Python,Pytorch版本,操作系统,GPT进行支持
  • 加快运行速度

手动控制训练过程

class LitAntoEncoder(pl.LightningModule):def __init__(self):super().__init__()self.automatic_optimization=Falsedef training_step(self,batch,batch_idx):# access your optimizers with use_pl_optimizer=False. Default is Trueopt_a, opt_b = self.optimizers(use_pl_optimizer=True)loss_a = ...self.manual_backward(loss_a, opt_a)opt_a.step()opt_a.zero_grad()loss_b = ...self.manual_backward(loss_b, opt_b, retain_graph=True)self.manual_backward(loss_b, opt_b)opt_b.step()opt_b.zero_grad()

Example

Hello world
  • MNIST
Contrastive Learning
  • BYOL
  • CPC v2
  • Moco v2
  • SIMCLR
NLP
  • GPT-2
  • BERT

Reinforcement Learning

  • DQN
  • Dueling-DQN
  • Reinforce
Vision
  • GAN
Classic ML
  • Logistic Regression
  • Linear Regression

官方API教程

Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 1.8.0dev documentation (pytorch-lightning.readthedocs.io)

总结

Pytorch-lightning 作为2w star的github项目一定是很有用的,目前我仅仅尝试了一些example,需要完全掌握pytorch-ligthning中的简单语法,然后确实可以帮助我们减少重复AI代码的编写。

Pytorch-lightning相关推荐

  1. 有bug!用Pytorch Lightning重构代码速度更慢,修复后速度倍增

    选自Medium 作者:Florian Ernst 机器之心编译 编辑:小舟.陈萍 用了 Lightning 训练速度反而更慢,你遇到过这种情况吗? PyTorch Lightning 是一种重构 P ...

  2. 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...

  3. 分离硬件和代码、稳定 API,PyTorch Lightning 1.0.0 版本正式发布

    机器之心报道 机器之心编辑部 还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量.优化.日志记录.数据流. ...

  4. 使用PyTorch Lightning自动训练你的深度神经网络

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...

  5. GitHub高赞!PyTorch Lightning 你值得拥有!

    (给机器学习算法与Python学习加星标,提升AI技能) 本文转自AI新媒体量子位(公众号 ID: QbitAI) 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱.但是,一旦任务复 ...

  6. 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

    文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...

  7. pytorch lightning

    背景 众所周知,pytorch是近年热门的深度学习框架之一,与tensorflow相比,普遍认识是pytorch更适合学界,方便学者快速实践深度模型,各类研究论文中,pytorch的算法实现更多.但是 ...

  8. 0.pytorch lightning 入门

    15分钟了解Pytorch Lightning 翻译自官方文档 前置知识:推荐pytorch 目标:通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学 ...

  9. Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】

    pytorch是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练,则要安排一下Apex,Apex安装也是很烦啊,我个人经历是各种报错,安装好了程序还是各种报错,而pl则不同,这些 ...

  10. 16、Pytorch Lightning入门

    资源 官方手册 GitHub地址 GItHub案例:Pytorch-Lightning-Template项目 pytorch也是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练, ...

最新文章

  1. 智能手机背面玻璃的缺陷检测,分割网络的应用
  2. 如何重新安装TCP/IP协议
  3. 对AD资料复制的理解
  4. springboot 直接转发调用_springboot-过滤器的页面跳转【重定向与请求转发】-异常报错...
  5. 动态链接MFC引发的血案
  6. 大神把可视化放进数据地图里,原来不敲代码一样能做
  7. 【华为云技术分享】云小课 | 迁移第三方云厂商数据至OBS,两种方式任你选
  8. private public php,[php]private public protected 三者区别
  9. 电大计算机网络模拟题及答案,最新国家开放大学电大本科《计算机网络》期末题库及答案...
  10. opencv 鼠标点击处视频的坐标和rgbw值
  11. 使用RDP报表工具实现多级表头动态列
  12. 爬取百度贴吧发帖信息并保存到scv文件中
  13. 用word模仿手写字体
  14. 结合运动流的时间先验在微创手术视频中的器械分割
  15. rk3288 linux 编译,RK3288系统编译及环境搭建
  16. msfvenom生成木马的简单利用
  17. CSMA/CD 和 CSMA/CA 之原理
  18. SeleniumWebDriver之FindElement和FindElements
  19. 五子棋程序设计(C语言、人机对战、禁手)
  20. Microsoft Sql Server Studio 2019 没有配置管理器解决办法

热门文章

  1. java读取mdb类型的数据
  2. JS_02_函数_运算符_循环
  3. 戴尔服务器开机无显示器,戴尔液晶显示器开机无显示原因是电容问题?
  4. 0基础学RS(十四)VTP(VLAN中继协议)作用及配置
  5. Java毕设项目在线招投标系统(java+VUE+Mybatis+Maven+Mysql)
  6. MATLAB 图片三角风格化(low poly)
  7. plc远程监控.plc远程通讯
  8. 【徐云流浪中国】地图版
  9. Java zip 解压缩
  10. iOS端的UI设计文档