Pytorch-lightning
Pytorch-lightning
简介
目前好像大多AI训练学习框架都使用的pytorch-lightning,因此今天也来了解一番,以后也要熟练使用,官方的定义为:构建和训练Pytorch 模型,并使用Lightning Apps模板将它们连接到ML 的生命周期,无需处理DIY基础设施,成本管理,扩展和其他令人头疼的问题。
- github地址:Lightning-AI/lightning
- 官方API 文档:Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 1.8.0dev documentation )
How to Use
- Install
pip install pytorch-lightning
Add the imports
import os import torch from torch import nn import torch.nn.functional as F from torchvision.datasets import MNIST from torch.utils.data import DataLoader,random_split from torchvision import transforms import pytorch_lightning as pl
Define a LightningModule (nn.Module)
class LitAutoEncoder(pl.LightningModuel):def __init__(self):super().__init__()self.encoder=nn.Sequential(nn.Linear(28*28,128),nn.ReLU(),nn.Linear(128,3))self.decoder=nn.Sequential(nn.Linear(3,128),nn.ReLU(),nn.Linear(128,28*28))def forward(self,x):embedding=self.encoder(x)return embeddingdef training_step(self,batch,batch_idx):x,y=batchx=x.view(x.size(0),-1)z=self.encoder(x)x_hat=self.decoder(z)loss=F.mse_loss(x_hat,x)self.log('train_loss',loss)return lossdef configure_optimizers(self):optimizer=torch.optim.Adam(self.parameters(),lr=1e-3)return optimizer
Train
dataset=MNIST(os.getcwd(),download=True,transform=transforms.ToTensor()) train,val=random_split(dataset,[55000,5000])autoencoder=LitAutoEncoder() trainer=pl.Trainer() trainer.fit(autoencoder,DataLoader(train),DataLoader(val))
Advanced feature
多GPU
trainer=Trainer(max_epochs=1,accelerator='gpu',device=8)
TPU
16 位精度
实验logging
early_stopping
es=EarlyStopping(monitor='val_loss') trainer=Trainer(callbacks=[checkpointing])
model checkpoint
checkpointing=ModelCheckpoint(monitor='val_loss') trainer=Trainer(callbacks=[checkpointing])
torchscript
# torchscript autoencoder = LitAutoEncoder() torch.jit.save(autoencoder.to_torchscript(), "model.pt")
ONNX
# onnx with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmpfile:autoencoder = LitAutoEncoder()input_sample = torch.randn((1, 64))autoencoder.to_onnx(tmpfile.name, input_sample, export_params=True)os.path.isfile(tmpfile.name)
training tricks
40+的training trick供我们选择
Advantages
- 模型与硬件无关
- 代码简化
- 已于重构
- 犯更少的mistakes
- 保存了灵活性,但移除了大量样本
- 与流行的机器学习工具有集成
- 不同Python,Pytorch版本,操作系统,GPT进行支持
- 加快运行速度
手动控制训练过程
class LitAntoEncoder(pl.LightningModule):def __init__(self):super().__init__()self.automatic_optimization=Falsedef training_step(self,batch,batch_idx):# access your optimizers with use_pl_optimizer=False. Default is Trueopt_a, opt_b = self.optimizers(use_pl_optimizer=True)loss_a = ...self.manual_backward(loss_a, opt_a)opt_a.step()opt_a.zero_grad()loss_b = ...self.manual_backward(loss_b, opt_b, retain_graph=True)self.manual_backward(loss_b, opt_b)opt_b.step()opt_b.zero_grad()
Example
Hello world
- MNIST
Contrastive Learning
- BYOL
- CPC v2
- Moco v2
- SIMCLR
NLP
- GPT-2
- BERT
Reinforcement Learning
- DQN
- Dueling-DQN
- Reinforce
Vision
- GAN
Classic ML
- Logistic Regression
- Linear Regression
官方API教程
Welcome to ⚡ PyTorch Lightning — PyTorch Lightning 1.8.0dev documentation (pytorch-lightning.readthedocs.io)
总结
Pytorch-lightning 作为2w star的github项目一定是很有用的,目前我仅仅尝试了一些example,需要完全掌握pytorch-ligthning中的简单语法,然后确实可以帮助我们减少重复AI代码的编写。
Pytorch-lightning相关推荐
- 有bug!用Pytorch Lightning重构代码速度更慢,修复后速度倍增
选自Medium 作者:Florian Ernst 机器之心编译 编辑:小舟.陈萍 用了 Lightning 训练速度反而更慢,你遇到过这种情况吗? PyTorch Lightning 是一种重构 P ...
- 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!
点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...
- 分离硬件和代码、稳定 API,PyTorch Lightning 1.0.0 版本正式发布
机器之心报道 机器之心编辑部 还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量.优化.日志记录.数据流. ...
- 使用PyTorch Lightning自动训练你的深度神经网络
点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...
- GitHub高赞!PyTorch Lightning 你值得拥有!
(给机器学习算法与Python学习加星标,提升AI技能) 本文转自AI新媒体量子位(公众号 ID: QbitAI) 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱.但是,一旦任务复 ...
- 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解
文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...
- pytorch lightning
背景 众所周知,pytorch是近年热门的深度学习框架之一,与tensorflow相比,普遍认识是pytorch更适合学界,方便学者快速实践深度模型,各类研究论文中,pytorch的算法实现更多.但是 ...
- 0.pytorch lightning 入门
15分钟了解Pytorch Lightning 翻译自官方文档 前置知识:推荐pytorch 目标:通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学 ...
- Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】
pytorch是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练,则要安排一下Apex,Apex安装也是很烦啊,我个人经历是各种报错,安装好了程序还是各种报错,而pl则不同,这些 ...
- 16、Pytorch Lightning入门
资源 官方手册 GitHub地址 GItHub案例:Pytorch-Lightning-Template项目 pytorch也是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练, ...
最新文章
- 智能手机背面玻璃的缺陷检测,分割网络的应用
- 如何重新安装TCP/IP协议
- 对AD资料复制的理解
- springboot 直接转发调用_springboot-过滤器的页面跳转【重定向与请求转发】-异常报错...
- 动态链接MFC引发的血案
- 大神把可视化放进数据地图里,原来不敲代码一样能做
- 【华为云技术分享】云小课 | 迁移第三方云厂商数据至OBS,两种方式任你选
- private public php,[php]private public protected 三者区别
- 电大计算机网络模拟题及答案,最新国家开放大学电大本科《计算机网络》期末题库及答案...
- opencv 鼠标点击处视频的坐标和rgbw值
- 使用RDP报表工具实现多级表头动态列
- 爬取百度贴吧发帖信息并保存到scv文件中
- 用word模仿手写字体
- 结合运动流的时间先验在微创手术视频中的器械分割
- rk3288 linux 编译,RK3288系统编译及环境搭建
- msfvenom生成木马的简单利用
- CSMA/CD 和 CSMA/CA 之原理
- SeleniumWebDriver之FindElement和FindElements
- 五子棋程序设计(C语言、人机对战、禁手)
- Microsoft Sql Server Studio 2019 没有配置管理器解决办法