朋友们,如需转载请标明出处:https://blog.csdn.net/jiangjunshow

在tensorflow1.x的时候,代码默认的执行方式是graph execution(图执行),而从tensorflow2.0开始,改为了eager execution(饥饿执行)。正如翻译的意思一样,eager execution会立即执行每一步代码,非常的饥渴。而graph execution会将所有代码组合成一个graph(图)后再执行。这里打一个不太恰当的比喻来帮助大家理解:eager execution就像搞一夜情,认识后就立即“执行”,而graph execution就像婚恋,认识后先憋着,不会立即“执行”,要经过了长时间的“积累”后,再一次性“执行”。

在eager 模式下,代码的编写变得很自然很简单,而且因为代码会被立即执行,所以调试时也变得很方便。而graph 模式下,代码的执行效率要高一些;而且由于graph其实就是一个由操作指令和数据组成的一个数据结构,所以graph可以很方便地被导出并保存起来,甚至之后可以运行在其它非python的环境下(因为graph就是个数据结构,里面定义了一些操作指令和数据,所以任何地方只要能解释这些操作和数据,那么就能运行这个模型);也正因为graph是个数据结构,所以不同的运行环境可以按照自己的喜好来解释里面的操作和数据,这样一来,解释后生成的代码会更加符合当前运行的环境,这里一来代码的执行效率就更高了。

可能有些同学还无法理解上面所说的“graph是个数据结构…”。这里我打个比方来帮助大家理解。假设graph里面包含了两个数据x和y,另外还包含了一个操作指令“将x和y相加”。当C++的环境要运行这个graph时,“将x和y相加”这个操作就会被翻译成相应的C++代码,当Java环境下要运行这个graph时,就会被解释成相应的Java代码。graph里面只是一些数据和指令,具体怎么执行命令,要看当前运行的环境。

除了上面所说的,graph还有很多内部机制使代码更加高效运行。总之,graph execution可以让tensorflow模型运行得更快,效率更高,更加并行化,更好地适配不同的运行环境和运行设备。

graph 虽然运行很高效,但是代码却没有eager 的简洁,为了兼顾两种模式的优点,所以出现了tf.function。使用tf.function可以将eager 代码一键封装成graph。

既然是封装成graph,那为什么名字里使用function这个单词内,不应该是tf.graph吗?因为tf.function的作用就是将python function转化成包含了graph的tensorflow function。所以使用function这个单词也说得通。下面的代码可以帮助大家更好地理解。

import tensorflow as tf
import timeit
from datetime import datetime# 定义一个 Python function.
def a_regular_function(x, y, b):x = tf.matmul(x, y)x = x + breturn x# `a_function_that_uses_a_graph` 是一个 TensorFlow `Function`.
a_function_that_uses_a_graph = tf.function(a_regular_function)# 定义一些tensorflow tensors.
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)orig_value = a_regular_function(x1, y1, b1).numpy()
# 在python中可以直接调用tenforflow Function。就像使用python自己的function一样。
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy()
assert(orig_value == tf_function_value)

tf.function不仅仅只作用于顶层的python function,它也作用于内嵌的python function。看下面的代码你就能明白了。

def inner_function(x, y, b):x = tf.matmul(x, y)x = x + breturn x# 使用tf.function将`outer_function`变成一个tensorflow `Function`。注意,之前的代码是将tf.function当作是函数来使用,这样是被当作了修饰符来使用。这两种方式都是被支持的。
@tf.function
def outer_function(x):y = tf.constant([[2.0], [3.0]])b = tf.constant(4.0)return inner_function(x, y, b)# tf.function构建的graph中不仅仅包含了 `outer_function`还包含了它里面调用的`inner_function`。
outer_function(tf.constant([[1.0, 2.0]])).numpy()

输出结果:

array([[12.]], dtype=float32)

如果你之前使用过tenforflow 1.x,你会察觉到,在2.x中构建graph再也不需要tf.Session和Placeholder了。使代码大大地简洁了。

我们的代码里经常会将python代码和tensorflow代码混在一起。在使用tf.function进行graph转化时,tensorflow的代码会被直接进行转化,而python代码会被一个叫做AutoGraph (tf.autograph)的库来负责进行转化。

