1. 关于单机多卡的处理:

在pytorch官网上有一个简单的示例:函数使用为:torch.nn.DataParallel(model, deviceids, outputdevice, dim)关键的在于model、device_ids这两个参数。

DATA PARALLELISM​pytorch.org

但是官网的例子中没有讲到一个核心的问题:即所有的tensor必须要在同一个GPU上。这是网络运行的前提。这篇文章给了我很大的帮助,里面的例子也很好懂,很直观:

pytorch: 一机多卡训练的尝试​www.jianshu.com

一般来说有两种数据迁移的方法:

1)是先定义一个device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')这里面已经定义了device在卡0上“cuda:0”

然后将model = torch.nn.DataParallel(model,devices_ids=[0, 1, 2])(假设有三张卡)

此后需要将tensor 也迁移到GPU上去。注意所有的tensor必须要在同一张GPU上面

即:tensor1 = tensor1.to(device), tensor2 = tensor2.to(device)等等

可能有人会问了,我并没有指定那一块GPU啊,怎么这样也没有出错啊?

原因很简单,因为一开始的device中已经指定了那一块卡了(卡的id为0)

2)第二中方法就是直接用tensor.cuda()的方法

即先model = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) (假设有三块卡, 卡的ID 为0, 1, 2)

然后tensor1 = tensor1.cuda(0), tensor2=tensor2.cuda(0)等等。(我这里面把所有的tensor全放进ID 为 0 的卡里面,也可以将全部的tensor都放在ID 为1 的卡里面)

2 关于DataParallel的封装问题

在DataParallel中,没有和nn.Module一样多的特性。但是有些时候我们可能需要使用到如.fc这样的性质(.fc性质在nn.Module中有, 但是在DataParallel中没有)这个时候我们需要一个.Module属性来进行过渡。操作如下:

model 

3 Pytorch中的数据导入潜规则

All pre-trained models expect input images normalized in the same way, i.e. mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be at least 224. The images have to be loaded in to a range of [0, 1] and then normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]

所以我们在transform的时候可以先定义:normalized = torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), 然后用的时候直接调用normalized就行了。

4 python中的某些包的版本不同也会导致程序运行失败。

如,今天遇到一个pillow包的问题。原先装的包的6.0.0版本的,但是在制作数据集的时候,训练集跑的好好的,一到验证集就开始无端报错。在确定程序无误之后,将程序放在别的环境中跑(也是pytorch环境),正常运行。于是经过几番查找,发现是pillow出了问题,于是乎卸载了原来的版本,重新装一个低一点的版本问题就解决了。这种版本问题的坑其实很多,而且每个人遇到的还都不尽相同,所以需要慢慢的去摸索才能发现问题所在。

5 关于CUDA 内存溢出的问题。

这个一般是因为batch_size 设置的比较大。(8G显存的话大概batch_size < = 64都ok, 如果还是报错的话,就在对半分 64, 32, 16, 8, 4等等)。而且这个和你的数据大小没什么太大的关系。因为我刚刚开始也是想可能是我训练集太大了,于是将数据集缩小了十倍,还是同样的报错。所以就想可能 batch_size的问题。最后果然是batchsize的问题。

6 关于模型导入

一般来说如果你的模型是再GPU上面训练的,那么如果你继续再GPU上面进行其他的后续操作(如迁移学习等)那么直接使用:

import torch
from torchvision import modelspre_trained_weight = torch.load('pre_trained_weight.pt') # pre_trained_weight.pt 是我在resnet18上面训练好的模型
resnet18 = models.resnet18(pretrained=False) # 导入框架
resnet18.load_state_dict(pre_trained_weight)  # load_state_dict()函数表示导入当前权值,因为这个权值都是以字典的形式保存的

# 如果你模型在GPU上训练的,而且后续操作也在GPU上进行,那上面的操作就没啥毛病。但是…………

如果你模型在GPU上训练的,后续操作是在CPU上进行的话。那么还用上面的代码的话就会报错了。因为你模型在GPU上训练,其实其内部的某些数据格式和CPU上的不大一样。所以需要一个函数将GPU上的模型转化为CPU上的模型。这个工作在pytorch里面其实很简单。只要把上面的代码简单修改一下即可:(在torch.load函数里面加一个map_location='cpu'即可!)

