简介

在之前专栏的两篇文章中我主要介绍了数据的准备以及模型的构建,模型构建完成的下一步就是模型的训练优化,训练完成的模型用于实际应用中。

损失函数

损失函数用于衡量预测值与目标值之间的误差,通过最小化损失函数达到模型的优化目标。不同的损失函数其衡量效果不同,未必都是出于最好的精度而设计的。PyTorch对于很多常用的损失函数进行了封装,均在torch.nn模块下,它们的使用方法类似,实例化损失计算对象,然后用实例化的对象对预测值和目标值进行损失计算即可。

  • L1损失

    • nn.L1Loss(reduction)
    • 计算L1损失即绝对值误差。
    • reduce参数表示是否返回标量,默认返回标量,否则返回同维张量。
    • size_average参数表示是否返回的标量为均值,默认为均值,否则为求和结果。
    • reduction参数取代了上述两个参数,meansumNone的取值对应上面的结果。
    • 下面代码可以演示损失的计算流程。
      import torch
      from torch import nn
      pred = torch.ones(100, 1) * 0.5
      label = torch.ones(100, 1)l1_mean = nn.L1Loss()
      l1_sum = nn.L1Loss(reduction='sum')print(l1_mean(pred, label))
      print(l1_sum(pred, label))
      
  • MSE损失
    • nn.MSELoss(reduction='mean')
    • 计算均方误差,常用于回归问题。
    • 参数同上。
  • CE损失
    • nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')
    • 计算交叉熵损失,常用于分类问题。并非标准的交叉熵,而是结合了Softmax的结果,也就是说会将结果先进行softmax计算为概率分布。
    • weight参数是每个类别的权重,用于解决样本不均衡问题。
    • reduction参数类似上面的损失函数。
    • ignore_index参数表示忽略某个类别,不计算其损失。
  • KL散度
    • nn.KLDivLoss(reduction='mean')
    • 计算KL散度。
    • 参数同上。
  • 二分交叉熵
    • nn.BCELoss(reduction='mean')
    • 计算二分交叉熵损失,一般用于二分类问题。
  • 逻辑二分交叉熵
    • nn.BCEWithLogitsLoss()
    • 输入先经过sigmoid变换再计算损失,类似CE损失。

上述只是提到了几个常用的简单损失函数,更加复杂的可以查看官方文档,一共封装了近20个损失,当然,也可以自定义损失函数,返回一个张量或者标量即可(事实上这些损失函数就是这么干的)。

优化器

数据、模型、损失函数都确定了,那这个深度模型任务其实已经完成了大半,接下来就是选择合适的优化器对模型进行优化训练。

首先,要了解PyTorch中优化器的机制,其所有优化器都是继承自Optimizer类,该类封装了一套基础的方法如state_dict()load_state_dict()等。

参数组(param_groups)

任何优化器都有一个属性为param_groups,这是因为优化器对参数的管理是基于组进行的,为每一组参数配置特定的学习率、动量比例、衰减率等等,该属性为一个列表,里面多个字典,对应不同的参数及其配置。

例如下面的代码中只有一个组。

import torch
import torch.optim as optimw1 = torch.randn(2, 2)
w2 = torch.randn(2, 2)optimizer = optim.SGD([w1, w2], lr=0.1)
print(optimizer.param_groups)

梯度清零

事实上,PyTorch不会在一次优化完成后清零之前计算得到的梯度,所以需要每次优化完成后手动清零,即调用优化器的zero_grad()方法。

参数组添加

通过调用优化器的add_param_group()方法可以添加一组定制的参数。

常用优化器

PyTorch将这些优化算法均封装于torch.optim模块下,其实现时对原论文有所改动,具体参见源码。

  • 随机梯度下降

    • optim.SGD(params, lr, momentum, weight_decay)
    • 随机梯度下降优化器。
    • params参数表示需要管理的参数组。
    • lr参数表示初始学习率,可以按需调整学习率。
    • momentum参数表示动量SGD中的动量值,一般为0.9。
    • weight_decay参数表示权重衰减系数,也是L2正则系数。
  • 随机平均梯度下降
    • optim.Adam(params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False)
    • Adam优化算法的实现。
    • 参数类似上面。

下图演示了各种算法相同情境下的收敛效果。

学习率调整策略

合适的学习率可以使得模型迅速收敛,这也是Adam等算法的初衷,一般我们训练时会在开始给一个较大的学习率,随着训练的进行逐渐下调这个学习率。那么何时下调、下调多少,相关的问题就是学习率调整策略,PyTorch提供了6中策略以供使用,它们都在torch.optim.lr_scheduler中,分为有序调整(较为死板)、自适应调整(较为灵活)和自定义调整(适合各种情况)。

下面介绍最常用的自动学习率调整机制。它封装为optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=False, threshold=0.0001,threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8)

当指标不再变化时即调整学习率,这是一种非常实用的学习率调整策略。例如,当验证集的损失不再下降即即将陷入过拟合,进行学习率调整。

  • mode参数由两种为minmax,当指标不再变低或者变高时调整。
  • factor参数表示学习率调整比例。
  • patience参数表示等待耐心,当patience个step指标不变即调整学习率。
  • verbose参数表示调整学习率是否可见。
  • cooldown参数表示冷却时间,调整后冷却时间内不再调整。
  • min_lr参数表示学习率下限。
  • eps参数表示学习率衰减最小值,学习率变化小于该值不调整。

训练流程实战

下面的代码演示了数据的导入、模型构建、损失函数使用以及优化器的优化整个流程,大部分时候我们使用PyTorch进行模型训练都是这个思路

