参考文档:https://mp.weixin.qq.com/s/kmed_E4MaDwN-oIqDh8-tg

上篇文章我们完成了一个 vgg 网络的实现,那么现在已经掌握了一些基础的网络结构的实现,距离一个入门炼丹师还有两个小问题需要注意一下:GPU 和保存模型。

提起炼丹大家经常可以听到显卡如何如何的,也就是 GPU 在炼丹的过程中起到重要的作用。另一方面,训练了一个模型后,我们肯定要用它来进行一些预测,前面的代码中都是将训练好的模型直接进行预测,但是如果代码每次预测都要训练一次岂不是麻烦死了,所以将训练好的模型保存下来也是一个关键环节。

1、GPU在炼丹中如何使用

GPU 擅长并行式的图像计算,而张量本身和图像一样都是矩阵计算,所以对于 tensor 的计算,GPU 本身就有得天独厚的优势。再加上并行式的训练方式,可以有效的节约训练时候的时间消耗。
那么我们怎么在训练过程中使用 GPU 呢?

首先安装 pytorch 的时候在官网需要选择 cuda 的版本,这里我们不多赘述。然后通过下面的代码查看你的电脑是否支持 GPU,以及是否安装成功。

torch.cuda.is_available()

如果显示为 True,则说明安装成功过,我们就可以接下来学习如何使用你的这张显卡了。

在使用 GPU 的过程中,其实主要就是两部分需要放到 GPU 上:模型和数据。所以我们现在来学习如何将这两部分放到 GPU 上。

device = "cuda:0"
cnn.to(device)
images = images.to(device)
labels = labels.to(device)

这里的 cuda:0 是我们指定了使用哪张显卡,在单卡情况下,一般这里的标号就是 0,所以 device 变量相当于我们指定了显卡。接下来 cnn 是我们定义的网络模型,直接加后缀的 .to(device) 就可以将网络放到 GPU 中了。严格的讲是将所有网络的参数和缓存放到 GPU 中。

最后的 images 和 labels 也是同样的道理,将对应的数据放到 GPU 中。这也就是为什么我们在选择显卡的时候需要注意显存的大小,它决定了我们可以放进去多少数据。

在测试的时候也是同样的道理,需要将数据存到 GPU 中,最后预测的结果也是在 GPU 中,如果需要将其拿到 CPU 上来和其它数据进行交互的话,我们就需要按照下面代码的样子进行处理:

pred_y = pred_y.cpu()

2、如何保存模型

在 pytorch 中,所有的网络参数信息都保存在状态字典中,也就是 python 中的 dict 结构。本身我们对 dict 就有很多保存的方法,所以 pytorch 在这里的实现也就不复杂了。

在保存模型的时候,我们有两种常见做法:第一种是只保存网络的参数,那么在使用的时候创建好网络结构,然后将参数加载进去;第二种是保存整个网络。下面我们分别给出示例:

torch.save(cnn.state_dict(), PATH)

通过这个方法,就可以将一个名为 cnn 的网络,保存到目标路径 PATH 中。在使用的时候按照如下方法:

cnn = CNN(*args, **kwargs)
cnn.load_state_dict(torch.load(PATH))
cnn.eval()

我们可以看到先通过设定好的网络结构 CNN 先实例化一个 cnn 网络出来,然后用 load_state_dict() 函数去读取保存好的网络参数。

还有一种办法是保存整个网络,具体的方法如下:

torch.save(cnn, PATH)

而读取的方法也不能用 load_state_dict() 了,需要直接使用 load() 方法。

cnn = torch.load(PATH)
cnn.eval()

这样子我们就完成了对模型的导入。

除此之外,还值得注意的几个点是 checkpoint 的保存和冻结部分参数进行 fine-tuning。

checkpoint 保存也就是保存下模型在某个时刻的状态,那么就需要保存当时的所有信息,包括当时的 epoch,loss,优化器的 state_dict 等等。

而冻结部分参数进行 fine-tuning 也是比较常见的一种行为,尤其是在涉及到迁移学习等方面时,我们经常需要将一个在其它数据集上进行训练过的模型,再到当前数据集上继续训练。那么在这一步的时候一般就是将对应的 tensor 的 requires_grid 设置为 True 或者 False,然后按照如下的命令进行训练:

optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)

3、总结

今天的文章总结了两个比较常见的功能,我们在实际场景中也的确会经常用到这两个方法。GPU 对炼丹的帮助不言而喻,而模型的保存在我们的实际使用中也是不可或缺的。

