声明:本文大部分内容是从知乎、博客等知识分享站点摘录而来,以方便查阅学习。具体摘录地址已在文章底部引用部分给出。


1. 查看模型每层输出详情

from torchsummary import summary
summary(your_model, input_size=(channels, H, W))

2. 梯度裁减

import torch.nn as nnoutputs = model(inputs)
loss= criterion(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)  # max_norm:梯度的最大范数;norm_type:规定范数的类型,默认为L2
optimizer.step()

3. 扩展图片维度

因为训练时数据维度一般为(batch_size, c, h,, w),而测试时如果只输入一张图片,则需要进行维度扩展。

方法一:(h, w, c) -> (1, h, w, c)

import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)img = image.view(1, *image.size())

方法二:(h, w, c) -> (1, h, w, c)

import cv2
import numpy as npimage = cv2.imread(img_path)
img = image[np.newaxis, :, :, :]

方法三:

import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)img = image.unsqueeze(dim=0)  # 扩展维度,dim指定扩展哪个维度;torch.Size([(h, w, c)]) -> torch.Size([(1, h, w, c)])
img = img.squeeze(dim=0) # 去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度; torch.Size([(1, h, w, c)]) -> torch.Size([(h, w, c)]) 

4. 独热编码

在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。

import torch
class_num = 8
batch_size = 4def one_hot(label):"""将一维列表转换为独热编码"""label = label.resize_(batch_size, 1)m_zeros = torch.zeros(batch_size, class_num)# 从 value 中取值,然后根据 dim 和 index 给相应位置赋值onehot = m_zeros.scatter_(1, label, 1)  # (dim,index,value)return onehot.numpy()  # Tensor -> Numpy

label = torch.LongTensor(batch_size).random_() % class_num  # 对随机数取余
print(one_hot(label))

# output:# label = tensor([3, 7, 0, 6])# [[0. 0. 0. 1. 0. 0. 0. 0.]#  [0. 0. 0. 0. 0. 0. 0. 1.]#  [1. 0. 0. 0. 0. 0. 0. 0.]#  [0. 0. 0. 0. 0. 0. 1. 0.]]

5. 防止验证模型时爆显存

验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。

with torch.no_grad():# 使用model进行预测的代码pass

6. torch.cuda.empty_cache()的用处

由于 PyTorch 的缓存分配器会事先分配一些固定的显存,即使实际上 tensors 并没有使用完这些显存,这些显存也不能被其他应用使用。因此 torch.cuda.empty_cache() 的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi命令可见。注意使用此命令不会释放 tensors 占用的显存。对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。

7. 学习率衰减

import torch.optim as optim
from torch.optim import lr_scheduler# 训练前的初始化
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1)  # # 每过10个epoch,学习率乘以0.1# 训练过程中
for n in n_epoch:scheduler.step()

8. 冻结某些层的参数

在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。

1) 我们首先需要知道每一层的名字,通过如下代码打印:

model = Network()  # 获取自定义网络结构
for name, value in model.named_parameters():print('name: {0},\t grad: {1}'.format(name, value.requires_grad))

假设前几层信息如下:

name: cnn.VGG_16.convolution1_1.weight,     grad: True
name: cnn.VGG_16.convolution1_1.bias,     grad: True
name: cnn.VGG_16.convolution1_2.weight,     grad: True
name: cnn.VGG_16.convolution1_2.bias,     grad: True
name: cnn.VGG_16.convolution2_1.weight,     grad: True
name: cnn.VGG_16.convolution2_1.bias,     grad: True
name: cnn.VGG_16.convolution2_2.weight,     grad: True
name: cnn.VGG_16.convolution2_2.bias,     grad: True

2) 定义一个要冻结的层的列表

no_grad = ['cnn.VGG_16.convolution1_1.weight','cnn.VGG_16.convolution1_1.bias','cnn.VGG_16.convolution1_2.weight','cnn.VGG_16.convolution1_2.bias'
]

3) 冻结方法如下

net = Net.CTPN()  # 获取网络结构
for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True

再打印每层信息:

name: cnn.VGG_16.convolution1_1.weight,     grad: False
name: cnn.VGG_16.convolution1_1.bias,     grad: False
name: cnn.VGG_16.convolution1_2.weight,     grad: False
name: cnn.VGG_16.convolution1_2.bias,     grad: False
name: cnn.VGG_16.convolution2_1.weight,     grad: True
name: cnn.VGG_16.convolution2_1.bias,     grad: True
name: cnn.VGG_16.convolution2_2.weight,     grad: True
name: cnn.VGG_16.convolution2_2.bias,     grad: True

4) 最后在定义优化器时,只对requires_grad为True的层的参数进行更新。

optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)

9. 对不同层使用不同的学习率

1)首先获取网络结构每一层的名字