import torch
from torch import nn
import torch.nn.functional as F
from torch import optimclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=(3, 3))self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3)self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(64*54*54, 256)self.fc2 = nn.Linear(256, 128)self.fc3 = nn.Linear(128, 101)def forward(self, x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = x.view(-1, 64*54*54)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xnet = Net()
x = torch.randn((32, 3, 224, 224))
y = torch.ones(32, ).long()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, dampening=0.1)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)epochs = 10
losses = []
for epoch in range(epochs):correct = 0.0total = 0.0optimizer.zero_grad()outputs = net(x)loss = criterion(outputs, y)loss.backward()optimizer.step()scheduler.step()_, predicted = torch.max(outputs.data, 1)total += y.size(0)correct += (predicted == y).squeeze().sum().numpy()losses.append(loss.item())print("loss", loss.item(), "acc", correct / total)import matplotlib.pyplot as plt
plt.plot(list(range(len(losses))), losses)
plt.savefig('his.png')
plt.show()

其训练损失变化图如下,由于只是给出的demo数据,训练很快收敛,准确率一轮达到100%。

补充说明

本文介绍了PyTorch中损失函数的使用以及优化器的优化流程,这也是深度模型训练的最后步骤,比较重要。本文的所有代码均开源于我的Github,欢迎star或者fork。

PyTorch-训练相关推荐

  1. 让PyTorch训练速度更快,你需要掌握这17种方法

    选自efficientdl.com 作者:LORENZ KUHN 机器之心编译 编辑:陈萍 掌握这 17 种方法,用最省力的方式,加速你的 Pytorch 深度学习训练. 近日,Reddit 上一个帖 ...

  2. PyTorch训练加速17种技巧

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 文自 机器之心 作者:LORENZ KUHN 编辑:陈萍 掌握这 ...

  3. Pytorch 训练与测试时爆显存(cuda out of memory)的终极解决方案,使用cpu(勿喷)

    Pytorch 训练与测试时爆显存(cuda out of memory)的终极解决方案,使用cpu(勿喷) 参见了很多方法,都没有用. 简单点,直接把gpu设成-1

  4. 送你9个快速使用Pytorch训练解决神经网络的技巧(附代码)

    来源:读芯术 本文约4800字,建议阅读10分钟. 本文为大家介绍9个使用Pytorch训练解决神经网络的技巧 图片来源:unsplash.com/@dulgier 事实上,你的模型可能还停留在石器时 ...

  5. 若使用numba.cuda.jit加速pytorch训练代码会怎样

    也许没有察觉 在使用pytorch训练数据的时候cuda 显卡总是发挥不到最大性能 这就是你的cpu程序拖住了你的显卡 怎么办 目前我能想到的最好方法就是 使用numba.cuda.jit这样你也不用 ...

  6. pytorch训练过程中loss出现NaN的原因及可采取的方法

    在pytorch训练过程中出现loss=nan的情况 1.学习率太高. 2.loss函数 3.对于回归问题,可能出现了除0 的计算,加一个很小的余项可能可以解决 4.数据本身,是否存在Nan,可以用n ...

  7. 这17 种方法让 PyTorch 训练速度更快!

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:选自 | efficientdl.com   作者 | LO ...

  8. DataLoader worker (pid 2287) is killed by signal: Killed. pytorch训练解决方法

    DataLoader worker (pid 2287) is killed by signal: Killed. pytorch训练解决方法 参考文章: (1)DataLoader worker ( ...

  9. 如何用PyTorch训练图像分类器

    本文为 AI 研习社编译的技术博客,原标题 : How to Train an Image Classifier in PyTorch and use it to Perform Basic Infe ...

  10. pytorch训练GAN的代码(基于MNIST数据集)

    论文:Generative Adversarial Networks 作者:Ian J. Goodfellow 年份:2014年 从2020年3月多开始看网络,这是我第一篇看并且可以跑通代码的论文,简 ...

最新文章

  1. IntelliJ IDEA控制台输出中文乱码问题解决
  2. pytorch 安装 pip+windows10+python3.6+CUDA10.0
  3. 视频会议场景下的弱网优化
  4. 为什么要有周考?周考是用来干什么的?
  5. diff算法_vue源码解读 diff算法
  6. linux屏幕怎么放大_02|初始Linux——Windows与Linux区别
  7. cpu矿工cpuminer-multi编译与使用
  8. python产生随机数并排序_中小学python教学案例:随机数按升序排列 输出
  9. 欠采样和过采样_过采样和欠采样
  10. 硬件和软件的32位与64位区别
  11. Java岗大厂面试百日冲刺【Day52】— 数据库8 (日积月累,每日三题)
  12. 【Linux】深入解析Linux proc文件系统
  13. 数字电路74161(MN)
  14. 高性能迷你React框架 anu1.2.3 发布
  15. 计算机网络——传输层の选择题整理
  16. 强化学习——环境库OpenAI Gym
  17. 【图灵杯 J】简单的变位词
  18. 在了解VR的途中看到文章
  19. 考研复试自我介绍总结
  20. mysql查询男女平均年龄_查询计算机系学生的姓名、性别和年龄

热门文章

  1. Condition总结-CountDownLatch源码分析
  2. 原本挂起的线程继续执行
  3. ServletFileUpload API详解
  4. plsql(轻量版)_流程控制
  5. Shell变量作用域
  6. 【干货】仪器仪表常用术语汇总
  7. 2019-05-21 Java学习日记之String类型Demo
  8. flask sqlalchemy一对多关系详解
  9. 细数非对称加密与对称加密的区别
  10. Notepad++ JSON关键字自动提示