pytorch学习笔记(6):GPU和如何保存加载模型相关推荐

  1. Tensorflow学习(二)之——保存加载模型、Saver的用法

    1. Saver的背景介绍 我们经常在训练完一个模型之后希望保存训练的结果,这些结果指的是模型的参数,以便下次迭代的训练或者用作测试.Tensorflow针对这一需求提供了Saver类. Saver类 ...

  2. TensorFlow学习笔记——使用TFRecord进行数据保存和加载

    本篇文章主要介绍如何使用TensorFlow构建自己的图片数据集TFRecord的方法,并使用最新的数据处理Dataset API进行操作. TFRecord TFRecord数据文件是一种对任何数据 ...

  3. Objective-C学习笔记第十五章文件加载与保存

    第十五章文件加载与保存 Cocoa提供了Core Data,他能在后台处理所有文件内容 Cocoa提供了两个通用的文件处理类:属性列表和对象编码 一.属性列表类 在Cocoa中,有一类名为属性列表的对 ...

  4. Unity学习笔记(5):动态加载Prefab

    第一种方法,从Resources文件夹读取Prefab Assets/Resources文件夹是Unity中的一个特殊文件夹,在博主当前的认知里,放在这个文件夹里的Prefab可以被代码动态加载 直接 ...

  5. contiki学习笔记(六)contiki程序加载器和多线程库

    六.contiki程序加载器 contiki程序加载器是一个用于加载和启动程序的抽象接口. Data Structures struct dsc//DSC程序描述结构. ModulesThe Cont ...

  6. JVM学习笔记之-类加载子系统,类的加载与类的加载过程,双亲委派机制

    一 类加载器与类加载过程 类加载子系统作用 类加载器子系统负责从文件系统或者网络中加载class文件,class文件在文件开头有特定的文件标识. ClassLoader只负责class文件的加载,至于 ...

  7. scrapy学习笔记(三)-关于动态加载网页的爬取(序)

    一.尝试 对于我要爬取的网站内容,按照网上普遍的步骤:直接通过xpath获取到对于数据,再记录到item中,只适用于静态html网页,但是如今的互联网大部分的web页面都是动态的,经常逛的网站例如京东 ...

  8. 深度实践SPARK机器学习_学习笔记_第二章2.3加载数据

    2.3加载数据 1.下载数据文件u.user head -3 u.user ##查看文件前几行 cat u.user |wc -l 或者 more u.user |wc -l    ##数文件记录数 ...

  9. Pytorch 保存和加载模型

    当保存和加载模型时,需要熟悉三个核心功能: 1. torch.save :将序列化对象保存到磁盘.此函数使用Python的 pickle 模块进行序列化.使 用此函数可以保存如模型.tensor.字典 ...

最新文章

  1. 如何进行机器学习框架选择
  2. Spark Shuffle两种Manager
  3. python dataframe列数值相加,python合并dataframe中的行并将值相加
  4. pythongui程序,python第一个GUI程序
  5. WebKit 内核源码分析 (一) Frame
  6. SpringBoot实用小技巧之动态设置SpringBoot日志级别 1
  7. springboot aop记录日志
  8. jQuery新的事件绑定机制on()
  9. 用python操作浏览器的三种方式_经验 | python 操作浏览器的三种方式
  10. 是否优化更新主题浏览量:_主题306:能力规划
  11. 当Apple TV+的生态化反梦,撞上一个“日渐昂贵”的流媒体市场
  12. 阿里云服务器租用收费标准(精准费用报价更新)
  13. 金彩教育:详情页文案怎么写
  14. Swift不深入只浅出入门教程-孟祥月-专题视频课程
  15. mavoneditor 显示html,mavonEditor
  16. 使用scp把另外一台服务器上的文件夹/文件拷贝到当前服务器
  17. PCA变换与KL变换
  18. 《python3廖雪峰》正则表达式匹配Email地址练习题答案
  19. [附源码]JAVA毕业设计酒店订房系统(系统+LW)
  20. 量子纠缠暗示了:我们这个世界很诡异!它到底纠缠了个啥?

热门文章

  1. redis zset usage
  2. SQL 年龄段 品牌分类 分组统计
  3. 陕西卫视《关中男人》观后感--女人之后是男人?
  4. Java之Set接口
  5. finalize作用
  6. Kubelet 源码剖析
  7. 有效的MongoDB索引
  8. python http get 请求_Python:编写HTTP Server处理GET请求
  9. git 常用操作,撤销修改
  10. unserialize用法