Tensorflow函数映射:py_func和map_fn
tf.map_fn
[tf.map_fn]:map on the list of tensors unpacked from elems on dimension 0. 接受一个函数对象,然后用该函数对象对集合(elems)中的每一个元素分别处理,
tf.map_fn(
fn,
elems,
dtype=None,
parallel_iterations=None,
back_prop=True,
swap_memory=False,
infer_shape=True,
name=None
)
参数解析
dtype: (optional) The output type(s) of fn. If fn returns a structure of Tensors differing from the structure of elems, then dtype is not optional and must have the same structure as the output of fn. 如果输入输出类型不一样,这个一定要加上,否则输出类型自动和输入一样,报错。
-柚子皮-
tf.py_func
[tf.py_func]: Wraps a python function and uses it as a TensorFlow op. 用来将 一个 python 函数打包成一个 op。
Note: 杯具的是,如果在estimator,如predict阶段使用py_func,它是不能在后续server中使用的。即export_saved_model后,在predictor.from_saved_model时会出错,说ValueError: callback pyfunc_0 is not found,即estimator不会将py_func拉进静态图中?这时只能放弃py_func,类似下面示例中的解决方案2了。
Tensorflow还是有不足的地方。第一体现在Tensorflow的数据机制,由于tensor只是占位符,在没有用tf.Session().run接口填充值之前是没有实际值的。因此,在网络搭建的时候,是不能对tensor进行判值操作的,即不能插入if...else...之类的代码。第二,相较于numpy array,Tensorflow中对tensor的操作接口灵活性并没有那么高,使得Tensorflow的灵活性减弱。扩展Tensorflow程序的灵活性,有一个重要的手段,就是使用tf.py_func接口。tf.py_func()运算符使您可以在TensorFlow图的中间运行任意Python代码。包装自定义NumPy运算符特别方便,因为没有等效的TensorFlow运算符(尚未存在)。添加tf.py_func()是在图中使用sess.run()调用的替代方法,这样就可以使用任意python函数操作tensor了(一般的py函数不能操作tensor的,只能数值)。
tf.py_func(
func,
inp,
Tout,
stateful=True,
name=None
)
tf.py_func的原理:首先,tf.py_func接收的是tensor,然后将其转化为numpy array送入func函数,最后再将func函数输出的numpy array转化为tensor返回。
参数解析
func: 一个python函数,它将一个Numpy数组组成的list作为输入,该list中的元素的数据类型和inp参数中的tf.Tensor对象的数据类型相对应,同时该函数返回一个Numpy数组组成的list或者单一的Numpy数组,其数据类型和参数Tout中的值相对应
inp: Tensor队形组成的list,即使只有一个tensor也需要使用[tensor]。
Tout: 该函数的返回对象的数据类型。一个tensorflow数据类型组成的list或者tuple(如[tf.string, tf.string]),(如果只有一个返回值,需要单独一个tensorflow数据类型,如tf.string,不要写成[tf.string],这样返回时也会多一维)。
stateful:布尔值,如果该值为True,该函数应被视为与状态有关的。如果一个函数与状态无关,则相同的输入会产生相同的输出,并不会产生明显的副作用。有些优化操作如common subexpression elimination只能在与状态无关的操作中进行。
注意:
1 func函数的返回值类型一定要和Tout指定的tensor类型一致。
2 The body of the function (i.e. func) will not be serialized in a GraphDef. Therefore, you should not use this function if you need to serialize your model and restore it in a different environment. tf.py_func中的func是脱离Graph的,在func中不能定义可训练的参数参与网络训练(反传),或者说无法求导。
3 如果python_func()函数有 string 参数的话,tf会把这个string参数 转换成 bytes 类型。
函数示例
(可能可以使用np解决)一个不好的地方是,如果返回多个数据,有n个数据那必须指定Tout = [tf.string]*n。[Returning mutiple values in the input function for `tf.py_func`]如果输入m个数据,但是Tout = [tf.string]*n就会报错:InvalidArgumentError (see above for traceback): pyfunc returns m values, but expects to see n values.出现这种情况一般发生在tf batch训练时,因为最后一个batch_size是不固定且很可能不等于指定的params['batch_size'],而且这里的n不能直接设置成输入inp.shape[0],因为它不固定,也是返回一个None值,而不是实际的batch大小。
pred_strings = tf.py_func(mlb.inverse_transform, [pred_ids], [tf.string] * params['batch_size'])
pred_strings = tf.convert_to_tensor(pred_strings, dtype=tf.string)
1 一种解决方案是将tf.py_func外加一层tf.map_fn,这样tf.py_func每次都只执行一个数据,Tout = tf.string就可以。
pred_strings = tf.map_fn(lambda x: tf.py_func(mlb.inverse_transform, [tf.expand_dims(x, 0)], tf.string), pred_ids, tf.string)
2 还有一种是在estimator外层的sess中执行,这样pred_ids就不是tensor,而是数值,可以直接使用python函数操作。
from: -柚子皮-
ref:
Tensorflow函数映射:py_func和map_fn相关推荐
- tensorflow函数API总结
tensorflow函数API总结: 首推官网查询 tf.keras.Input:创建输入层 别名: tf.keras.Input tf.keras.layers.Input tf.keras.Inp ...
- C++ 函数映射使用讲解
想想我们在遇到多语句分支时是不是首先想到的是 switc case 和 if else if ... 这2种方式在编码方面确实简单少,但是当分支达到一定数量后,特别是分支内部有嵌套大段代码或者再嵌套分 ...
- tensorflow函数记录
tensorflow函数记录 函数类型一 tf.stack() axis=0,1 tf.reduce_sum() axis=0,1 tf.concat() axis=0,1 功能快捷键 合理的创建标题 ...
- 【TensorFlow】TensorFlow函数精讲之tf.nn.max_pool()和tf.nn.avg_pool()
tf.nn.max_pool()和tf.nn.avg_pool()是TensorFlow中实现最大池化和平均池化的函数,在卷积神经网络中比较核心的方法. 有些和卷积很相似,可以参考TensorFlow ...
- 【TensorFlow】TensorFlow函数精讲之 tf.random_normal()
tf.trandom_normal()函数是生成正太分布随机值 此函数有别于tf.truncated_normal()正太函数,请参考本博客关于tf.truncated_normal()函数的介绍 ( ...
- 【TensorFlow】TensorFlow函数精讲之tf.truncated_normal()
tf.truncated_normal()函数是一种"截断"方式生成正太分布随机值,"截断"意思指生成的随机数值与均值的差不能大于两倍中误差,否则会重新生成. ...
- Linux系统mmap函数映射物理地址
Linux系统mmap函数映射物理地址 代码 64位报错 代码 在某些特殊情况下,我们只是想要读取某个寄存器的值或者某个地址的值,不需要去专门写一个驱动模块来实现,可以使用mmap函数配合/dev/m ...
- 从函数映射的角度理解矩阵
从函数映射的角度理解矩阵 0.预备知识 函数 是把一个集 "A" 的元素与另一个集 "B" 的元素配对的方法: 一般函数从 "A" 的每个元 ...
- tensorflow函数一览
tensorflow函数一览 注意: 本文介绍的是version 1.3.0下的函数,有些函数可能在旧版本中没有. 转载请注明出处变天式的博客 前言:名词解释 tensor 张量,可以是一个数,也可以 ...
- 利用BP网络实现非线性函数映射(基于matlab工具箱)
利用BP网络实现非线性函数映射(基于matlab工具箱) 目录 利用BP网络实现非线性函数映射(基于matlab工具箱) 一.网络结构 二.学习过程 三.学习结果 四.误差分析 五.实验总结 附录(源 ...
最新文章
- 存储过程如何处理异常
- 作业调度算法--短作业优先 操作系统_处理器管理_编程题
- 大数据Java基础第十九天作业
- 纪念9.11十周年 奥巴马诵读圣经原文
- 数据库备份还原顺序关系(环境:Microsoft SQL Server 2008 R2)
- 优酷电视剧爬虫代码实现一:下载解析视频网站页面(3)补充知识点:htmlcleaner使用案例...
- iou画 yolov3_专栏 | 【从零开始学习YOLOv3】4. YOLOv3中的参数进化
- 猜拳游戏php中Computer类,人机猜拳 (玩家、电脑、游戏、测试)四个类写法
- mysql报表慢_mysql慢查询日志报表工具mysqlsla
- Talos实验室深入我国DDoS黑市DuTe 揭露各种DDoS团伙、平台、工具及攻击
- j2ee和mysql怎么连接_Eclipse下配置j2ee开发环境及与MySQL数据库的连接
- 【肿瘤分割】基于matlab聚类乳腺肿瘤图像分割【含Matlab源码 1471期】
- 购物也能乐开花 淘宝搞笑评价集萃--2
- php返回token什么意思,token什么意思
- Nginx的rewrite(地址重定向)剖析
- 算法成华纳旗下歌手?背景音乐经济
- mysql 插入缓冲_innodb insert buffer 插入缓冲区的理解
- 我看技术人的成长路径
- MIPI介绍(CSI DSI接口)
- Android性能优化之APK瘦身详解(瘦身73%)