2019独角兽企业重金招聘Python工程师标准>>>

本文讲述如下问题:

1.如何定义list类型的placeholder?

2.如何将普通python函数包装成TensorFlow算子,加入到NN网络中?

具体见代码:

import tensorflow as tf
import numpy as npdef gen_tfrecords():with tf.python_io.TFRecordWriter(r"D:\my.tfrecords") as tf_writer:features = {}features['scale'] = tf.train.Feature(int64_list=tf.train.Int64List(value=[15]))xmin = []xmax = []ymin = []ymax = []for i in range(2):xmin.append(float(i))xmax.append(float(i+500))ymin.append(float(i))ymax.append(float(i+500))# 变长数据以list形式存储features['xmin'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmin))features['xmax'] = tf.train.Feature(float_list=tf.train.FloatList(value=xmax))features['ymin'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymin))features['ymax'] = tf.train.Feature(float_list=tf.train.FloatList(value=ymax))tf_features = tf.train.Features(feature=features)tf_example = tf.train.Example(features=tf_features)tf_serialized = tf_example.SerializeToString()tf_writer.write(tf_serialized)
gen_tfrecords()
def parse_tf(example_proto):dics = {}#定长数据解析dics['scale'] = tf.FixedLenFeature(shape=[], dtype=tf.int64)#列表数据解析dics['xmin'] = tf.VarLenFeature(tf.float32)dics['xmax'] = tf.VarLenFeature(tf.float32)dics['ymin'] = tf.VarLenFeature(tf.float32)dics['ymax'] = tf.VarLenFeature(tf.float32)parse_example = tf.parse_single_example(serialized=example_proto,features=dics)xmin = parse_example['xmin']xmax = parse_example['xmax']ymin = parse_example['ymin']ymax = parse_example['ymax']scale = parse_example['scale']return scale,xmin,xmax,ymin,ymaxdef scale_image(scale):w = 10h = 10w = w*scaleh = h*scalereturn w,h
def scale_image2(scale):w = 10h = 10w = w*scaleh = h*scalelst = [w,h]#如果想要返回一个list,需要将其封装为一个ndarrayreturn np.array(lst)
def image_s(scale):w = 10h = 10w = w*scaleh = h*scalereturn w*hdef calc_image_s(xmin,xmax,ymin,ymax):ss = []for i in range(len(xmin)):s = (xmax[i]-xmin[i])*(ymax[i]-ymin[i])ss.append(s)return np.array(ss)scale_p = tf.placeholder(dtype=tf.int64)
#如果placeholder的shape不写,则可表示各种类型的数据,这里可用于表示list类型的数据
x_min_p = tf.placeholder(dtype=tf.float32)
x_max_p = tf.placeholder(dtype=tf.float32)
y_min_p = tf.placeholder(dtype=tf.float32)
y_max_p = tf.placeholder(dtype=tf.float32)#tf.py_func用来将不同python函数包裹成TensorFlow算子,返回值是tensor,Tout表示函数返回值的类型,单个返回值不用[],多个返回值,要用[]
nw,nh = tf.py_func(scale_image,inp=[scale_p],Tout=[tf.int64,tf.int64])
nw_nh = tf.py_func(scale_image2,inp=[scale_p],Tout=tf.int64)
s = tf.py_func(image_s,inp=[scale_p],Tout=tf.int64)
ss= tf.py_func(calc_image_s,inp=[x_min_p,x_max_p,y_min_p,y_max_p],Tout=tf.float32)two = tf.constant(value=2,dtype=tf.float32)
s2 = tf.multiply(ss,two)dataset = tf.data.TFRecordDataset(r"D:\my.tfrecords")
dataset = dataset.map(parse_tf).batch(1).repeat(1)iterator = dataset.make_one_shot_iterator()next_element = iterator.get_next()
with tf.Session() as session:scale, xmin, xmax, ymin, ymax = session.run(fetches=next_element)w,h = session.run(fetches=[nw,nh],feed_dict={scale_p:scale})print(w,h)w_h = session.run(fetches=[nw_nh], feed_dict={scale_p: scale})print(w_h)s1 = session.run(fetches=[s], feed_dict={scale_p: scale})print(s1)s1 = session.run(fetches=[ss], feed_dict={x_min_p:xmin.values,x_max_p:xmax.values,y_min_p:ymin.values,y_max_p:ymax.values})print(s1)s22 = session.run(fetches=[s2], feed_dict={x_min_p: xmin.values, x_max_p: xmax.values, y_min_p: ymin.values, y_max_p: ymax.values})print(s22)

结果如下:

[150] [150]
[array([[150],[150]], dtype=int64)]
[array([22500], dtype=int64)]
[array([250000., 250000.], dtype=float32)]
[array([500000., 500000.], dtype=float32)]

定义新的op时会用到该方法,据官网介绍,这种做法,目前不支持分布式与模型保存。但是对于辅助op,基本上够用了,例如faster r-cnn中anchor的生成与RPN训练时label的生成。

tf.py_func 要求包裹的函数,输入输出均为ndarray