net = Network()  # 获取自定义网络结构
for name, value in net.named_parameters():print('name: {}'.format(name))
# 输出:
# name: cnn.VGG_16.convolution1_1.weight
# name: cnn.VGG_16.convolution1_1.bias
# name: cnn.VGG_16.convolution1_2.weight
# name: cnn.VGG_16.convolution1_2.bias
# name: cnn.VGG_16.convolution2_1.weight
# name: cnn.VGG_16.convolution2_1.bias
# name: cnn.VGG_16.convolution2_2.weight
# name: cnn.VGG_16.convolution2_2.bias

2)对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:

conv1_params = []
conv2_params = []for name, parms in net.named_parameters():if "convolution1" in name:conv1_params += [parms]else:conv2_params += [parms]# 然后在优化器中进行如下操作:
optimizer = optim.Adam([{"params": conv1_params, 'lr': 0.01},{"params": conv2_params, 'lr': 0.001},],weight_decay=1e-3,
)

我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的 weight_decay。

我们也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。

References:

[1] PyTorch trick 集锦

转载于:https://www.cnblogs.com/xxxxxxxxx/p/11582657.html

【PyTorch】Tricks 集锦相关推荐

  1. PyTorch Tricks 集锦

    点击上方"Datawhale",选择"星标"公众号 第一时间获取价值内容 目录: 1 指定GPU编号 2 查看模型每层输出详情 3 梯度裁剪 4 扩展单张图片维 ...

  2. 目标检测比赛中的tricks集锦

    ↑ 点击蓝字 关注视学算法 作者丨初识CV@知乎 来源丨https://zhuanlan.zhihu.com/p/102817180 编辑丨极市平台 极市导读 本文总结了目标检测比赛中的8点技巧,包含 ...

  3. 数据竞赛Tricks集锦

    点击上方"Datawhale",选择"星标"公众号 第一时间获取价值内容 本文将对数据竞赛的『技巧』进行全面的总结,同时还会分享下个人对比赛方法论的思考.前者比 ...

  4. PyTorch Trick集锦

    点击蓝字  关注我们 作者丨z.defying@知乎 来源丨https://zhuanlan.zhihu.com/p/76459295 极市导读 本文整理了13则PyTorch使用的小窍门,包括了指定 ...

  5. Pytorch错误集锦

    RuntimeError: Expected one of cpu, cuda, mkldnn, opengl, opencl, ideep, hip, msnpu device type at st ...

  6. pytorch tricks

    您的鼓励是我前进的源动力

  7. 干货收集和整理:Pytorch,Keras,数据分析

    深度学习框架:Keras .Pytorch https://github.com/huggingface/transformers   (Keras作者推荐开源项目transformers) http ...

  8. 【pytorch】torch.cuda.empty_cache()==>释放缓存分配器当前持有的且未占用的缓存显存

    Pytorch 训练时无用的临时变量可能会越来越多,导致 out of memory ,可以使用下面语句来清理这些不需要的变量. torch.cuda.empty_cache() 官网 上的解释为: ...

  9. PyTorch超级资源列表(Github 2.4K星)包罗万象

    PyTorch超级资源列表,包罗万象 PyTorch超级资源列表(Github 2.4K星)包罗万象 -v7.x 1 Pytorch官方工程 2 自然语言处理和语音处理(NLP & Speec ...

最新文章

  1. Ubuntu14.04 64位机上安装OpenCV2.4.13(CUDA8.0)版操作步骤
  2. 设计模式(六)命令模式
  3. Log4net中的RollingFileAppender解析
  4. 互联网协议 — 数据交换技术
  5. 软件项目经理需具备什么样的技术水平?
  6. 034_jQuery Ajax的getJSON和getScript方法
  7. sql 时间范围查询_Excel中使用SQL查询,单元格范围最多支持65536行?
  8. NoSQL Databases - CouchDB
  9. 与国际接轨,中国人慎用这些汉字取名(最后一段对话,笑到喷饭!!)
  10. 赛锐信息:FlexBroswer,一劳永逸解决业务系统Flash问题
  11. 为什么华为5G手机要设置一个5G开关?
  12. 键盘上每个键作用!!! (史上最全的)­
  13. L1-041 寻找250 (10 分)—团体程序设计天梯赛
  14. NGSL + NAWL 单词表 以及学习网站
  15. 计算机应用免费课件,计算机应用基础ppt课件 免费版
  16. 限时秒杀┃“探月计划”来袭,美国米德天文望远镜助孩子观月赏月
  17. 计算机系统基础——我与袁春风不得不说的知识——入门必看
  18. Openstack基础架构
  19. vue项目文件夹介绍
  20. Jerry Ma:为什么我更喜欢做空?

热门文章

  1. 让您的Eclipse具有千变万化的外观
  2. 2018年东北农业大学春季校赛 E 阶乘后的0【数论】
  3. 项目中常见错误总结一
  4. winpcap 发送数据包
  5. xpath IE 7
  6. CentOS系统搭建OpenERP
  7. Juniper静态路由之no-advertise和qualified-next-hop
  8. Google退出中国 谁最受伤
  9. ADO与ADO.NET
  10. 剑指Offer:剪绳子(动态规划、贪婪算法)