import torch
from torchvision import modelspre_trained_weight = torch.load('pre_trained_weight.pt',map_location='cpu') # pre_trained_weight.pt 是我在resnet18上面训练好的模型
resnet18 = models.resnet18(pretrained=False) # 导入框架
resnet18.load_state_dict(pre_trained_weight)  # load_state_dict()函数表示导入当前权值,因为这个权值都是以字典的形式保存的

7. 关于两次sort操作:

参考:https://blog.csdn.net/LXX516/article/details/78804884

前几天看SSD pytorch的源码发现了,有这样的一步操作,不得解,

于是查阅了一下资料和动手操作后发现了两次sort操作的神奇之处。

首先 sort操作没什么好说。它接收两个参数:dim和descending参数。dim表示的是从哪个维度进行排列,descending参数接收布尔类型的输入,表示结果是否按降序排列。两次sort操作的具体实施为。

import 

从上面的分析中可以看到,两次sort操作得到的idx的意义是: 在保证原始元素的位置不变的情况下,可以表示排序情况(升序or降序)。

以上是原理,那么两次sort究竟用在什么地方呢?

还是上面哪个例子:

>>>x
tensor([[-0.1361,  0.4076, -0.8244,  0.9163],[-0.0997, -1.1689, -2.3145,  1.2334],[-0.4384, -1.6083,  1.7621, -0.9648]])

我想取x的第一行元素的前1个最小值, 第二行元素的前2个最小值,第三行元素的前3个最小值。该怎么操作呢?

根据上面的两次sort操作,我们得到idx

tensor([[1, 2, 0, 3],[2, 1, 0, 3],[2, 0, 3, 1]])
# 定义criterion
criterion = torch.tensor([1, 2, 3]).view(3, -1)
criterion = criterion.expand_as(idx)
>>>criterion
tensor([[1, 1, 1, 1],[2, 2, 2, 2],[3, 3, 3, 3]])
mask = idx < criterion
>>>mask
tensor([[0, 0, 1, 0],[0, 1, 1, 0],[1, 1, 0, 1]], dtype=torch.uint8)
# 可以看到,mask得到的就是我们所需要的索引。可以看到mask第一行只有一个1, 第二行有两个1,第三行有三个1.这里的1表示的True的意思,即得到这个数
>>>x[mask]
tensor([-0.8244, -1.1689, -2.3145, -0.4384, -1.6083, -0.9648]) # 最终结果

8. log_sum_exp的trick:

机器学习常见模式LogSumExp解密人工智能_机器人之家​www.jqr.com

参考这篇文章,写的通俗易懂。大概介绍一下问题:

发现这个问题是前几天,这里面在进行exp操作的时候用x-x_max。当时很是疑惑。后来一看上面这篇文章才明白了。

一般来说

是有一个确切的值与之对应的。但是在计算里面却不是这样的。输入torch.exp(1000), 结果是:

这样的结果并不意外,因为计算机的存储阶段误差导致的。基于这种情况的存在,所以人们想到了一个比较好的解决方法。具体怎么实现看看上面的链接便清楚了。