同一个tensorflow function可能会生成不同的graph。因为每一个tf.Graph的input输入类型必须是固定的,所以如果在调用tensorflow function时传入了新的数据类型,那么这次的调用就会生成一个新的graph。输入的类型以及维度被称为signature(签名),tensorflow function就是根据签名来生成graph的,遇到新的签名就会生成新的graph。下面的代码可以帮助你理解。

@tf.function
def my_relu(x):return tf.maximum(0., x)# 下面对`my_relu` 的3次调用的数据类型都不同,所以生成了3个graph。这3个graph都被保存在my_relu这个tenforflow function中。
print(my_relu(tf.constant(5.5)))
print(my_relu([1, -1])) #python数组
print(my_relu(tf.constant([3., -3.])))  # tf数组

输出结果:

tf.Tensor(5.5, shape=(), dtype=float32)
tf.Tensor([1. 0.], shape=(2,), dtype=float32)
tf.Tensor([3. 0.], shape=(2,), dtype=float32)

如果相同的输入类型被调用了,那么不会再重新生成新的类型。

# 下面这两个调用就不会生成新的graph.
print(my_relu(tf.constant(-2.5))) # 这个数据类型与上面的 `tf.constant(5.5)`一样.
print(my_relu(tf.constant([-1., 1.]))) # 这个数据类型与上面的 `tf.constant([3., -3.])`一样。

因为一个tensorflow function里面可以包含多个graph,所以说tensorflow function是具备多态性的。这种多态性使得tensorflow function可以任意支持不同的输入类型,非常的灵活;并且由于对每一个输入类型会生成一个特定的graph,这也会让代码执行时更加高效!

下面的代码打印出了3种不同的签名

print(my_relu.pretty_printed_concrete_signatures())

输出结果:

my_relu(x)Args:x: float32 Tensor, shape=()Returns:float32 Tensor, shape=()my_relu(x=[1, -1])Returns:float32 Tensor, shape=(2,)my_relu(x)Args:x: float32 Tensor, shape=(2,)Returns:float32 Tensor, shape=(2,)

上面你已经学会了如何使用tf.function将python function转化为tenforflow function。但要想在实际开发中正确地使用tf.function,还需要学习更多知识。下面我就带领大家来学习学习它们。八十八师的弟兄们,不要退缩,跟着我一起冲啊啊啊!

默认情况下,tenforflow function里面的代码会以graph的模式被执行,但是也可以让它们以eager的模式来执行。大家看下面的代码。

@tf.function
def get_MSE():print("Calculating MSE!")#这条语句就是让下面的代码以eager的模式来执行
tf.config.run_functions_eagerly(True)
get_MSE(y_true, y_pred)
#这条代码就是取消前面的设置
tf.config.run_functions_eagerly(False)

某些情况下,同一个tensorflow function在graph与eager模式下会有不同的运行效果。python的print函数就是其中一个特殊情况。看下面的代码。

@tf.function
def get_MSE(y_true, y_pred):print("Calculating MSE!")sq_diff = tf.pow(y_true - y_pred, 2)return tf.reduce_mean(sq_diff)y_true = tf.random.uniform([5], maxval=10, dtype=tf.int32)
y_pred = tf.random.uniform([5], maxval=10, dtype=tf.int32)error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)

输出结果:

Calculating MSE!

看到输出结果你是不是很惊讶?get_MSE被调用了3次,但是里面的python print函数只被执行了一次。这是为什么呢?因为python print函数只在创建graph时被执行,而上面的3次调用中输入参数的类型都是一样的,所以只有一个graph被创建了一次,所以python print函数也只会被调用一次。

为了将graph和eager进行对比,下面我们在eager模式下看看输出结果。

# 开启强制eager模式
tf.config.run_functions_eagerly(True)error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)
error = get_MSE(y_true, y_pred)# 取消eager模式
tf.config.run_functions_eagerly(False)

输出结果:

Calculating MSE!
Calculating MSE!
Calculating MSE!

看!在eager模式下,print被执行了3次。PS:如果使用tf.print,那么在graph和eager模式下都会打印3次。

graph execution模式还有一个特点,就是它会不执行那些无用的代码。看下面的代码。

