Pytorch-Lightning–Tuner

lr_find()

参数详解

参数名称 含义 默认值
model LightningModule实例
train_dataloaders 训练数据加载器 None
val_dataloaders 验证数据加载器 None
datamodule LightningDataModule实例 None
min_lr 学习率最小值 1e-08
max_lr 学习率最大值 1
num_training 测试学习率的训练轮数 100
mode 学习率寻找策略,分为指数(默认)和线性(linear) exponential
early_stop_threshold 当任意一点的loss>=early_stop_threshold*best_loss时停止搜索,设置为None禁用该项 4.0
update_attr 将搜索到的学习率更新到模型参数中 False

使用注意

  • 暂时只支持单个优化器
  • 暂不支持DDP

用法

使用self.learing_rateself.lr作为学习率参数

class LitModel(LightningModule):def __init__(self, learning_rate):self.learning_rate = learning_ratedef configure_optimizers(self):return Adam(self.parameters(), lr=(self.lr or self.learning_rate))model = LitModel()# 开启 auto_lr_find标志
trainer = Trainer(auto_lr_find=True)
# 寻找合适的学习率
trainer.tune(model)

使用其他的学习率变量名称

model = LitModel()# 设置为自己的学习率超参数名称 my_value
trainer = Trainer(auto_lr_find="my_value")trainer.tune(model)

使用lr_find()查看自动搜索学习率的结果

model = MyModelClass(hparams)
trainer = Trainer()# 运行学习率搜索
lr_finder = trainer.tuner.lr_find(model)# 查看搜索结果
lr_finder.results# 绘制学习率搜索图,suggest参数指定是否显示建议的学习率点
fig = lr_finder.plot(suggest=True)
fig.show()# 获取最佳学习率或建议的学习率
new_lr = lr_finder.suggestion()# 更新模型的学习率
model.hparams.lr = new_lr# 训练模型
trainer.fit(model)

scale_batch_size()

参数详解

参数名称 含义 默认值
model LightningModule实例
train_dataloaders 训练数据加载器 None
val_dataloaders 验证数据加载器 None
datamodule LightningDataModule实例 None
mode 学习率寻找策略,分为幂次方(默认)和二分(binsearch) power
steps_per_trial 每次测试当前batch_size的训练step数量 3
init_val 初始batch_size大小 2
max_trials 算法结束前batch_size最大增量 25
batch_arg_name 存储batch_size的属性名 'batch_size'
  • Returns:搜索结果

将在如下地方寻找batch_arg_name

  • model
  • model.hparams
  • trainer.datamodule (如果datamodule传递给了tune())

使用注意

  • 暂时不支持DDP模式

  • 由于需要使用模型的batch_arg_name属性,因此不能直接将dataloader直接传递给trainer.fit(),否则此功能将失效,需要在模型中加载数据

  • 原来模型中的batch_arg_name属性将被覆盖

  • train_dataloader()应该依赖于batch_arg_name属性

    def train_dataloader(self):return DataLoader(train_dataset, batch_size=self.batch_size | self.hparams.batch_size)
    

用法

使用Trainer中的auto_scale_batch_size属性

# 默认不执行缩放
trainer = Trainer(auto_scale_batch_size=None)# 设置搜索策略
trainer = Trainer(auto_scale_batch_size=None | "power" | "binsearch")# 寻找最佳batch_szie,并自动设置到模型的batch_size属性中
trainer.tune(model)

使用scale_batch_size()

# 返回搜索结果
new_batch_size = tuner.scale_batch_size(model, *extra_parameters_here)# 覆盖原来的属性(这个过程是自动的)
model.hparams.batch_size = new_batch_size

