TensorFlow 2 默认的即时执行模式(Eager Execution)为我们带来了灵活及易调试的特性,但为了追求更快的速度与更高的性能,我们依然希望使用 TensorFlow 1.X 中默认的图执行模式(Graph Execution)。此时,TensorFlow 2 为我们提供了 tf.function 模块,结合 AutoGraph 机制,使得我们仅需加入一个简单的 @tf.function 修饰符,就能轻松将模型以图执行模式运行。

import tensorflow as tf
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import timenp.random.seed(42)  # 设置numpy随机数种子
tf.random.set_seed(42)  # 设置tensorflow随机数种子# 生成训练数据
x = np.linspace(-1, 1, 100)
x = x.astype('float32')
y = x * x + 1 + np.random.rand(100)*0.1  # y=x^2+1 + 随机噪声
x_train = np.expand_dims(x, 1)  # 将一维数据扩展为二维
y_train = np.expand_dims(y, 1)  # 将一维数据扩展为二维
plt.plot(x, y, '.')  # 画出训练数据def create_model():inputs = keras.Input((1,))x = keras.layers.Dense(10, activation='relu')(inputs)outputs = keras.layers.Dense(1)(x)model = keras.Model(inputs=inputs, outputs=outputs)return modelmodel = create_model()  # 创建一个模型
loss_fn = keras.losses.MeanSquaredError()  # 定义损失函数
optimizer = keras.optimizers.SGD()  # 定义优化器@tf.function  # 将训练过程转化为图执行模式
def train():with tf.GradientTape() as tape:y_pred = model(x_train, training=True)  # 前向传播,注意不要忘了training=Trueloss = loss_fn(y_train, y_pred)  # 计算损失tf.summary.scalar("loss", loss, epoch+1)  # 将损失写入tensorboardgrads = tape.gradient(loss, model.trainable_variables)  # 计算梯度optimizer.apply_gradients(zip(grads, model.trainable_variables))  # 使用优化器进行反向传播return lossepochs = 1000
begin_time = time.time()  # 训练开始时间
for epoch in range(epochs):loss = train()print('epoch:', epoch+1, '\t', 'loss:', loss.numpy())  # 打印训练信息
end_time = time.time()  # 训练结束时间print("训练时长:", end_time-begin_time)# 预测
y_pre = model.predict(x_train)# 画出预测值
plt.plot(x, y_pre.squeeze())
plt.show()


测试结果:不使用@tf.function,那么训练时间大约为3秒。如果使用@tf.function,训练时间仅需要0.5秒。快了很多倍。

解决无法调试的问题:
用于将函数转换为图里面的节点,从而可以加速运算,我们可以给任何的函数定义上加入@tf.function,但是在调试的时候,无法直接去调试被@tf.function修饰的函数,即我们如果给函数里面打上断点,但是我们无法进入到断点的位置。

为了进行调试,需要在调用函数之前加上下面的语句。
tf.config.experimental_run_functions_eagerly(True)

测试案例:

tf.config.experimental_run_functions_eagerly(True)
# 完整的例子如下
@tf.function
def f(x):if x > 0:# Try setting a breakpoint here!# Example:#   import pdb#   pdb.set_trace()x = x + 1return xtf.config.experimental_run_functions_eagerly(True)# You can now set breakpoints and run the code in a debugger.
f(tf.constant(1))

参考原文在这里:使用@tf.function加快训练速度

其他参考:https://blog.csdn.net/weixin_43824178/article/details/99297237