def unused_return_eager(x):# 当传入的x只包含一个元素时,下面的代码会报错,因为下面的代码是要获取x的第二个元素。PS:索引是从0开始的,1代表第二个元素tf.gather(x, [1]) # unused return xtry:print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:print(f'{type(e).__name__}: {e}')

上面的代码是以eager的模式运行,所以每一行代码都会被执行,所以上面的异常会发生并且会被捕获到。而下面的代码是以graph模式运行的,则不会报异常。因为tf.gather(x, [1])这句代码其实没有任何用途(它只是获取了x的第二个元素,并没有赋值也没有改变任何变量),所以graph模式下它根本就没有被执行,所以也就不会报任何异常了。

@tf.function
def unused_return_graph(x):tf.gather(x, [1])return xtry:print(unused_return_eager(tf.constant([0.0])))
except tf.errors.InvalidArgumentError as e:print(f'{type(e).__name__}: {e}')

前面我们说graph的执行效率会比eager的要高,那到底高多少呢?其实我们可以用下面的代码来计算graph模式到底能比eager模式提升多少效率。

x = tf.random.uniform(shape=[10, 10], minval=-1, maxval=2, dtype=tf.dtypes.int32)def power(x, y):result = tf.eye(10, dtype=tf.dtypes.int32)for _ in range(y):result = tf.matmul(x, result)return result
print("Eager execution:", timeit.timeit(lambda: power(x, 100), number=1000))

输出结果:

Eager execution: 1.8983725069999764
power_as_graph = tf.function(power)
print("Graph execution:", timeit.timeit(lambda: power_as_graph(x, 100), number=1000))

输出结果:

Graph execution: 0.5891194120000023

从上面的代码可以看出graph比eager的执行时间缩短了近3倍。当然,因具体计算内容不同,效率的提升程度也是不同的。

graph虽然能提升运行效率,但是转化graph时也会有代价。对于某些代码,转化graph所需的时间可能比运行graph的还要长。所以在编写代码时要尽量避免graph的重复转化。如果你发现模型的效率很低,那么可以查查是否存在重复转化。可以通过加入print函数来判断是否存在重复转化(还记得前面我们讲过,每次转化graph时就会调用一次print函数)。看下面的代码。

@tf.function
def a_function_with_python_side_effect(x):print("Tracing!") # An eager-only side effect.return x * x + tf.constant(2)print(a_function_with_python_side_effect(tf.constant(2)))
print(a_function_with_python_side_effect(tf.constant(3)))

输出结果:

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
tf.Tensor(11, shape=(), dtype=int32)

可以看出,因为上面两次调用的参数类型是一样的,所以只转化了一次graph,print只被调用了一次。

print(a_function_with_python_side_effect(2))
print(a_function_with_python_side_effect(3))

输出结果:

Tracing!
tf.Tensor(6, shape=(), dtype=int32)
Tracing!
tf.Tensor(11, shape=(), dtype=int32)

上面print被调用了2次。啊?为什么?你可以会表示不解~~上面两个参数的类型是一样的啊,为什么还调用了两次print。因为,输入参数是python类型,对于新的python类型每次都会创建一个新的graph。所以最好是用tenforflow的数据类型作为function的输入参数。

最后我给出tf.function相关的几点建议:

  • 当需要切换eager和graph模式时,应该使用tf.config.run_functions_eagerly来进行明显的标注。

  • 应该在python function的外面创建tenforflow的变量(tf.Variables),在里面修改它们的值。这条建议同样适用于其它那些使用tf.Variables的tenforflow对象(例如keras.layers,keras.Models,tf.optimizers)。

  • 避免函数内部依赖外部定义的python变量。

  • 应该尽量将更多的计算量代码包含在一个tf.function中而不是包含在多个tf.function里,这样可以将代码执行效率最大化。

  • 最好是用tenforflow的数据类型作为function的输入参数。