Pytorch-Lightning--Tuner相关推荐

  1. 有bug!用Pytorch Lightning重构代码速度更慢,修复后速度倍增

    选自Medium 作者:Florian Ernst 机器之心编译 编辑:小舟.陈萍 用了 Lightning 训练速度反而更慢,你遇到过这种情况吗? PyTorch Lightning 是一种重构 P ...

  2. 用上Pytorch Lightning的这六招,深度学习pipeline提速10倍!

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI 面对数以 ...

  3. 分离硬件和代码、稳定 API,PyTorch Lightning 1.0.0 版本正式发布

    机器之心报道 机器之心编辑部 还记得那个看起来像 Keras 的轻量版 PyTorch 框架 Lightning 吗?它终于出了 1.0.0 版本,并增添了很多新功能,在度量.优化.日志记录.数据流. ...

  4. 使用PyTorch Lightning自动训练你的深度神经网络

    点击上方"AI公园",关注公众号,选择加"星标"或"置顶" 作者:Erfandi Maula Yusnu, Lalu 编译:ronghuai ...

  5. GitHub高赞!PyTorch Lightning 你值得拥有!

    (给机器学习算法与Python学习加星标,提升AI技能) 本文转自AI新媒体量子位(公众号 ID: QbitAI) 一直以来,PyTorch就以简单又好用的特点,广受AI研究者的喜爱.但是,一旦任务复 ...

  6. 模型泛化技巧“随机权重平均(Stochastic Weight Averaging, SWA)”介绍与Pytorch Lightning的SWA实现讲解

    文章目录 SWA简介 SWA公式 SWA常见参数 Pytorch Lightning的SWA源码分析 SWALR 参考资料 SWA简介 SWA,全程为"Stochastic Weight A ...

  7. pytorch lightning

    背景 众所周知,pytorch是近年热门的深度学习框架之一,与tensorflow相比,普遍认识是pytorch更适合学界,方便学者快速实践深度模型,各类研究论文中,pytorch的算法实现更多.但是 ...

  8. 0.pytorch lightning 入门

    15分钟了解Pytorch Lightning 翻译自官方文档 前置知识:推荐pytorch 目标:通过PL中7个关键步骤了解PL工作流程 PL是基于pytorch的高层API,自带丰富的工具为AI学 ...

  9. Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】

    pytorch是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练,则要安排一下Apex,Apex安装也是很烦啊,我个人经历是各种报错,安装好了程序还是各种报错,而pl则不同,这些 ...

  10. 16、Pytorch Lightning入门

    资源 官方手册 GitHub地址 GItHub案例:Pytorch-Lightning-Template项目 pytorch也是有缺陷的,例如要用半精度训练.BatchNorm参数同步.单机多卡训练, ...

最新文章

  1. 结合实例与代码谈数字图像处理都研究什么?
  2. 龙芯2h芯片不能进入pmon_国产处理器龙芯地址空间详解
  3. java 可变长度参数/动态参数...
  4. 微信小程序之获取验证码js
  5. mac android studio 打不开adb,Android-Macbook ADB无法打开
  6. Onvif2.6.1命名空间前缀对照
  7. SQL的四种连接-左外连接、右外连接、内连接、全连接(转)
  8. mac远程redis_「实战篇」开源项目docker化运维部署-redis高速缓存(六)
  9. Linux下内存使用率、CPU使用率、以及运行原理-转
  10. 树链剖分之点剖分(点分治)讲解
  11. linux命令mysql启动,linux中mysql启动服务命令
  12. 远距离485无线传输方案
  13. 绘制业务流程图—入门篇
  14. IT新人的辛酸反省与总结
  15. [jzoj100047]【NOIP2017提高A组模拟7.14】基因变异
  16. 巨控GRM110无线通信模块
  17. uc_client 同步登陆
  18. springboot maven打包运行失败问题debug分析报告——XXX--1.0-SNAPSHOT.jar中没有主清单属性
  19. Windows中mysql使用命令行登录
  20. 私有文件服务器,私有云文件服务器

热门文章

  1. bat(batch)入门简介
  2. QGIS与网易有道词典冲突
  3. 预计每天全世界上传的短视频超过4亿条
  4. 固高机器人控制器开发笔记
  5. upnp 播放器 android,基于Android系统的UPNP媒体播放器的研究与实现
  6. SVN 安装使用--中文插件-下载项目
  7. 树模型:决策树、随机森林(RF)、AdaBoost、GBDT、XGBoost、LightGBM和CatBoost算法区别及联系
  8. 拼多多商品详情接口,拼多多详情页接口,宝贝详情页接口,商品属性接口,商品信息查询,商品详细信息接口,h5详情,拼多多APP详情
  9. long + ulong_ULONG_MAX常数,带C ++示例
  10. CDN在前端开发中的作用