pytorch学习笔记(6):GPU和如何保存加载模型
参考文档:https://mp.weixin.qq.com/s/kmed_E4MaDwN-oIqDh8-tg
上篇文章我们完成了一个 vgg 网络的实现,那么现在已经掌握了一些基础的网络结构的实现,距离一个入门炼丹师还有两个小问题需要注意一下:GPU 和保存模型。
提起炼丹大家经常可以听到显卡如何如何的,也就是 GPU 在炼丹的过程中起到重要的作用。另一方面,训练了一个模型后,我们肯定要用它来进行一些预测,前面的代码中都是将训练好的模型直接进行预测,但是如果代码每次预测都要训练一次岂不是麻烦死了,所以将训练好的模型保存下来也是一个关键环节。
1、GPU在炼丹中如何使用
GPU 擅长并行式的图像计算,而张量本身和图像一样都是矩阵计算,所以对于 tensor 的计算,GPU 本身就有得天独厚的优势。再加上并行式的训练方式,可以有效的节约训练时候的时间消耗。
那么我们怎么在训练过程中使用 GPU 呢?
首先安装 pytorch 的时候在官网需要选择 cuda 的版本,这里我们不多赘述。然后通过下面的代码查看你的电脑是否支持 GPU,以及是否安装成功。
torch.cuda.is_available()
如果显示为 True,则说明安装成功过,我们就可以接下来学习如何使用你的这张显卡了。
在使用 GPU 的过程中,其实主要就是两部分需要放到 GPU 上:模型和数据。所以我们现在来学习如何将这两部分放到 GPU 上。
device = "cuda:0"
cnn.to(device)
images = images.to(device)
labels = labels.to(device)
这里的 cuda:0 是我们指定了使用哪张显卡,在单卡情况下,一般这里的标号就是 0,所以 device 变量相当于我们指定了显卡。接下来 cnn 是我们定义的网络模型,直接加后缀的 .to(device) 就可以将网络放到 GPU 中了。严格的讲是将所有网络的参数和缓存放到 GPU 中。
最后的 images 和 labels 也是同样的道理,将对应的数据放到 GPU 中。这也就是为什么我们在选择显卡的时候需要注意显存的大小,它决定了我们可以放进去多少数据。
在测试的时候也是同样的道理,需要将数据存到 GPU 中,最后预测的结果也是在 GPU 中,如果需要将其拿到 CPU 上来和其它数据进行交互的话,我们就需要按照下面代码的样子进行处理:
pred_y = pred_y.cpu()
2、如何保存模型
在 pytorch 中,所有的网络参数信息都保存在状态字典中,也就是 python 中的 dict 结构。本身我们对 dict 就有很多保存的方法,所以 pytorch 在这里的实现也就不复杂了。
在保存模型的时候,我们有两种常见做法:第一种是只保存网络的参数,那么在使用的时候创建好网络结构,然后将参数加载进去;第二种是保存整个网络。下面我们分别给出示例:
torch.save(cnn.state_dict(), PATH)
通过这个方法,就可以将一个名为 cnn 的网络,保存到目标路径 PATH 中。在使用的时候按照如下方法:
cnn = CNN(*args, **kwargs)
cnn.load_state_dict(torch.load(PATH))
cnn.eval()
我们可以看到先通过设定好的网络结构 CNN 先实例化一个 cnn 网络出来,然后用 load_state_dict() 函数去读取保存好的网络参数。
还有一种办法是保存整个网络,具体的方法如下:
torch.save(cnn, PATH)
而读取的方法也不能用 load_state_dict() 了,需要直接使用 load() 方法。
cnn = torch.load(PATH)
cnn.eval()
这样子我们就完成了对模型的导入。
除此之外,还值得注意的几个点是 checkpoint 的保存和冻结部分参数进行 fine-tuning。
checkpoint 保存也就是保存下模型在某个时刻的状态,那么就需要保存当时的所有信息,包括当时的 epoch,loss,优化器的 state_dict 等等。
而冻结部分参数进行 fine-tuning 也是比较常见的一种行为,尤其是在涉及到迁移学习等方面时,我们经常需要将一个在其它数据集上进行训练过的模型,再到当前数据集上继续训练。那么在这一步的时候一般就是将对应的 tensor 的 requires_grid 设置为 True 或者 False,然后按照如下的命令进行训练:
optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
3、总结
今天的文章总结了两个比较常见的功能,我们在实际场景中也的确会经常用到这两个方法。GPU 对炼丹的帮助不言而喻,而模型的保存在我们的实际使用中也是不可或缺的。
pytorch学习笔记(6):GPU和如何保存加载模型相关推荐
- Tensorflow学习(二)之——保存加载模型、Saver的用法
1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类 ...
- TensorFlow学习笔记——使用TFRecord进行数据保存和加载
本篇文章主要介绍如何使用TensorFlow构建自己的图片数据集TFRecord的方法,并使用最新的数据处理Dataset API进行操作. TFRecord TFRecord数据文件是一种对任何数据 ...
- Objective-C学习笔记第十五章文件加载与保存
第十五章文件加载与保存 Cocoa提供了Core Data,他能在后台处理所有文件内容 Cocoa提供了两个通用的文件处理类:属性列表和对象编码 一.属性列表类 在Cocoa中,有一类名为属性列表的对 ...
- Unity学习笔记(5):动态加载Prefab
第一种方法,从Resources文件夹读取Prefab Assets/Resources文件夹是Unity中的一个特殊文件夹,在博主当前的认知里,放在这个文件夹里的Prefab可以被代码动态加载 直接 ...
- contiki学习笔记(六)contiki程序加载器和多线程库
六.contiki程序加载器 contiki程序加载器是一个用于加载和启动程序的抽象接口. Data Structures struct dsc//DSC程序描述结构. ModulesThe Cont ...
- JVM学习笔记之-类加载子系统,类的加载与类的加载过程,双亲委派机制
一 类加载器与类加载过程 类加载子系统作用 类加载器子系统负责从文件系统或者网络中加载class文件,class文件在文件开头有特定的文件标识. ClassLoader只负责class文件的加载,至于 ...
- scrapy学习笔记(三)-关于动态加载网页的爬取(序)
一.尝试 对于我要爬取的网站内容,按照网上普遍的步骤:直接通过xpath获取到对于数据,再记录到item中,只适用于静态html网页,但是如今的互联网大部分的web页面都是动态的,经常逛的网站例如京东 ...
- 深度实践SPARK机器学习_学习笔记_第二章2.3加载数据
2.3加载数据 1.下载数据文件u.user head -3 u.user ##查看文件前几行 cat u.user |wc -l 或者 more u.user |wc -l ##数文件记录数 ...
- Pytorch 保存和加载模型
当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...
最新文章
- 如何进行机器学习框架选择
- Spark Shuffle两种Manager
- python dataframe列数值相加,python合并dataframe中的行并将值相加
- pythongui程序,python第一个GUI程序
- WebKit 内核源码分析 (一) Frame
- SpringBoot实用小技巧之动态设置SpringBoot日志级别 1
- springboot aop记录日志
- jQuery新的事件绑定机制on()
- 用python操作浏览器的三种方式_经验 | python 操作浏览器的三种方式
- 是否优化更新主题浏览量:_主题306:能力规划
- 当Apple TV+的生态化反梦,撞上一个“日渐昂贵”的流媒体市场
- 阿里云服务器租用收费标准(精准费用报价更新)
- 金彩教育:详情页文案怎么写
- Swift不深入只浅出入门教程-孟祥月-专题视频课程
- mavoneditor 显示html,mavonEditor
- 使用scp把另外一台服务器上的文件夹/文件拷贝到当前服务器
- PCA变换与KL变换
- 《python3廖雪峰》正则表达式匹配Email地址练习题答案
- [附源码]JAVA毕业设计酒店订房系统(系统+LW)
- 量子纠缠暗示了:我们这个世界很诡异!它到底纠缠了个啥?