pytorch load state dict_学习Pytorch过程遇到的坑(持续更新中)相关推荐

  1. 图谱(学习地图)系列总结,持续更新中

    目录 1.2022年最新前端学习路线图 2.2022年最新大数据学习路线图 3.2022年最新javaEE学习路线图 4.2022年最新UI/UE学习路线图 5.2022年java学习路线指南 6.J ...

  2. pytorch load state dict_PyTorch 学习笔记(五):Finetune和各层定制学习率

    本文截取自<PyTorch 模型训练实用教程>,获取全文pdf请点击:https://github.com/tensor-yu/PyTorch_Tutorial @[toc] 我们知道一个 ...

  3. pytorch load state dict_Pytorch学习记录-使用Pytorch进行深度学习,保存和加载模型

    新建 Microsoft PowerPoint 演示文稿 (2).jpg 保存和加载模型 在完成60分钟入门之后,接下来有六节tutorials和五节关于文本处理的tutorials.争取一天一节.不 ...

  4. linux的学习之旅(初学者)--持续更新中

    具体Linux的信息建议先自己百度了解,本博客是笔者自己的学习记录,因此本博客是按照本人的学习进度及内容而写的,如有错误或者忘记的,欢迎留言告知. 进入Linux系统 Linux系统与Windows系 ...

  5. 《python编程从入门到实践》python入门级-学习笔记(1-2章)——持续更新中

    CSDN的小伙伴们你们好~从今天起我开始自学编程了. 恭喜你关注到一个成长型账号. 一以来作为美术出身的TA,我无数次的向往能打出几行属于自己的代码.为了跟上时代的步伐,也为了能更加深入TA这个职业, ...

  6. 【Vue全家桶+SSR+Koa2全栈开发】项目搭建过程 整合 学习目录(持续更新中)

    写在开头 大家好,这里是lionLoveVue,基础知识决定了编程思维,学如逆水行舟,不进则退.金三银四,为了面试也还在慢慢积累知识,Github上面可以直接查看所有前端知识点梳理,github传送门 ...

  7. Go语言开发学习笔记(持续更新中)

    Go语言开发学习笔记(持续更新中) 仅供自我学习 更好的文档请选择下方 https://studygolang.com/pkgdoc https://www.topgoer.com/go%E5%9F% ...

  8. 【我的OpenGL学习进阶之旅】【持续更新】关于学习OpenGL的一些资料

    目录 一.相关书籍 OpenGL 方面 C方面 NDK 线性代数 二.相关博客 2.0 一些比较官方的链接 2.1 OpenGL着色器语言相关 2.2 [[yfan]](https://segment ...

  9. C语言学习笔记Day3——持续更新中... ...

    上一篇文章C语言学习笔记Day2--持续更新中- - 八. 容器 1. 一维数组 1.1 什么是一维数组 当数组中每个元素都只带有一个下标(第一个元素的下标为0, 第二个元素的下标为1, 以此类推)时 ...

最新文章

  1. Java SE 6之GUI:让界面更加绚丽(上)
  2. Java Vector
  3. python读文件去除空行_「34」Python文件操作经典案例:CSV文件的读与写
  4. “蜥蜴之尾”——长老木马四代分析报告
  5. 【ARM】Tiny4412裸板编程之MMU封装
  6. 循环斐波那契数列_每日一课 | 斐波那契数列的第n个项
  7. Bootstrap进度条的颜色
  8. Asp.Net SignalR - 简单聊天室实现
  9. mysql集群 自增_为什么我们要从MySQL迁移到TiDB?
  10. ISO 9001是什么?ISO 9001 质量管理体系详细介绍
  11. WORD目录三级标题行间距太大 目录标题行间距
  12. js获取url一级域名的方法
  13. 离线语音合成使用——科大讯飞or云知音or百度语音
  14. WAF防火墙有什么用
  15. 最简便的方法搭建Hexo+Github博客,基于Next主题
  16. 刘文文:603001操盘手坐庄内幕(转)
  17. asp.net网上商城系统VS开发sqlserver数据库web结构c#编程计算机网页源码项目
  18. 百度前端技术学院-斌斌学院-任务五
  19. python爬取前程无忧_Python爬取前程无忧网址,并保存为txt文件
  20. 水利工程资料管理软件

热门文章

  1. 【sql那些事】时间处理的一揽子事
  2. C#开发笔记之19-如何用C#实现优雅的Json解析(序列化/反序列化)方案?
  3. C#LeetCode刷题之#874-模拟行走机器人​​​​​​​(Walking Robot Simulation)
  4. MySQL常用数据类型
  5. mysql创建表语句和修改表语句
  6. sql中聚合函数和分组函数_SQL选择计数聚合函数-语法示例解释
  7. 让我们揭穿有关学习编码的主要神话
  8. 121_Power Query之R.Execute的read.xlsxODBC
  9. Django表单form
  10. ROS仿真-记一次错误 gazebo-2 process has died exit code 2