加快深度学习模型训练速度@tf.function
TensorFlow 2 默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但为了追求更快的速度与更高的性能,我们依然希望使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution)。此时,TensorFlow 2 为我们提供了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图执行模式运行。
import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import timenp.random.seed(42) # 设置numpy随机数种子
tf.random.set_seed(42) # 设置tensorflow随机数种子# 生成训练数据
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1 # y=x^2+1 + 随机噪声
x_train = np.expand_dims(x, 1) # 将一维数据扩展为二维
y_train = np.expand_dims(y, 1) # 将一维数据扩展为二维
plt.plot(x, y, '.') # 画出训练数据def create_model():inputs = keras.Input((1,))x = keras.layers.Dense(10, activation='relu')(inputs)outputs = keras.layers.Dense(1)(x)model = keras.Model(inputs=inputs, outputs=outputs)return modelmodel = create_model() # 创建一个模型
loss_fn = keras.losses.MeanSquaredError() # 定义损失函数
optimizer = keras.optimizers.SGD() # 定义优化器@tf.function # 将训练过程转化为图执行模式
def train():with tf.GradientTape() as tape:y_pred = model(x_train, training=True) # 前向传播,注意不要忘了training=Trueloss = loss_fn(y_train, y_pred) # 计算损失tf.summary.scalar("loss", loss, epoch+1) # 将损失写入tensorboardgrads = tape.gradient(loss, model.trainable_variables) # 计算梯度optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 使用优化器进行反向传播return lossepochs = 1000
begin_time = time.time() # 训练开始时间
for epoch in range(epochs):loss = train()print('epoch:', epoch+1, '\t', 'loss:', loss.numpy()) # 打印训练信息
end_time = time.time() # 训练结束时间print("训练时长:", end_time-begin_time)# 预测
y_pre = model.predict(x_train)# 画出预测值
plt.plot(x, y_pre.squeeze())
plt.show()
测试结果:不使用@tf.function,那么训练时间大约为3秒。如果使用@tf.function,训练时间仅需要0.5秒。快了很多倍。
解决无法调试的问题:
用于将函数转换为图里面的节点,从而可以加速运算,我们可以给任何的函数定义上加入@tf.function,但是在调试的时候,无法直接去调试被@tf.function修饰的函数,即我们如果给函数里面打上断点,但是我们无法进入到断点的位置。
为了进行调试,需要在调用函数之前加上下面的语句。
tf.config.experimental_run_functions_eagerly(True)
测试案例:
tf.config.experimental_run_functions_eagerly(True)
# 完整的例子如下
@tf.function
def f(x):if x > 0:# Try setting a breakpoint here!# Example:# import pdb# pdb.set_trace()x = x + 1return xtf.config.experimental_run_functions_eagerly(True)# You can now set breakpoints and run the code in a debugger.
f(tf.constant(1))
参考原文在这里:使用@tf.function加快训练速度
其他参考:https://blog.csdn.net/weixin_43824178/article/details/99297237
加快深度学习模型训练速度@tf.function相关推荐
- 图像处理深度学习模型训练速度的硬件影响因素
深度学习训练速度的影响因素 1 数据流通路径 2影响速率的因素 2.1硬盘读取速度 2.2PCle传输速度 2.3内存读写速度 2.4cpu频率 2.5 GPU 其他名词 以图象训练任务为例,从CPU ...
- 深度学习模型训练过程
深度学习模型训练过程 一.数据准备 基本原则: 1)数据标注前的标签体系设定要合理 2)用于标注的数据集需要无偏.全面.尽可能均衡 3)标注过程要审核 整理数据集 1)将各个标签的数据放于不同的文件夹 ...
- 收藏 | PyTorch深度学习模型训练加速指南2021
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...
- 笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解
笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解 针对特定场景任务从模型选择.模型训练.超参优化.效果展示这四个方面进行模型开发. 一.模型选择 从任务类型出发,选择最合适的模型. ...
- 深度学习模型训练和关键参数调优详解
深度学习模型训练和关键参数调优详解 一.模型选择 1.回归任务 人脸关键点检测 2.分类任务 图像分类 3.场景任务 目标检测 人像分割 文字识别 二.模型训练 1.基于高层API训练模型 加载数据集 ...
- 深度学习模型训练的一般方法(以DSSM为例)
向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程 公众号:datayx 本文主要用于记录DSSM模型学习期间遇到的问题及分析.处理经验.先统领性地提出深度学习模型训练 ...
- AI佳作解读系列(一)——深度学习模型训练痛点及解决方法
AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 参考文章: (1)AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 (2)https://www.cnblogs.com/carson ...
- dcm格式的文件里有什么,哪些对于深度学习模型训练有用
DCM格式的文件通常包含医学图像,如X射线.CT或MRI扫描.这些图像可以用来辅助医生诊断疾病,并且对于深度学习模型训练也非常有用.在医学图像分析方面,深度学习模型可以用来做图像分割.疾病诊断.肿瘤检 ...
- 深度学习模型训练的结果及改进方法
深度学习模型训练的结果及改进方法 模型在训练集上误差较大: 解决方法:1. 选择新的激活函数2. 使用自适应的学习率 在训练集上表现很好,但在测试集上表现很差(过拟合): 解决方法:1. 减少迭代次数 ...
最新文章
- Navicat导出表结构导出成Excel
- 面试题:mysql 表删除一半数据,B+树索引文件会不会变小???
- 松翰松翰c语言编程指导,松翰C程序检单例程代码下载
- jQuery的JSONP
- Git之多个用户ID适配
- 计算机c盘丢失,电脑C盘丢失的视频文件怎么恢复?方法讲解,轻松搞定
- python多标签分类_如何通过sklearn实现多标签分类?
- 阿里云centos7使用yum安装mysql的正确姿势
- javaSE学习 访问控制
- 深度强化学习之稀疏奖励(Sparse Reward)
- Web 探索之旅 | 第二部分第四课:数据库
- 使用代理服务器来连接到internet_代理服务器是什么,有什么作用?
- PXE启动错误代码一览表
- MATLAB神经网络工具箱 BP神经网络函数化表示 BP神经网络梯度\求导函数
- ERP ERP原理与应用试题(附答案)
- 通过getPixel();和通过bmp.getPixels();方法遍历整张图片的效率比较。
- 迅雷协议分析–多链接资源获取
- Android 中短信数据库的简单操作
- 更改vscode Java项目的.class文件输出路径
- Android报错installation failed with message invalid file E://.....