加快深度学习模型训练速度@tf.function相关推荐

  1. 图像处理深度学习模型训练速度的硬件影响因素

    深度学习训练速度的影响因素 1 数据流通路径 2影响速率的因素 2.1硬盘读取速度 2.2PCle传输速度 2.3内存读写速度 2.4cpu频率 2.5 GPU 其他名词 以图象训练任务为例,从CPU ...

  2. 深度学习模型训练过程

    深度学习模型训练过程 一.数据准备 基本原则: 1)数据标注前的标签体系设定要合理 2)用于标注的数据集需要无偏.全面.尽可能均衡 3)标注过程要审核 整理数据集 1)将各个标签的数据放于不同的文件夹 ...

  3. 收藏 | PyTorch深度学习模型训练加速指南2021

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者:LORENZ KUHN 编译:ronghuaiyang ...

  4. 笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解

    笔记 | 百度飞浆AI达人创造营:深度学习模型训练和关键参数调优详解 针对特定场景任务从模型选择.模型训练.超参优化.效果展示这四个方面进行模型开发. 一.模型选择 从任务类型出发,选择最合适的模型. ...

  5. 深度学习模型训练和关键参数调优详解

    深度学习模型训练和关键参数调优详解 一.模型选择 1.回归任务 人脸关键点检测 2.分类任务 图像分类 3.场景任务 目标检测 人像分割 文字识别 二.模型训练 1.基于高层API训练模型 加载数据集 ...

  6. 深度学习模型训练的一般方法(以DSSM为例)

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 本文主要用于记录DSSM模型学习期间遇到的问题及分析.处理经验.先统领性地提出深度学习模型训练 ...

  7. AI佳作解读系列(一)——深度学习模型训练痛点及解决方法

    AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 参考文章: (1)AI佳作解读系列(一)--深度学习模型训练痛点及解决方法 (2)https://www.cnblogs.com/carson ...

  8. dcm格式的文件里有什么,哪些对于深度学习模型训练有用

    DCM格式的文件通常包含医学图像,如X射线.CT或MRI扫描.这些图像可以用来辅助医生诊断疾病,并且对于深度学习模型训练也非常有用.在医学图像分析方面,深度学习模型可以用来做图像分割.疾病诊断.肿瘤检 ...

  9. 深度学习模型训练的结果及改进方法

    深度学习模型训练的结果及改进方法 模型在训练集上误差较大: 解决方法:1. 选择新的激活函数2. 使用自适应的学习率 在训练集上表现很好,但在测试集上表现很差(过拟合): 解决方法:1. 减少迭代次数 ...

最新文章

  1. Navicat导出表结构导出成Excel
  2. 面试题:mysql 表删除一半数据,B+树索引文件会不会变小???
  3. 松翰松翰c语言编程指导,松翰C程序检单例程代码下载
  4. jQuery的JSONP
  5. Git之多个用户ID适配
  6. 计算机c盘丢失,电脑C盘丢失的视频文件怎么恢复?方法讲解,轻松搞定
  7. python多标签分类_如何通过sklearn实现多标签分类?
  8. 阿里云centos7使用yum安装mysql的正确姿势
  9. javaSE学习 访问控制
  10. 深度强化学习之稀疏奖励(Sparse Reward)
  11. Web 探索之旅 | 第二部分第四课:数据库
  12. 使用代理服务器来连接到internet_代理服务器是什么,有什么作用?
  13. PXE启动错误代码一览表
  14. MATLAB神经网络工具箱 BP神经网络函数化表示 BP神经网络梯度\求导函数
  15. ERP ERP原理与应用试题(附答案)
  16. 通过getPixel();和通过bmp.getPixels();方法遍历整张图片的效率比较。
  17. 迅雷协议分析–多链接资源获取
  18. Android 中短信数据库的简单操作
  19. 更改vscode Java项目的.class文件输出路径
  20. Android报错installation failed with message invalid file E://.....

热门文章

  1. 03.spring framework的AOP
  2. html中qq号码怎么写,qq号码免费申请6位号的方法
  3. metasequoia :Summoner
  4. 计算机出现蓝屏怎么解决,教你电脑出现蓝屏是怎么回事
  5. 房卡麻将分析系列之断线重连
  6. 山外K66连接TLL注意事项
  7. 登录安全----双重MD5加密实现安全登录
  8. 计算机内存卡插哪里,电脑内存卡在哪个位置
  9. 大数据BI解决方案:医疗行业的数据治理
  10. 豆瓣最新API-python