点击上方关注,All in AI中国

本文将介绍如何使用Keras和Google CoLaboratory与TPU一起训练LSTM模型,与本地计算机上的GPU相比,这样训练能大大缩短训练时间。

很长一段时间以来,我都在单张GTX 1070显卡上训练我的模型,它的单精度大约为8.18 TFlops。后来Google的Colab开放了免费的Tesla K80显卡,配备12GB RAM,8.73TFlops。直到最近,Colab的运行时类型选择器中还会弹出带有180 TFlops的Cloud TPU选项。这篇教程将简要介绍如何将现有的Keras模型转换为TPU模型,然后在Colab上训练。与在GTX1070上训练相比,TPU能够加速20倍。

我们将构建一个易于理解,但训练起来非常复杂的Keras模型,这样我们就可以稍微"预热"一下Cloud TPU。在IMDB情感分类任务上训练LSTM模型可能是一个很好的例子,因为相比密集层和卷积层来说,训练LSTM对算力要求更高。

工作流程概述:

  • 使用静态输入batch_size构建用于功能API训练的Keras模型
  • 将Keras模型转换为TPU模型
  • 使用静态batch_size * 8训练TPU模型,并将权重保存到文件
  • 创建一个结构相同,但输入批大小可变的Keras模型,用于推理
  • 加载模型权重
  • 基于推理模型进行预测