一文搞懂tf.function相关推荐

  1. 一文搞懂如何使用Node.js进行TCP网络通信

    摘要: 网络是通信互联的基础,Node.js提供了net.http.dgram等模块,分别用来实现TCP.HTTP.UDP的通信,本文主要对使用Node.js的TCP通信部份进行实践记录. 本文分享自 ...

  2. 都2021年了,再不学ES6你就out了 —— 一文搞懂ES6

    JS干货分享 -- 一文搞懂ES6 导语:ES6是什么?用来做什么? 1. let 与 const 2. 解构赋值 3. 模板字符串 4. ES6 函数(升级后更爽) 5. Class类 6. Map ...

  3. 一文搞懂什么是 PostCSS

    一文搞懂什么是 PostCSS 在 Web 应用开发中,CSS 代码的编写是重要的一部分.CSS 规范从最初的 CSS1 到现在的 CSS3,再到 CSS 规范的下一步版本,规范本身一直在不断的发展演 ...

  4. ES6学习——一文搞懂ES6

    ES6学习--一文搞懂ES6 es6介绍 ES全称EcmaScript,是脚本语言的规范,而平时经常编写的EcmaScript的一种实现,所以ES新特性其实就是指JavaScript的新特性. 为什么 ...

  5. 一文搞懂极大似然估计

    极大似然估计,通俗理解来说,就是利用已知的样本结果信息,反推最具有可能(最大概率)导致这些样本结果出现的模型参数值! 换句话说,极大似然估计提供了一种给定观察数据来评估模型参数的方法,即:" ...

  6. 一文搞懂JSON.stringify和JSON.parse(五)JSON.parse使用说明

    一 JSON.parse特性 JSON.parse()就是将JSON字符串解析成字符串描述的JavaScript值或对象(将JSON数据解析为js原生值),例如: JSON.parse('{}'); ...

  7. 一文搞懂Elasticsearch索引的mapping与setting

    目录 Elasticsearch索引结构 Mapping Setting Elasticsearch索引结构 一个Elasticsearch索引的主要结构如下: {"test_index&q ...

  8. 一文搞懂RNN(循环神经网络)

    基础篇|一文搞懂RNN(循环神经网络) https://mp.weixin.qq.com/s/va1gmavl2ZESgnM7biORQg 神经网络基础 神经网络可以当做是能够拟合任意函数的黑盒子,只 ...

  9. 一文搞懂 Python 的 import 机制

    一.前言 希望能够让读者一文搞懂 Python 的 import 机制 1.什么是 import 机制? 通常来讲,在一段 Python 代码中去执行引用另一个模块中的代码,就需要使用 Python ...

最新文章

  1. Spring Cloud云架构 - SSO单点登录之OAuth2.0登录流程(2)
  2. 【生生被气死的一周】头秃
  3. Tomcat性能调优-JVM监控与调优
  4. 学计算机专业需要买电脑么,上大学该买电脑吗?学长:买的时候以为是刚需,买了变成“鸡肋”...
  5. C++const的作用与使用
  6. Oldboy28期linux决心书
  7. xfce4的右键打开终端失效
  8. sql机器学习服务_机器学习服务–在SQL Server中配置R服务
  9. 微信小程序首支视频广告片发布
  10. 算法笔记_083:蓝桥杯练习 合并石子(Java)
  11. 全球异常高温:虾熟了我也要“熟”了
  12. eq, neq.gt,ge,lte,lt,not,mod的含义
  13. 值得您收藏的png图标第二辑
  14. 网上流传ldquo;魔方文化启示录rdquo;
  15. SIPM模拟器 MIPS汇编语言实现读取文件
  16. 什么是Principle?能做什么?
  17. Office 365导出PDF带备注页
  18. 关于a标签的基本用法和特殊用法
  19. ssm毕设项目学生宿舍管理系统15pjb(java+VUE+Mybatis+Maven+Mysql+sprnig)
  20. android studio json格式化,Android json格式化显示,可展开与折叠

热门文章

  1. 逆向思维赚钱法则 真正赚钱的暴利项目
  2. Hive 的 distribute by
  3. Bentley MicroStation CE版的颜色变换(CONNECT Edition)
  4. OpenFOAM动态加密网格的负载平衡
  5. 小谷机器人连不上wifi_小谷连不上网怎么办
  6. 数学建模-数学规划模型
  7. 香港美国CERA机房你怎么选择?
  8. 第06章 Tableau仪表板和故事
  9. 购房指南—新房交房注意事项细节有哪些
  10. narwal机器人_国货之光!云鲸NARWAL扫地机器人国外众筹获第一