Pytorch常用技巧记录
Pytorch常用技巧记录
目录
文章目录
- Pytorch常用技巧记录
- 1、指定GPU编号
- 2、查看模型每层输出详情
- 3、梯度裁剪(Gradient Clipping)
- 4、扩展单张图片维度
- 5、独热编码
- 6、防止验证模型时爆显存
- 7、学习率衰减
- 8、冻结某些层的参数
- 9、对不同层使用不同学习率
- 10、模型相关操作
- 11、Pytorch内置one_hot函数
- 转载
1、指定GPU编号
- 设置当前使用的GPU设备仅为0号设备,设备名称为
/gpu:0
:os.environ["CUDA_VISIBLE_DEVICES"] = "0"
- 设置当前使用的GPU设备为0,1号两个设备,名称依次为
/gpu:0
、/gpu:1
:os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
,根据顺序表示优先使用0号设备,然后使用1号设备。
指定GPU的命令需要放在和神经网络相关的一系列操作的前面。
2、查看模型每层输出详情
Keras有一个简洁的API来查看模型的每一层输出尺寸,这在调试网络时非常有用。现在在PyTorch中也可以实现这个功能。
使用很简单,如下用法:
from torchsummary import summary
summary(your_model, input_size=(channels, H, W))
input_size
是根据你自己的网络模型的输入尺寸进行设置。
pytorch-summary
3、梯度裁剪(Gradient Clipping)
import torch.nn as nnoutputs = model(data)
loss= loss_fn(outputs, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
optimizer.step()
nn.utils.clip_grad_norm_
的参数:
- parameters – 一个基于变量的迭代器,会进行梯度归一化
- max_norm – 梯度的最大范数
- norm_type – 规定范数的类型,默认为L2
知乎用户 不椭的椭圆
提出:梯度裁剪在某些任务上会额外消耗大量的计算时间,可移步评论区查看详情。
4、扩展单张图片维度
因为在训练时的数据维度一般都是 (batch_size, c, h, w),而在测试时只输入一张图片,所以需要扩展维度,扩展维度有多个方法:
import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())img = image.view(1, *image.size())
print(img.size())# output:
# torch.Size([h, w, c])
# torch.Size([1, h, w, c])
或
import cv2
import numpy as npimage = cv2.imread(img_path)
print(image.shape)
img = image[np.newaxis, :, :, :]
print(img.shape)# output:
# (h, w, c)
# (1, h, w, c)
或(感谢知乎用户 coldleaf
的补充)
import cv2
import torchimage = cv2.imread(img_path)
image = torch.tensor(image)
print(image.size())img = image.unsqueeze(dim=0)
print(img.size())img = img.squeeze(dim=0)
print(img.size())# output:
# torch.Size([(h, w, c)])
# torch.Size([1, h, w, c])
# torch.Size([h, w, c])
tensor.unsqueeze(dim)
:扩展维度,dim指定扩展哪个维度。
tensor.squeeze(dim)
:去除dim指定的且size为1的维度,维度大于1时,squeeze()不起作用,不指定dim时,去除所有size为1的维度。
5、独热编码
在PyTorch中使用交叉熵损失函数的时候会自动把label转化成onehot,所以不用手动转化,而使用MSE需要手动转化成onehot编码。
import torch
class_num = 8
batch_size = 4def one_hot(label):"""将一维列表转换为独热编码"""label = label.resize_(batch_size, 1)m_zeros = torch.zeros(batch_size, class_num)# 从 value 中取值,然后根据 dim 和 index 给相应位置赋值onehot = m_zeros.scatter_(1, label, 1) # (dim,index,value)return onehot.numpy() # Tensor -> Numpylabel = torch.LongTensor(batch_size).random_() % class_num # 对随机数取余
print(one_hot(label))# output:
[[0. 0. 0. 1. 0. 0. 0. 0.][0. 0. 0. 0. 1. 0. 0. 0.][0. 0. 1. 0. 0. 0. 0. 0.][0. 1. 0. 0. 0. 0. 0. 0.]]
Convert int into one-hot format
注:第11条有更简单的方法。
6、防止验证模型时爆显存
验证模型时不需要求导,即不需要梯度计算,关闭autograd,可以提高速度,节约内存。如果不关闭可能会爆显存。
with torch.no_grad():# 使用model进行预测的代码pass
感谢知乎用户zhaz
的提醒,我把 torch.cuda.empty_cache()
的使用原因更新一下。
这是原回答:
Pytorch 训练时无用的临时变量可能会越来越多,导致
out of memory
,可以使用下面语句来清理这些不需要的变量。
官网 上的解释为:
Releases all unoccupied cached memory currently held by the caching allocator so that those can be used in other GPU application and visible innvidia-smi.
torch.cuda.empty_cache()
意思就是PyTorch的缓存分配器会事先分配一些固定的显存,即使实际上tensors并没有使用完这些显存,这些显存也不能被其他应用使用。这个分配过程由第一次CUDA内存访问触发的。
而 torch.cuda.empty_cache()
的作用就是释放缓存分配器当前持有的且未占用的缓存显存,以便这些显存可以被其他GPU应用程序中使用,并且通过 nvidia-smi
命令可见。注意使用此命令不会释放tensors占用的显存。
对于不用的数据变量,Pytorch 可以自动进行回收从而释放相应的显存。
更详细的优化可以查看 优化显存使用 和 显存利用问题。
7、学习率衰减
import torch.optim as optim
from torch.optim import lr_scheduler# 训练前的初始化
optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, 10, 0.1)# 训练过程中
for n in n_epoch:scheduler.step()...
关键语句为lr_scheduler.StepLR(optimizer, 10, 0.1)
,表示每过10个epoch,学习率乘以0.1。
8、冻结某些层的参数
参考:Pytorch 冻结预训练模型的某一层
在加载预训练模型的时候,我们有时想冻结前面几层,使其参数在训练过程中不发生变化。
我们需要先知道每一层的名字,通过如下代码打印:
net = Network() # 获取自定义网络结构
for name, value in net.named_parameters():print('name: {0},\t grad: {1}'.format(name, value.requires_grad))
假设前几层信息如下:
name: cnn.VGG_16.convolution1_1.weight, grad: True
name: cnn.VGG_16.convolution1_1.bias, grad: True
name: cnn.VGG_16.convolution1_2.weight, grad: True
name: cnn.VGG_16.convolution1_2.bias, grad: True
name: cnn.VGG_16.convolution2_1.weight, grad: True
name: cnn.VGG_16.convolution2_1.bias, grad: True
name: cnn.VGG_16.convolution2_2.weight, grad: True
name: cnn.VGG_16.convolution2_2.bias, grad: True
后面的True表示该层的参数可训练,然后我们定义一个要冻结的层的列表:
no_grad = ['cnn.VGG_16.convolution1_1.weight','cnn.VGG_16.convolution1_1.bias','cnn.VGG_16.convolution1_2.weight','cnn.VGG_16.convolution1_2.bias'
]
冻结方法如下:
net = Net.CTPN() # 获取网络结构
for name, value in net.named_parameters():if name in no_grad:value.requires_grad = Falseelse:value.requires_grad = True
冻结后我们再打印每层的信息:
name: cnn.VGG_16.convolution1_1.weight, grad: False
name: cnn.VGG_16.convolution1_1.bias, grad: False
name: cnn.VGG_16.convolution1_2.weight, grad: False
name: cnn.VGG_16.convolution1_2.bias, grad: False
name: cnn.VGG_16.convolution2_1.weight, grad: True
name: cnn.VGG_16.convolution2_1.bias, grad: True
name: cnn.VGG_16.convolution2_2.weight, grad: True
name: cnn.VGG_16.convolution2_2.bias, grad: True
可以看到前两层的weight和bias的requires_grad都为False,表示它们不可训练。
最后在定义优化器时,只对requires_grad为True的层的参数进行更新。
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.01)
9、对不同层使用不同学习率
我们对模型的不同层使用不同的学习率。
还是使用这个模型作为例子:
net = Network() # 获取自定义网络结构
for name, value in net.named_parameters():print('name: {}'.format(name))# 输出:
# name: cnn.VGG_16.convolution1_1.weight
# name: cnn.VGG_16.convolution1_1.bias
# name: cnn.VGG_16.convolution1_2.weight
# name: cnn.VGG_16.convolution1_2.bias
# name: cnn.VGG_16.convolution2_1.weight
# name: cnn.VGG_16.convolution2_1.bias
# name: cnn.VGG_16.convolution2_2.weight
# name: cnn.VGG_16.convolution2_2.bias
对 convolution1 和 convolution2 设置不同的学习率,首先将它们分开,即放到不同的列表里:
conv1_params = []
conv2_params = []for name, parms in net.named_parameters():if "convolution1" in name:conv1_params += [parms]else:conv2_params += [parms]# 然后在优化器中进行如下操作:
optimizer = optim.Adam([{"params": conv1_params, 'lr': 0.01},{"params": conv2_params, 'lr': 0.001},],weight_decay=1e-3,
)
我们将模型划分为两部分,存放到一个列表里,每部分就对应上面的一个字典,在字典里设置不同的学习率。当这两部分有相同的其他参数时,就将该参数放到列表外面作为全局参数,如上面的weight_decay
。
也可以在列表外设置一个全局学习率,当各部分字典里设置了局部学习率时,就使用该学习率,否则就使用列表外的全局学习率。
10、模型相关操作
这个内容比较多,我就写成了一篇文章。
PyTorch 中模型的使用
11、Pytorch内置one_hot函数
感谢 yangyangyang 补充:Pytorch 1.1后,one_hot可以直接用 torch.nn.functional.one_hot
。
然后我将Pytorch升级到1.2版本,试用了下 one_hot
函数,确实很方便。
具体用法如下:
import torch.nn.functional as F
import torchtensor = torch.arange(0, 5) % 3 # tensor([0, 1, 2, 0, 1])
one_hot = F.one_hot(tensor)# 输出:
# tensor([[1, 0, 0],
# [0, 1, 0],
# [0, 0, 1],
# [1, 0, 0],
# [0, 1, 0]])
F.one_hot
会自己检测不同类别个数,生成对应独热编码。我们也可以自己指定类别数:
tensor = torch.arange(0, 5) % 3 # tensor([0, 1, 2, 0, 1])
one_hot = F.one_hot(tensor, num_classes=5)# 输出:
# tensor([[1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0],
# [0, 0, 1, 0, 0],
# [1, 0, 0, 0, 0],
# [0, 1, 0, 0, 0]])
升级 Pytorch (cpu版本)的命令:conda install pytorch torchvision -c pytorch
(希望Pytorch升级不会影响项目代码)
转载
https://github.com/zxdefying/pytorch_tricks
Pytorch常用技巧记录相关推荐
- Linux 常用技巧记录
1. 分卷压缩解压 tar -czf file | split -b 2G -d -file.tar.gz #压缩**file目录**,并且每个目录大小为2G左右. cat file.tar.gz*| ...
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- 收藏!PyTorch常用代码段合集
↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:Jack Stark,来源:极市平台 来源丨https://zhu ...
- PyTorch常用代码段合集
↑ 点击蓝字 关注视学算法 作者丨Jack Stark@知乎 来源丨https://zhuanlan.zhihu.com/p/104019160 极市导读 本文是PyTorch常用代码段合集,涵盖基本 ...
- 【转】oracle存储过程常用技巧
原文链接 http://www.cnblogs.com/chinafine/archive/2010/07/12/1776102.html 我们在进行pl/sql编程时打交道最多的就是存储过程了.存储 ...
- 【深度学习】PyTorch常用代码段合集
来源 | 极市平台,机器学习算法与自然语言处理 本文是PyTorch常用代码段合集,涵盖基本配置.张量处理.模型定义与操作.数据处理.模型训练与测试等5个方面,还给出了多个值得注意的Tips,内容非常 ...
- linux history 看更多历史记录_Linux历史记录history常用技巧
Linux历史记录history常用技巧 Pain #1 - 历史记录不带时间戳,不知道命令是什么时候发生的 默认情况下 history 命令直接显示用户执行的命令而不会输出运行命令时的日期和时间,即 ...
- pytorch list转tensor_PyTorch 52.PyTorch常用代码段合集
本文参考于: Jack Stark:[深度学习框架]PyTorch常用代码段zhuanlan.zhihu.com 1. 基本配置 导入包和版本查询: import torch import torc ...
- Javascript 常用技巧 [2]
Javascript 常用技巧 [2] /** 请问如何去掉主页右面的滚动条? <!-- <body scroll="no"> --> <!-- & ...
最新文章
- RequireJS示例
- torch.roll
- 内核对象——Windows核心编程学习手札系列之三
- Java I/O系统学习系列三:I/O流的典型使用方式
- svn提示客户端版本太旧
- 使用Java查询DynamoDB项
- java jtextarea滚动条下滑,关于JTextArea的滚动条问题
- (TOJ1531)爱的伟大意义
- 【Flink】Failed to create checkpoint storage at checkpoint coordinator side
- Mysql 异步复制
- 2018.9.28 典型for循环特殊理解及其二维数组的理解
- webstorm注释写出的提示
- MATLAB APP全局变量的使用
- JS 进阶 (六) 浏览器事件模型DOM操作
- Combo Box 组合框
- NOI.6.08石头剪刀布
- 03Roberts算子
- 中级软件设计师备考攻略
- SVG 的平移、旋转和缩放
- Python提取Word中的图片
热门文章
- 第三次握手为什么没有序列号_TCP三次握手机制-深入浅出(实例演示)
- 鸽主姓名查询成绩_SQL学习之旅-Select简单查询
- laravel ::all() 选择字段_Laravel 性能优化:优化 ORM 性能使应用程序高可用
- c ++递归算法数的计数_计数排序算法–在C / C ++中实现的想法
- scala中命名参数函数_Scala中的命名参数和默认参数值
- zookeeper入门学习之java api会话建立《四》
- C++基础知识(三)C++的输入和输出及操纵符
- 如何成为Java开发工程师?需要掌握哪些技能?
- Java常见面试题:Oracle JDK 和 OpenJDK 的区别?
- 开课吧Java教程什么是类集接口