TensorFlow2-高阶操作(八):gather/gather_nd(已知元素的位置,从张量中提取该元素)、scatter_nd/scatter_nd_update(已知赋值位置,向0张量中赋值)
一、gather/gather_nd(已知元素的位置,从张量中提取该元素)
1、tf.gather()函数
tf.gather(params, # 传入的tensorindices, # 指定的索引validate_indices=None, # 不重要name=None, # 命名axis=None, # 指定轴batch_dims=0)
功能:就是抽取出params的第axis维度上在indices里面所有的index
- params是要查找的张量,indices是要查找值的索引(int32或int64),axis是查找轴,name是操作名。
- 如果indices是标量,output[a0,...,an,b0,...,bn]=params[a0,...an,indices,b0,...,bn]output[a_0,...,a_n,b_0,...,b_n] = params[a_0,...a_n,indices,b_0,...,b_n]output[a0,...,an,b0,...,bn]=params[a0,...an,indices,b0,...,bn];
- 如果indices是向量,output[a0,...,an,i,b0,...,bn]=params[a0,...an,indices[i],b0,...,bn]output[a_0,...,a_n,i,b_0,...,b_n] = params[a_0,...a_n,indices[i],b_0,...,b_n]output[a0,...,an,i,b0,...,bn]=params[a0,...an,indices[i],b0,...,bn];
- 如果indices是高阶张量,output[a0,...,an,i,...,j,b0,...,bn]=params[a0,...an,indices[i,...,j],b0,...,bn]output[a_0,...,a_n,i,...,j,b_0,...,b_n] = params[a_0,...a_n,indices[i,...,j],b_0,...,b_n]output[a0,...,an,i,...,j,b0,...,bn]=params[a0,...an,indices[i,...,j],b0,...,bn]
需要注意的是indices里面最大值需要小等于params在指定的axis下ndim的长度。
如上图所示,params一共6个维度,indices为[2,1,3,4]被取了出来。
该函数返回值类型与params相同,具体值是从params中收集过来的,形状为: params.shape[:axis]+indices.shape+params.shape[axis+1:]params.shape[:axis]+indices.shape+params.shape[axis+1:]params.shape[:axis]+indices.shape+params.shape[axis+1:]
1.1 indices是标量
import numpy as np
import tensorflow as tfc1 = tf.constant(np.random.randint(low=1, high=9, size=6))
print("c1 = ", c1)
print("-" * 100)g1 = tf.gather(c1, indices=2) # 获取索引为 2 的值
print("g1 = tf.gather(c1, indices=2) = ", g1)
print("-" * 200)
打印结果:
c1 = tf.Tensor([7 2 8 8 5 4], shape=(6,), dtype=int32)
----------------------------------------------------------------------------------------------------
g1 = tf.gather(c1, indices=2) = tf.Tensor(8, shape=(), dtype=int32)
1.2 indices是向量
import tensorflow as tfa = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_a = tf.Variable([2, 4, 6, 8])b = tf.Variable([[1, 2, 3, 4, 5],[6, 7, 8, 9, 10],[11, 12, 13, 14, 15]
])
index_b = tf.Variable([0, 2])print("a = \n", a)
print("-" * 100)
print("b = \n", b)
print("-" * 200)g_a = tf.gather(a, indices=index_a)
print("g_a = tf.gather(a, indices=index_a) = ", g_a)
print("-" * 200)g1 = tf.gather(b, indices=index_b, axis=0)
print("g1 = tf.gather(b, indices=index_b, axis=0) = ", g1)
print("-" * 100)g2 = tf.gather(b, indices=index_b, axis=1)
print("g2 = tf.gather(b, indices=index_b, axis=1) = ", g2)
print("-" * 200)
打印结果:
a =
<tf.Variable 'Variable:0' shape=(10,) dtype=int32, numpy=array([ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])>
----------------------------------------------------------------------------------------------------
b = <tf.Variable 'Variable:0' shape=(3, 5) dtype=int32, numpy=
array([[ 1, 2, 3, 4, 5],[ 6, 7, 8, 9, 10],[11, 12, 13, 14, 15]])>
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g_a = tf.gather(a, indices=index_a) = tf.Tensor([3 5 7 9], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g1 = tf.gather(b, indices=index_b, axis=0) = tf.Tensor(
[[ 1 2 3 4 5][11 12 13 14 15]], shape=(2, 5), dtype=int32)
----------------------------------------------------------------------------------------------------
g2 = tf.gather(b, indices=index_b, axis=1) = tf.Tensor(
[[ 1 3][ 6 8][11 13]], shape=(3, 2), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Process finished with exit code 0
2、tf.gather_nd()函数
根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor
tf.gather_nd(params, # 被收集的张量indices, # 索引张量。必须是以下类型之一:int32,int64。name=None, # 操作的名称(可选)batch_dims=0)
indices是 KKK 阶张量,包含 K−1K-1K−1 阶的索引值。它最后一阶是索引,最后一阶维度必须小于等于params的秩。
- indices最后一阶的维数等于params的秩时,我们得到params的某些元素;
- indices最后一阶的维数小于params的秩时,我们得到params的切片。
在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记。
如数组 ary=[1,2,3,4]:
- 元素2的索引 为 1, 元素的引用可表示为 [1];
- 元素3的索引为 2, 元素的引用可表示为 [2];
那么二维数组呢? 类似地,对于二维 ary=[ [1,2], [3,4] ],
- 元素 [1,2] 在一维中的索引为 [0],
- 元素 1 的索引 则表示为 [0,0],
- 元素 2 的索引 则表示为 [0,1],
因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,还是以上面的二维数组为例:
- [0,0] 表示 的是 1;
- [0,1] 表示的是 2;
当 indices=[[0,0],[0,1]]indices = [[0,0],[0,1]]indices=[[0,0],[0,1]] 时, 该函数的输出则为 [1,2][1,2][1,2],即 indices 中 表示索引的 部分 被提取到的值替换。
那么当indices 为[ [ [ [ [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4 ] ] ] ]
例如:output[i0,...,iK−2]=params[indices[i0,...iK−2]]\color{blue}{output[i_0,...,i_{K-2}]=params[indices[i_0,...i_{K-2}]]}output[i0,...,iK−2]=params[indices[i0,...iK−2]]。输出张量的形状由indices的 K−1K-1K−1 阶和 params 索引到的形状拼接而成,形状为:indices.shape[:−1]+params.shape[indices.shape[−1]:]\color{blue}{indices.shape[:-1]+params.shape[indices.shape[-1]:]}indices.shape[:−1]+params.shape[indices.shape[−1]:]
tf.gather和tf.gather_nd都是从tensor中取出index标注的部分,不同之处在于,gather一般只使用一个index来标注,而gather_nd可以使用多个index。
import tensorflow as tfparams = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])gather = tf.constant([0, 2])
gather_nd = tf.constant([[0, 0], [1, 1]])gather_result = tf.gather(params=params, indices=gather)
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)print("gather_result = ", gather_result)
print("-" * 200)
print("gather_nd_result = ", gather_nd_result)
打印结果:
gather_result = tf.Tensor(
[[b'a' b'b'][b'e' b'f']], shape=(2, 2), dtype=string)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
gather_nd_result = tf.Tensor([b'a' b'd'], shape=(2,), dtype=string)
如何直观理解gather_nd的indices呢?
- 在上例中,直观的理解就是,gather_nd取出params中位于[0,0]和[1,1]处的tensor,放入index中对应的位置。
- 换句话说,除去tensor维之外,返回值的形状和indices相同,值由indices标注。
如果理解了这一点,就可以用gather_nd实现gather:
import tensorflow as tfparams = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])gather_nd = tf.constant([[0], [2]])
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)print("gather_nd_result = ", gather_nd_result)
打印结果:
gather_nd_result = tf.Tensor(
[[b'a' b'b'][b'e' b'f']], shape=(2, 2), dtype=string)
2.1 案例01
import tensorflow as tfdata = tf.constant([[1, 2], [3, 4], [5, 6]])
indices = tf.constant([[1], [0], [1]])print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)res = tf.gather_nd(data, indices)print('res =\n', res)
打印结果:
data =tf.Tensor(
[[1 2][3 4][5 6]], shape=(3, 2), dtype=int32)
--------------------------------------------------
indices =tf.Tensor(
[[1][0][1]], shape=(3, 1), dtype=int32)
----------------------------------------------------------------------------------------------------
res =tf.Tensor(
[[3 4][1 2][3 4]], shape=(3, 2), dtype=int32)Process finished with exit code 0
2.2 案例02
import tensorflow as tfdata = tf.constant([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
indices = tf.constant([[1, 0], [0, 2], [1, 2]])print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)res = tf.gather_nd(data, indices)print('res =\n', res)
打印结果:
data =tf.Tensor(
[[1 2 3][3 4 5][5 6 7]], shape=(3, 3), dtype=int32)
--------------------------------------------------
indices =tf.Tensor(
[[1 0][0 2][1 2]], shape=(3, 2), dtype=int32)
----------------------------------------------------------------------------------------------------
res =tf.Tensor([3 3 5], shape=(3,), dtype=int32)Process finished with exit code 0
三、tf.scatter_nd()函数:已知赋值位置,向0张量中赋值
根据indices索引位置将updates中的元素 散布 到新的(初始为零)张量shape中去。
- 根据索引对给定shape的零张量中的单个值或切片应用稀疏updates来创建新的张量。
- scatter_nd运算符是 tf.gather_nd 运算符的反函数,tf.gather_nd 运算符是从给定的张量中提取值或切片。
scatter_nd(indices,updates,shape,name=None)
- indices:一个Tensor;必须是以下类型之一:int32,int64;指数张量。
- updates:一个Tensor;分散到输出的更新。
- shape:一个Tensor;必须与indices具有相同的类型;1-d;得到的张量的形状。
- name:操作的名称(可选)。
警告:更新应用的顺序是非确定性的,所以如果indices包含重复项的话,则输出将是不确定的。
indices是一个整数张量,其中含有索引形成一个新的形状shape张量。indices的最后的维度可以是shape的最多的秩:
indices.shape[-1] <= shape.rank
indices的最后一个维度对应于沿着shape的indices.shape[-1]维度的元素的索引(if indices.shape[-1] = shape.rank)或切片(if indices.shape[-1] < shape.rank)的索引。updates是一个具有如下形状的张量:
indices.shape[:-1] + shape[indices.shape[-1]:]
1、案例01
最简单的分散形式是通过索引将单个元素插入到张量中。例如,假设我们想要在8个元素的1级张量中插入4个分散的元素。
import tensorflow as tfindices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])output= tf.scatter_nd(indices, updates, shape)
print("output = ", output)
打印结果:
output = tf.Tensor([ 0 11 0 10 9 0 0 12], shape=(8,), dtype=int32)
2、案例02
我们也可以一次插入一个更高阶张量的整个片。例如,如果我们想要在具有两个新值的矩阵的第三维张量中插入两个切片。
在Python中,这个分散操作看起来像这样:
import tensorflow as tfindices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],[7, 7, 7, 7], [8, 8, 8, 8]],[[5, 5, 5, 5], [6, 6, 6, 6],[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
print("shape_zero = ", tf.zeros([4, 4, 4]))
print("-" * 200)output = tf.scatter_nd(indices, updates, shape)
print("output = ", output)
打印结果:
shape_zero = tf.Tensor(
[[[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]]], shape=(4, 4, 4), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
output = tf.Tensor(
[[[5 5 5 5][6 6 6 6][7 7 7 7][8 8 8 8]][[0 0 0 0][0 0 0 0][0 0 0 0][0 0 0 0]][[5 5 5 5][6 6 6 6][7 7 7 7][8 8 8 8]][[0 0 0 0][0 0 0 0][0 0 0 0][0 0 0 0]]], shape=(4, 4, 4), dtype=int32)Process finished with exit code 0
参考资料:
[tensorflow] tf.gather使用方法
tf.gather()函数详解
tf.gather( )的用法
tf.gather_nd和tf.gather的区别与联系
TensorFlow中gather, gather_nd, scatter, scatter_nd用法浅析
TensorFlow学习(三):tf.scatter_nd函数
Tensorflow (一): scatter_nd 与 gather_nd
Python tensorflow.scatter_nd方法代码示例
TensorFlow中tf.gather()函数的使用讲解
tf.gather_nd 用法
深度理解tf.gather和tf.gather_nd的用法
Python tensorflow.gather_nd()用法及代码示例
TensorFlow2-高阶操作(八):gather/gather_nd(已知元素的位置,从张量中提取该元素)、scatter_nd/scatter_nd_update(已知赋值位置,向0张量中赋值)相关推荐
- TensorFlow2 入门指南 | 06 TensorFLow2 高阶操作汇总
前言: 本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地.高效地入门TensorFlow2 深度学习框架.如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大 ...
- 深度学习(17)TensorFlow高阶操作六: 高阶OP
深度学习(17)TensorFlow高阶操作六: 高阶OP 1. Where(tensor) 2. where(cond, A, B) 3. 1-D scatter_nd 4. 2-D scatter ...
- 深度学习(14)TensorFlow高阶操作三: 张量排序
深度学习(14)TensorFlow高阶操作三: 张量排序 一. Sort, argsort 1. 一维Tensor 2. 多维Tensor 二. Top_k 三. Top-k accuracy(To ...
- Tensorflow学习四---高阶操作
Tensorflow学习四-高阶操作 Merge and split 1.tf.concat 拼接 a = tf.ones([4,32,8]) b = tf.ones([2,32,8]) print( ...
- 深度学习(16)TensorFlow高阶操作五: 张量限幅
深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...
- 深度学习(15)TensorFlow高阶操作四: 填充与复制
深度学习(15)TensorFlow高阶操作四: 填充与复制 1. Pad 2. 常用于Image Padding 3. tile 4. tile VS broadcast_to Outline pa ...
- 深度学习(12)TensorFlow高阶操作一: 合并与分割
深度学习(12)TensorFlow高阶操作一: 合并与分割 1. concat 2. stack: create new dim 3. Dim mismatch 4. unstuck 5. spli ...
- zotero文献管理高阶操作|全网最新最全的zotero高效运用技巧,quicker动作大盘点
大家能都知道quicker是电脑端的效率神器,可以然你解放双手,提高你的办公效率,当然,也是科研效率神器,截止2022年10月16日,与zotero相关的动作已经有66个.今天就让让我盘点几个可以让你 ...
- PS高阶操作之字体特效
PS高阶操作之字体特效 字体冰封效果 字体金属样式 字体冰封效果 新建一个白色的像素画布. 新建好画布后,用油漆桶刷成深蓝色. 打开通道,新建一个通道,选择文字工具. 调整文字大小和位置.在菜单栏中选 ...
最新文章
- PHP简单封装MysqlHelper类
- 《ADO.NET 2.0高级程序设计》读书随笔(1)使用连接池connection pool
- 用 Hadoop 进行分布式并行编程, 第 3 部分 部署到分布式环境
- 关于input的change事件触发多次发解决
- python 数据库表结构转为类_顺序表数据结构在python中的应用
- python梯度下降法实现线性回归_梯度下降法的python代码实现(多元线性回归)
- POJChallengeRound2 Guideposts 【单位根反演】【快速幂】
- Log4net数据表
- mysql connector net 6.9.3_MySQL Connector/Net 6.9.3 发布 MySQL Connector/Net 6.9.3下载
- linux 如何查看属性,linux 下查看系统属性
- 这样查看告警邮件要慢一点……
- java提前多久显示,Java当前日期/时间比原始时间提前1小时显示
- python程序内存分析_Python中使用MELIAE分析程序内存占用实例
- 贪心算法3——加油站问题
- 爱分享 IE地址栏显示空白?360电脑救援巧修复
- AppScan安全扫描问题解决方案
- 自编小程序,保持编程达人眼睛
- 基于OpenCV 的车道线检测方法
- 在控制台,打印出某个具体的变量,并监听其变化
- 《开源软件开发导论》作业1