PyTorch学习笔记(六):PyTorch进阶训练技巧
PyTorch实战:PyTorch进阶训练技巧
往期学习资料推荐:
1.Pytorch实战笔记_GoAI的博客-CSDN博客
2.Pytorch入门教程_GoAI的博客-CSDN博客
本系列目录:
PyTorch学习笔记(一):PyTorch环境安装
PyTorch学习笔记(二):简介与基础知识
PyTorch学习笔记(三):PyTorch主要组成模块
PyTorch学习笔记(四):PyTorch基础实战
PyTorch学习笔记(五):模型定义、修改、保存
PyTorch学习笔记(六):PyTorch进阶训练技巧
PyTorch学习笔记(七):PyTorch可视化
PyTorch学习笔记(八):PyTorch生态简介
后续继续更新!!!!
PyTorch进阶训练技巧
import torch
import torch.nn as nn
import torch.nn.functional as F
1 自定义损失函数
以函数方式定义:通过输出值和目标值进行计算,返回损失值
以类方式定义:通过继承
nn.Module
,将其当做神经网络的一层来看待
以DiceLoss损失函数为例,定义如下:
DSC = \frac{2|X∩Y|}{|X|+|Y|}DSC=∣X∣+∣Y∣2∣X∩Y∣
class DiceLoss(nn.Module):def __init__(self, weight=None, size_average=True):super(DiceLoss,self).__init__()def forward(self, inputs, targets, smooth=1):inputs = F.sigmoid(inputs) inputs = inputs.view(-1)targets = targets.view(-1)intersection = (inputs * targets).sum() dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth) return 1 - dice
2 动态调整学习率
Scheduler:学习率衰减策略,解决学习率选择的问题,用于提高精度
PyTorch Scheduler策略:
- lr_scheduler.LambdaLR
- lr_scheduler.MultiplicativeLR
- lr_scheduler.StepLR
- lr_scheduler.MultiStepLR
- lr_scheduler.ExponentialLR
- lr_scheduler.CosineAnnealingLR
- lr_scheduler.ReduceLROnPlateau
- lr_scheduler.CyclicLR
- lr_scheduler.OneCycleLR
- lr_scheduler.CosineAnnealingWarmRestarts
使用说明:需要将
scheduler.step()
放在optimizer.step()
后面自定义Scheduler:通过自定义函数对学习率进行修改
3 模型微调
概念:找到一个同类已训练好的模型,调整模型参数,使用数据进行训练。
模型微调的流程
- 在源数据集上预训练一个神经网络模型,即源模型
- 创建一个新的神经网络模型,即目标模型,该模型复制了源模型上除输出层外的所有模型设计和参数
- 给目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化改成的模型参数
- 使用目标数据集训练目标模型
使用已有模型结构:通过传入
pretrained
参数,决定是否使用预训练好的权重训练特定层:使用
requires_grad=False
冻结部分网络层,只计算新初始化的层的梯度def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False import torchvision.models as models # 冻结参数的梯度 feature_extract = True model = models.resnet50(pretrained=True) set_parameter_requires_grad(model, feature_extract) # 修改模型 num_ftrs = model.fc.in_features model.fc = nn.Linear(in_features=512, out_features=4, bias=True)model.fcLinear(in_features=512, out_features=4, bias=True)
注:在训练过程中,model仍会回传梯度,但是参数更新只会发生在
fc
层。
4 半精度训练
半精度优势:减少显存占用,提高GPU同时加载的数据量
设置半精度训练:
- 导入
torch.cuda.amp
的autocast
包 - 在模型定义中的
forward
函数上,设置autocast
装饰器 - 在训练过程中,在数据输入模型之后,添加
with autocast()
- 导入
适用范围:适用于数据的size较大的数据集(比如3D图像、视频等)
5 总结
- 自定义损失函数可以通过二种方式:函数方式和类方式,建议全程使用PyTorch提供的张量计算方法。
- 通过使用PyTorch中的scheduler动态调整学习率,也支持自定义scheduler
- 模型微调主要使用已有的预训练模型,调整其中的参数构建目标模型,在目标数据集上训练模型。
- 半精度训练主要适用于数据的size较大的数据集(比如3D图像、视频等)。
PyTorch学习笔记(六):PyTorch进阶训练技巧相关推荐
- PyTorch学习笔记:PyTorch初体验
PyTorch学习笔记:PyTorch初体验 一.在Anaconda里安装PyTorch 1.进入虚拟环境mlcc 2.安装PyTorch 二.在PyTorch创建张量 1.启动mlcc环境下的Spy ...
- pytorch学习笔记 1. pytorch基础 tensor运算
pytorch与tensorflow是两个近些年来使用最为广泛的机器学习模块.开个新坑记录博主学习pytorch模块的过程,不定期更新学习进程. 文章较为适合初学者,欢迎对代码和理解指点讨论,下面进入 ...
- PyTorch学习笔记(15) ——PyTorch中的contiguous
本文转载自栩风在知乎上的文章<PyTorch中的contiguous>.我觉得很好,特此转载. 0. 前言 本文讲解了pytorch中contiguous的含义.定义.实现,以及conti ...
- PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 call
您的位置 首页 PyTorch 学习笔记系列 PyTorch 学习笔记(六):PyTorch hook 和关于 PyTorch backward 过程的理解 发布: 2017年8月4日 7,195阅读 ...
- PyTorch学习笔记(六)——Sequential类、参数管理与GPU
系列文章\text{\bf 系列文章}系列文章 PyTorch学习笔记(一)--Tensor的基础语法 PyTorch学习笔记(二)--自动微分 PyTorch学习笔记(三)--Dataset和Dat ...
- Pytorch学习笔记总结
往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...
- PyTorch学习笔记(七):PyTorch可视化
PyTorch可视化 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一) ...
- PyTorch学习笔记(五):模型定义、修改、保存
往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...
- PyTorch学习笔记(四):PyTorch基础实战
PyTorch实战:以FashionMNIST时装分类为例: 往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本 ...
最新文章
- 让人造太阳更近!DeepMind强化学习算法控制核聚变登上Nature
- python往mysql存入数据_Python操作mysql之插入数据
- java字符串构造函数的应用_StringTokenizer类的使用
- register_sysctl_table实现内核数据交互
- JAVA中for循环写杨辉三角_java使用for循环输出杨辉三角
- 建立唯一索引后mysql策略_【MySQL】MySQL索引背后的之使用策略及优化【转】
- Johnson法则-流水作业调度-动态规划
- js中获得月份getmonth()+1,为什么要加1?
- 你和财务自由之间,只差洋哥的这些建议!!!
- UnityEditor代码分享导出材质贴图和Mesh本体
- android调试遇到ADB server didn't ACK以及蛋疼的sjk_daemon进程
- mysql 从从(主主)复制(故障转移)
- 求一个数各个位数之和
- proteus中仿真51单片系列之---blink点灯程序
- T检验:两样本数据的差异性
- 黑洞猝灭剂BHQ-2 acid,1214891-99-2,BHQ-2 Carboxylic Acid用作各种荧光共振能量转移,这种探针主要用于分析。
- java 动态代理实现原理
- sublime 中英文等宽字体
- 男人二十岁后应该学会的13个习惯
- 数字经济时代下,企业税务管理数字化转型如何做?