在阅读本文的同时,你可以上手试验相应的Colab Jupyter notebook:Keras_LSTM_TPU.ipynb。(https://colab.research.google.com/drive/1QZf1WeX3EQqBLeFeT4utFKBqq-ogG1FN)

首先,按照下图中的说明来激活在Colab运行中的TPU。

激活TPU

固定输入批尺寸

大多数情况下,CPU和GPU上对输入形状没有限制,但XLA/TPU环境下会强制使用固定的形状和批尺寸。

Can TPU包含8个TPU核心,作为独立的处理单元运行。如果没有使用所有八个核心,那TPU就不会得到充分利用。为了充分提高训练的矢量化速度,相比在单一GPU上训练的同样的模型,我们可以选择较大的批尺寸。总批尺寸大小为1024(每个核心128个)通常是一个很好的起点。

如果你要训练批尺寸较大的型号,请尝试慢慢减小批尺寸,以保证TPU内存放得下,只需确保总批尺寸为64的倍数(每核心批尺寸应该是8的倍数)。

值得一提,在批尺寸较大时,通常可以提高优化器的学习速率,以实现更快的收敛。你可以在本文中找到参考——"Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour"。(https://arxiv.org/pdf/1706.02677.pdf)

在Keras中,要定义静态批处理尺寸,我们使用函数API,然后为输入层指定batch_size参数。请注意,模型构建在一个带有batch_size参数的函数中,因此我们之后可以很方便地创建在CPU或GPU上运行的模型,这些模型接受可变批尺寸的输入。

此外,我们在这里使用了tf.train.Optimizer而不是标准的Keras优化器,因为TPU对Keras优化器的支持还处于实验阶段。

将Keras模型转换为TPU模型

tf.contrib.tpu.keras_to_tpu_model函数将tf.keras模型转换为等价的TPU版本。

然后,我们使用标准的Keras方法来训练,保存权重并评估模型。请注意,batch_size设置为模型输入batch_size的八倍,因为输入样本在8个TPU核心上均匀分布。

我做了一个实验,用来比较在Windows PC上运行单个GTX1070和在Colab上运行的TPU之间的训练速度,结果如下:

  • GPU和TPU都将输入批尺寸设为128。
  • GPU:每个历元179秒。20个历元后的验证准确率达到了76.9%,总计3600秒。
  • TPU:每个历元5秒(第一个历元需要49秒)。20个历元后的验证准确率达到了95.2%,总计150秒。
  • 在20个历元之后TPU的验证准确度高于在GPU上的表现,那是因为TPU上同时训练8个批的样本(每个批的大小为128)。

在CPU上进行推理

一旦我们获得了模型权重,我们就可以像往常一样加载它,然后在CPU或GPU等其他设备上进行预测。我们想要推理模型接受可变的输入批大小,这可以使用之前的make_model()函数来实现。

你可以看到推理模型现在可以接受可变输入样本数目,

然后,你可以使用标准的fit()、evaluate()函数与推理模型。

结论以及进一步阅读

这篇快速教程向你简要介绍了如何利用Google Colab上的免费Cloud TPU资源更快地训练Keras模型。

云TPU文档:https://cloud.google.com/tpu/docs/

云TPU性能指南:https://cloud.google.com/tpu/docs/performance-guide

云TPU故障排除指南:https://cloud.google.com/tpu/docs/troubleshooting

XLA概述:https://www.tensorflow.org/performance/xla/

编译出品

keras训练完以后怎么预测_还在使用“龟速”的单显卡训练模型?动动手,让TPU节省你的时间...相关推荐

  1. keras训练完以后怎么预测_使用Keras建立Wide Deep神经网络,通过描述预测葡萄酒价格...

    你能通过"优雅的单宁香"."成熟的黑醋栗香气"或"浓郁的酒香"这样的描述,预测葡萄酒的价格吗?事实证明,机器学习模型可以. 在这篇文章中,我 ...

  2. keras训练完以后怎么预测_一文告诉你如何将Keras模型保存到文件中,并再次加载它们来进行预测。...

    Keras是一个用于深度学习的简单而强大的Python库. 鉴于深度学习模式可能需要数小时.数天甚至数周的时间来培训,了解如何保存并将其从磁盘中加载是很重要的. 在本文中,您将发现如何将Keras模型 ...

  3. keras训练完以后怎么预测_农村小孩只有户口,没有承包地,以后怎么养老?看完我安心了...

    阅读本文前,请您先点击上面的蓝色字体"三农荟",再点击"关注",这样您就可以继续免费收到最新情感文章了.每天都有分享.完全是免费订阅,请放心关注. 农村小孩,只 ...

  4. 斜度符号标注_还记得机械图纸尺寸标注规则吗?通过动图详解复习一下

    1.基本规则 1.机件的真实大小应以图样上所注的尺寸数值为依据,与图形的大小及绘图的准确度无关. 2.图样中(包括技术要求和其它说明)的尺寸,以mm为单位时,不需标注计量单位的代号或名称,如采用其它单 ...

  5. 保存时间 默认_一些不起眼但又非常的实用的PPT制作技巧,大大节省PPT制作时间...

    从PPT小白到PPT大神的过程中,我们总会无数次碰壁,无数次陷入困境.今天为大家带来的是一些不起眼的PPT技巧,但是非常的实用,不信就看下文吧! 自定义访问工具栏 在PPT中我们有很多的常用操作,例如 ...

  6. layui横向时间线_一些不起眼但又非常的实用的PPT制作技巧,大大节省PPT制作时间...

    从PPT小白到PPT大神的过程中,我们总会无数次碰壁,无数次陷入困境.今天为大家带来的是一些不起眼的PPT技巧,但是非常的实用,不信就看下文吧! 自定义访问工具栏 在PPT中我们有很多的常用操作,例如 ...

  7. 【1】Keras复习之模型,层,训练,评估与预测

    本系列主要是针对文档的学习,文档地址是: www.keras.io,文档非常详细. Keras的核心数据结构就是模型,最简单的模型就是序贯模型,也就是Sequential模型,是层的线性堆砌.如果是想 ...

  8. C++ 和 OpenCV 实现卷积神经网络并加载 Keras 训练好的参数进行预测

    C++ 和 OpenCV 实现卷积神经网络并加载 Keras 训练好的参数进行预测 一. 背景 二. Keras 定义神经网络结构 channels_first 与 channels_last cha ...

  9. 使用Keras训练自动驾驶(使用Udacity自动驾驶模拟器)

    使用Keras训练自动驾驶(使用Udacity自动驾驶模拟器) 1.完成项目所需要的资源 (1)模拟器下载 • Linux • macOS • Windows (2)Unity 下载 运行Udacit ...

最新文章

  1. Android学习——R文件丢失异常原因汇总
  2. Linux技巧:一次删除一百万个文件最快方法
  3. php 随机在文章中添加锚文本_锚文本对网站SEO优化有什么帮助?
  4. Java学习笔记1.1.3 搭建Java开发环境 - 编写并运行Java程序
  5. 数据增量更新定义_技术资讯 | TiDB在准实时数据仓库中的实践
  6. case / switch语句的Python等价物是什么? [重复]
  7. Competitive Programming 3题解
  8. 【转】Add a user/Administrator to Windows Server 2008
  9. ngrok技术原理及下载使用
  10. linux服务器有电信和网通,Linux 双网关(电信与联通)
  11. python中oserror_[python] 解决OSError:
  12. python网络爬虫实战之下载笔趣看小说网小说
  13. c语言easyx改变字体大小,改变控制台字体大小
  14. 服务器文件夹怎么找回来,文件过期了怎么恢复(教你一招找回微信过期文件)...
  15. 若依(RuoYi)配置教程
  16. 单位篮球比赛结束,感想很多
  17. luoguP2711 小行星
  18. javaweb 图书管理系统完整代码_群晖 + Docker + Calibre-Web 搭建电子书管理系统
  19. Mysql常用的sql语句大全
  20. 332B. Maximum Absurdity

热门文章

  1. 最新最详细最简洁Eclipse调试PHP配置详解(Xdebug,Zend Debugger)
  2. JSP2.0中Simple Tag介绍
  3. 【哈利波特】Sherbert Lemon对HP的解读之11
  4. nike附近门店查询_不止5折!200+入手Nike、adidas,比“11.11”还便宜!
  5. js实现modbus_nodejs中使用modbus-serial库创建Modbus TCP读取设备的数据
  6. 收集Redis 经典面试题
  7. SQLServer常用的配置函数笔记
  8. Linux关于文件的权限笔记
  9. python线性回归分析看相关性_机器学习入门-相关分析之简单线性回归
  10. java局部刷新session过期_Ajax局部页面刷新和History API结合的陷阱