TensorFlow常用函数tf.where()、tf.gather()、tf.squeeze()详解!!
1.tf.where()
第一种用法:
- where(condition)的用法
where(condition, x=None, y=None, name=None)
condition是bool型值,True/False
返回值,是condition中元素为True对应的索引
例如:
import tensorflow as tf
a = [[1,2,3],[4,5,6]]
b = [[1,0,3],[1,5,1]]
condition1 = [[True,False,False],[False,True,True]]
condition2 = [[True,False,False],[False,True,False]]
with tf.Session() as sess:print(sess.run(tf.where(condition1)))print(sess.run(tf.where(condition2)))
结果1:
[[0 0][1 1][1 2]]
结果2:
[[0 0][1 1]]
第二种用法:
where(condition, x=None, y=None, name=None)
condition是bool型值,True/False, x, y 相同维度,
返回值是对应元素,condition中元素为True的元素替换为x中的元素,为False的元素替换为y中对应元素。
x只负责对应替换True的元素,y只负责对应替换False的元素,x,y各有分工
由于是替换,返回值的维度,和condition,x , y都是相等的。
import tensorflow as tf
x = [[1,2,3],[4,5,6]]
y = [[7,8,9],[10,11,12]]
condition3 = [[True,False,False],[False,True,True]]
condition4 = [[True,False,False],[True,True,False]]
with tf.Session() as sess:print(sess.run(tf.where(condition3,x,y)))print(sess.run(tf.where(condition4,x,y)))
结果:
[[ 1 8 9]
[10 5 6]]
[[ 1 8 9]
[ 4 5 12]]
第三种用法:
tf.where(tf.greater(A, B), a, b)
tf.greater(a,b)
功能:通过比较a、b两个值的大小来输出True和False。
where会先判断第一项是否为true,如果为true则返回a;否则返回b;而greater则是比较A是否大于B,是的话返回true;否则返回false
2.tf.gather()
我们知道,ndarray和list都可以直接通过索引进行切片,但tensor却不行。不过TensorFlow提供了多个函数来进行张量切片,tf.gather()就是其中一种,其调用形式如下:
tf.gather(params, indices, validate_indices=None, name=None, axis=0)
参数:
- params:要进行切片的ndarray或list或tensor等
- indices:索引向量,其类型可以是ndarray、list、tensor等
- axis : 对哪个轴进行切片
函数功能:
从’params’的’axis’维根据’indices’的参数值获取切片。就是在axis维根据indices取某些值,最终得到新的tensor
示例:
1. params 的维数为1
import tensorflow as tf
import numpy as np
# params = np.random.randint(1, 10, 5)
# params = [2, 3, 4, 5, 6, 7]
params = tf.constant([2, 3, 4, 5, 6, 7])
# indices = np.array([2, 1, 4, 2])
# indices = [2, 1, 4, 2]
indices = tf.constant([2, 1, 4, 2])
tensor1 = tf.gather(params, indices)
with tf.Session() as sess:# print(params)print(sess.run(params))print(sess.run(tensor1))
结果:
[2 3 4 5 6 7]
[4 3 6 4]#分析:根据indices逐一取出params中对应索引的元素,并组成新的张量。
2. params 的维数为2
import tensorflow as tf
import numpy as npparams = np.random.randint(1, 10, (4, 5))
indices = tf.constant([2, 1, 0, 2])
tensor0 = tf.gather(params, indices, axis=0)
tensor1 = tf.gather(params, indices, axis=1)
with tf.Session() as sess:print('params =', params)print('tensor0 =', sess.run(tensor0))print('tensor1 =', sess.run(tensor1))
结果:
params = [[5 1 4 7 2][1 8 9 1 7][2 1 8 7 2][8 9 5 8 7]]
tensor0 = [[2 1 8 7 2][1 8 9 1 7][5 1 4 7 2][2 1 8 7 2]]
tensor1 = [[4 1 5 4][9 8 1 9][8 1 2 8][5 9 8 5]]
对于二维params,
当indices是标量且是张量时,得到的结果不会降维;
当indices是标量且是ndarray时,得到的结果会降维。
import tensorflow as tf
import numpy as npparams = np.random.randint(1, 10, (3, 4))
indices1 = tf.constant([2])
indices2 = 2
tensor1 = tf.gather(params, indices1, axis=0)
tensor2 = tf.gather(params, indices2, axis=0)
with tf.Session() as sess:print('params =', params)print('tensor1 =', sess.run(tensor1))print('tensor2 =', sess.run(tensor2))
结果:
params = [[9 2 1 7][7 8 2 3][9 7 2 9]]
tensor1 = [[9 7 2 9]]
tensor2 = [9 7 2 9]
3.tf.squeeze()
tf.squeeze(input, axis=None, name=None, squeeze_dims=None)
该函数返回一个张量,这个张量是将原始input中所有维度为1的那些维都删掉的结果。
axis可以用来指定要删掉的为1的维度,此处要注意指定的维度必须确保其是1,否则会报错。
import tensorflow as tf
input_tensor = tf.ones((2, 1, 1, 3, 2))
new_tensor1 = tf.squeeze(input_tensor)
new_tensor2 = tf.squeeze(input_tensor, [1])
with tf.Session() as sess:print(sess.run(tf.shape(input_tensor)))print(sess.run(tf.shape(new_tensor1)))print(sess.run(tf.shape(new_tensor2)))
结果:
[2 1 1 3 2]
[2 3 2]
[2 1 3 2]
附加:
tf.less()、tf.greater()、tf.equal()等比较函数
这几个函数用于逐元素比较两个张量的大小,并返回比较结果(True or False)构成的布尔型张量。下面以tf.less()为例:
tf.less(x, y, name=None)
tf.less()返回了两个张量各元素比较(x<y)得到的真假值组成的张量。
提示:
- tf.less()支持broadcast机制;
- tf.less(x, y)中的 x 和 y 可以是tensor、ndarray、list等。
x = tf.constant([[1, 2, 3], [4, 5, 6]])
y1 = tf.constant([[2, 1, 2], [2, 6, 7]])
y2 = tf.constant([3, 6, 9])
y3 = tf.constant([3])
with tf.Session() as sess:print(sess.run(tf.less(x, y1)))print(sess.run(tf.less(x, y2)))print(sess.run(tf.less(x, y3)))
结果:
[[ True False False][False True True]]
[[ True True True][False True True]]
[[ True True False][False False False]]
总结:
- tf.less(x, y) —— x < y 为True
- tf.equal(x, y) —— x == y 为True
- tf.greater(x, y) —— x > y 为True
- tf.greater_equal(x, y) —— x >= y 为True
- tf.less_equal(x, y) —— x <= y 为True
TensorFlow常用函数tf.where()、tf.gather()、tf.squeeze()详解!!相关推荐
- viper4android io错误,golang常用库之配置文件解析库-viper使用详解
一.viper简介 viper 配置管理解析库,是由大神 Steve Francia 开发,他在google领导着 golang 的产品开发,他也是 gohugo.io 的创始人之一,命令行解析库 c ...
- python装饰器函数-Python函数装饰器常见使用方法实例详解
本文实例讲述了Python函数装饰器常见使用方法.分享给大家供大家参考,具体如下: 一.装饰器 首先,我们要了解到什么是开放封闭式原则? 软件一旦上线后,对修改源代码是封闭的,对功能的扩张是开放的,所 ...
- python跨函数调用变量_对python中不同模块(函数、类、变量)的调用详解
首先,先介绍两种引入模块的方法. 法一:将整个文件引入 import 文件名 文件名.函数名( ) / 文件名.类名 通过这个方法可以运行另外一个文件里的函数 法二:只引入某个文件中一个类/函数/变量 ...
- python中search和match的区别_Python中正则表达式match()、search()函数及match()和search()的区别详解...
match()和search()都是python中的正则匹配函数,那这两个函数有何区别呢? match()函数只检测RE是不是在string的开始位置匹配, search()会扫描整个string查找 ...
- python函数定义及调用-python函数声明和调用定义及原理详解
这篇文章主要介绍了python函数声明和调用定义及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 函数是指代码片段,可以重复调用,比如我们前 ...
- html5走格子游戏,JS/HTML5游戏常用算法之碰撞检测 地图格子算法实例详解
JS/HTML5游戏常用算法之碰撞检测 地图格子算法实例详解 发布时间:2020-09-26 20:42:24 来源:脚本之家 阅读:112 作者:krapnik 本文实例讲述了JS/HTML5游戏常 ...
- python函数声明和调用定义及原理详解
这篇文章主要介绍了python函数声明和调用定义及原理详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 函数是指代码片段,可以重复调用,比如我们前 ...
- python函数中可变参数的传递方式_详解Python函数可变参数定义及其参数传递方式...
Python函数可变参数定义及其参数传递方式详解 python中 函数不定参数的定义形式如下 1. func(*args) 传入的参数为以元组形式存在args中,如: def func(*args): ...
- 记录 之 tensorflow 常用函数:tf.split(),tf.clip_by_value() 和 tf.cond()
1.tf.split(axis, num_or_size_splits,value) 该函数是通道拆分函数,将原来的的多通道tensor,拆分为单通道 axis:拆分的维度 num_or_size_s ...
最新文章
- Android Studio 多渠道打包
- 第08次:升级《陋习手记》完善主从UI
- Base64加密解密算法的C/C++代码实现
- ora--12154 :TNS :could not resolve the connect identifier specified 错误处理
- cisco路由器故障判断及排除 计算机管理与维护
- nginx 没有cookie_Nginx入门学习(1):一些概念
- android xml反编译原理,记一次resources.arsc文件hex修改原理分析
- java常用设计模式详解及应用
- 10个免费网络管理工具
- 在DW中如何让代码对齐?
- ZFM_RFC_FIDOC-创建财务凭证-BAPI_ACC_DOCUMENT_CHECK/BAPI_ACC_DOCUMENT_POST/POSTING_INTERFACE_DOCUMENT
- 智联招聘 'python数据分析'职位分析第一篇
- 练习题目---光照度
- macos 系统固件 路径_itunes下载固件在哪里 itunes下载固件位置【介绍】
- 快速云:云计算供应商在合同谈判时可能拒绝的三个事项以及要求
- 1.Python下载与安装教程 For Windows
- 3D重建中的相机雷达融合
- python读取文件名或路径含中文字符的图片并从中筛选出全白或者全黑的图片
- Generative Adversarial Networks(WGAN、SAGAN、BigGAN)
- 【IntelliJ IDEA】如何安装汉化插件
热门文章
- error怎么开机 fan_笔记本开机显示fan error怎么解决?
- 肖申克的救赎(转贴)
- 他,生物系毕业,刚入职连Java都没听过,却在马云的要求下,三周写出淘宝网雏形...
- icomoon 下载及使用
- 推荐几个优质的 Python 学习资料(良心推荐!非广告!)
- Spark创建DataFrame
- 栅格地图中自由区域之Bresenham算法及个人搜索算法对比
- 带您了解企业云盘,互联网大数据下的产物
- C++ OJ 出现 Wrong Answer的解决方法:如何把输出结果写入到文件中
- 金蝶eas显示连接服务器超时,金蝶EAS常见问题解答_工具及框架应用_2016