转载于:https://my.oschina.net/u/3800567/blog/1794223

【TensorFlow系列】【九】利用tf.py_func自定义算子相关推荐

  1. Tensorflow之调试(Debug) tf.py_func()

    Tensorflow之调试(Debug)及打印变量 tensorflow调试tfdbg 几种常用方法: 1.通过Session.run()获取变量的值 2.利用Tensorboard查看一些可视化统计 ...

  2. tensorflow与python交互系列,tf.py_function()、tf.py_func、tf.numpy_function()(一)

    前言:前面在介绍使用tensorflow进行data pipeline的时候,遇到了一些问题,特意整理了两篇文章,请参见: tfrecord文件的map在使用的时候所踩的坑总结(map.py_func ...

  3. Tensorflow利用函数修饰符@tf.custom_gradients自定义函数梯度

    Tensorflow学习笔记(1) 利用函数修饰符@tf.custom_gradients自定义函数梯度_寂乐居士的博客-CSDN博客_tf.custom_gradient python中的修饰符以及 ...

  4. Tensorflow高级API的进阶--利用tf.contrib.learn建立输入函数

    正文共5958个字,预计阅读时间15分钟. 笔记整理者:王小草 笔记整理时间:2017年2月27日 笔记对应的官方文档:https://www.tensorflow.org/get_started/i ...

  5. tensorflow tf.py_func

    tf.py_func 在 faster  rcnn的tensorflow 实现中看到这个函数 1 rois,rpn_scores=tf.py_func(proposal_layer,[rpn_cls_ ...

  6. TF之LiR:利用TF自定义一个线性分类器LiR对乳腺癌肿瘤数据集进行二分类预测(良/恶性)

    TF之LiR:利用TF自定义一个线性分类器LiR对乳腺癌肿瘤数据集进行二分类预测(良/恶性) 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 X_train = np.float32 ...

  7. TF之VGG系列:利用预先编制好的脚本data_convert .py文件将图片格式转换为tfrecord 格式

    TF之VGG系列:利用预先编制好的脚本data_convert .py文件将图片格式转换为tfrecord 格式 目录 转换代码 转换后的结果 转换代码 python data_convert2tfr ...

  8. 【转载】使用tf.py_func函数增加Tensorflow程序的灵活性

    转自:https://blog.csdn.net/jiongnima/article/details/80555387 目录 tf.py_func函数接口 tf.py_func在Faster R-CN ...

  9. Tensorflow深度学习之二十五:tf.py_func

    一.简介 def py_func(func, inp, Tout, stateful=True, name=None)   该函数重构一个python函数,并将其作为一个TensorFlow的op使用 ...

  10. 经验干货:使用tf.py_func函数增加Tensorflow程序的灵活性

    不知不觉,笔者接触Tensorflow也满一年了.在这一年当中,笔者对Tensorflow的了解程度也逐渐加深.相比笔者接触的第一个深度学习框架Caffe而言,笔者认为Tensorflow更适合科研一 ...

最新文章

  1. PTA基础编程题目集-7-4 BCD解密
  2. 5月书讯:藏一个愿望等风来
  3. 理解JSON对象:JSON.parse、 JSON.stringify
  4. GitHub网站使用的基础入门
  5. 【机器视觉】 stop算子
  6. 无法定位程序输入点 在.exe上_win7提示explorer.exe应用程序错误的解决方法
  7. java怎么进行浮点数运算_【考试经验】Java中实现浮点数的精确运算
  8. 【算法】1282. 用户分组(多语言实现)
  9. ESXi虚拟机磁盘格式转换与减小硬盘容量的方法
  10. 基金指数温度怎么算_10分钟学会计算指数温度,挑选指数基金
  11. 非主流闪图头像教程:扩散粒子效果
  12. SAP ABAP——SAP简介(二)【SAP主要产品时间线】
  13. 3.2_backpack_背包问题
  14. 没有伪装和欺骗才能活在当下
  15. JavaScript框架比较:AngularJS vs ReactJS vs EmberJS
  16. Go语言——Json处理
  17. matlab编程实现自适应均值滤波和自适应中值滤波
  18. 计算机网络期末复习总结
  19. 统计字符串中数字字符的个数
  20. Mybatis常见低级错误

热门文章

  1. k8s apollo_AI增强的Apollo 16素材让您以4K登上月球
  2. ai人工智能有哪些_进入AI有多么简单
  3. ai带来的革命_AI革命就在这里。 这与我们预期的不同。
  4. ai/ml_十大ML / AI现实世界项目,以增强您的产品组合
  5. 上海市二级c语言软件环境,上海市2019年9月计算机二级考试复习教程:(C语言)上机考试新版题库+全真模拟试卷(2本装)...
  6. 计算机数字媒体学什么以后,数字媒体设计是学什么的?以后的发展方向是什么?...
  7. 大专适合学习php么_中专毕业上大专好还是出来工作?
  8. java打印jsp_在java中实现对FORM的打印功能
  9. 智能翻译android,离线翻译SDK,让智能小设备如虎添翼
  10. Android中获取IMEI码及其它相关信息的源码