一、gather/gather_nd(已知元素的位置,从张量中提取该元素)

1、tf.gather()函数

tf.gather(params,                    # 传入的tensorindices,             # 指定的索引validate_indices=None,  # 不重要name=None,                # 命名axis=None,             # 指定轴batch_dims=0)

功能:就是抽取出params的第axis维度上在indices里面所有的index

  • params是要查找的张量,indices是要查找值的索引(int32或int64),axis是查找轴,name是操作名。
  • 如果indices是标量,output[a0,...,an,b0,...,bn]=params[a0,...an,indices,b0,...,bn]output[a_0,...,a_n,b_0,...,b_n] = params[a_0,...a_n,indices,b_0,...,b_n]output[a0​,...,an​,b0​,...,bn​]=params[a0​,...an​,indices,b0​,...,bn​];
  • 如果indices是向量,output[a0,...,an,i,b0,...,bn]=params[a0,...an,indices[i],b0,...,bn]output[a_0,...,a_n,i,b_0,...,b_n] = params[a_0,...a_n,indices[i],b_0,...,b_n]output[a0​,...,an​,i,b0​,...,bn​]=params[a0​,...an​,indices[i],b0​,...,bn​];
  • 如果indices是高阶张量,output[a0,...,an,i,...,j,b0,...,bn]=params[a0,...an,indices[i,...,j],b0,...,bn]output[a_0,...,a_n,i,...,j,b_0,...,b_n] = params[a_0,...a_n,indices[i,...,j],b_0,...,b_n]output[a0​,...,an​,i,...,j,b0​,...,bn​]=params[a0​,...an​,indices[i,...,j],b0​,...,bn​]

需要注意的是indices里面最大值需要小等于params在指定的axis下ndim的长度。

如上图所示,params一共6个维度,indices为[2,1,3,4]被取了出来。

该函数返回值类型与params相同,具体值是从params中收集过来的,形状为: params.shape[:axis]+indices.shape+params.shape[axis+1:]params.shape[:axis]+indices.shape+params.shape[axis+1:]params.shape[:axis]+indices.shape+params.shape[axis+1:]

1.1 indices是标量

import numpy as np
import tensorflow as tfc1 = tf.constant(np.random.randint(low=1, high=9, size=6))
print("c1 = ", c1)
print("-" * 100)g1 = tf.gather(c1, indices=2)  # 获取索引为 2 的值
print("g1 = tf.gather(c1, indices=2) = ", g1)
print("-" * 200)

打印结果:

c1 =  tf.Tensor([7 2 8 8 5 4], shape=(6,), dtype=int32)
----------------------------------------------------------------------------------------------------
g1 = tf.gather(c1, indices=2) =  tf.Tensor(8, shape=(), dtype=int32)

1.2 indices是向量

import tensorflow as tfa = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
index_a = tf.Variable([2, 4, 6, 8])b = tf.Variable([[1, 2, 3, 4, 5],[6, 7, 8, 9, 10],[11, 12, 13, 14, 15]
])
index_b = tf.Variable([0, 2])print("a = \n", a)
print("-" * 100)
print("b = \n", b)
print("-" * 200)g_a = tf.gather(a, indices=index_a)
print("g_a = tf.gather(a, indices=index_a) = ", g_a)
print("-" * 200)g1 = tf.gather(b, indices=index_b, axis=0)
print("g1 = tf.gather(b, indices=index_b, axis=0) = ", g1)
print("-" * 100)g2 = tf.gather(b, indices=index_b, axis=1)
print("g2 = tf.gather(b, indices=index_b, axis=1) = ", g2)
print("-" * 200)

打印结果:

a =
<tf.Variable 'Variable:0' shape=(10,) dtype=int32, numpy=array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10])>
----------------------------------------------------------------------------------------------------
b = <tf.Variable 'Variable:0' shape=(3, 5) dtype=int32, numpy=
array([[ 1,  2,  3,  4,  5],[ 6,  7,  8,  9, 10],[11, 12, 13, 14, 15]])>
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g_a = tf.gather(a, indices=index_a) =  tf.Tensor([3 5 7 9], shape=(4,), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
g1 = tf.gather(b, indices=index_b, axis=0) =  tf.Tensor(
[[ 1  2  3  4  5][11 12 13 14 15]], shape=(2, 5), dtype=int32)
----------------------------------------------------------------------------------------------------
g2 = tf.gather(b, indices=index_b, axis=1) =  tf.Tensor(
[[ 1  3][ 6  8][11 13]], shape=(3, 2), dtype=int32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------Process finished with exit code 0

2、tf.gather_nd()函数

根据定义, 其主要功能是根据indices描述的索引,提取params上的元素, 重新构建一个tensor

tf.gather_nd(params,             # 被收集的张量indices,            # 索引张量。必须是以下类型之一:int32,int64。name=None,  # 操作的名称(可选)batch_dims=0)

indices是 KKK 阶张量,包含 K−1K-1K−1 阶的索引值。它最后一阶是索引,最后一阶维度必须小于等于params的秩。

  • indices最后一阶的维数等于params的秩时,我们得到params的某些元素;
  • indices最后一阶的维数小于params的秩时,我们得到params的切片。

在一维数组中,元素的索引即该元素在数组中序号,通常序号从0开始标记。

如数组 ary=[1,2,3,4]:

  • 元素2的索引 为 1, 元素的引用可表示为 [1];
  • 元素3的索引为 2, 元素的引用可表示为 [2];

那么二维数组呢? 类似地,对于二维 ary=[ [1,2], [3,4] ],

  • 元素 [1,2] 在一维中的索引为 [0],
  • 元素 1 的索引 则表示为 [0,0],
  • 元素 2 的索引 则表示为 [0,1],

因此 gather_nd 实现了根据指定的 参数 indices 来提取params 的元素重建出一个tensor,还是以上面的二维数组为例:

  • [0,0] 表示 的是 1;
  • [0,1] 表示的是 2;

当 indices=[[0,0],[0,1]]indices = [[0,0],[0,1]]indices=[[0,0],[0,1]] 时, 该函数的输出则为 [1,2][1,2][1,2],即 indices 中 表示索引的 部分 被提取到的值替换。

那么当indices 为[ [ [ [ [1,1] ] ] ] ] 时 函数输出是什么呢 ? 用元素 替换掉 表示索引的那一部分, 即可得到 [ [ [ [ 4 ] ] ] ]

例如:output[i0,...,iK−2]=params[indices[i0,...iK−2]]\color{blue}{output[i_0,...,i_{K-2}]=params[indices[i_0,...i_{K-2}]]}output[i0​,...,iK−2​]=params[indices[i0​,...iK−2​]]。输出张量的形状由indices的 K−1K-1K−1 阶和 params 索引到的形状拼接而成,形状为:indices.shape[:−1]+params.shape[indices.shape[−1]:]\color{blue}{indices.shape[:-1]+params.shape[indices.shape[-1]:]}indices.shape[:−1]+params.shape[indices.shape[−1]:]

tf.gather和tf.gather_nd都是从tensor中取出index标注的部分,不同之处在于,gather一般只使用一个index来标注,而gather_nd可以使用多个index。

import tensorflow as tfparams = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])gather = tf.constant([0, 2])
gather_nd = tf.constant([[0, 0], [1, 1]])gather_result = tf.gather(params=params, indices=gather)
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)print("gather_result = ", gather_result)
print("-" * 200)
print("gather_nd_result = ", gather_nd_result)

打印结果:

gather_result =  tf.Tensor(
[[b'a' b'b'][b'e' b'f']], shape=(2, 2), dtype=string)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
gather_nd_result =  tf.Tensor([b'a' b'd'], shape=(2,), dtype=string)

如何直观理解gather_nd的indices呢?

  • 在上例中,直观的理解就是,gather_nd取出params中位于[0,0]和[1,1]处的tensor,放入index中对应的位置。
  • 换句话说,除去tensor维之外,返回值的形状和indices相同,值由indices标注。

如果理解了这一点,就可以用gather_nd实现gather:

import tensorflow as tfparams = tf.constant([['a', 'b'], ['c', 'd'], ['e', 'f']])gather_nd = tf.constant([[0], [2]])
gather_nd_result = tf.gather_nd(params=params, indices=gather_nd)print("gather_nd_result = ", gather_nd_result)

打印结果:

gather_nd_result =  tf.Tensor(
[[b'a' b'b'][b'e' b'f']], shape=(2, 2), dtype=string)

2.1 案例01

import tensorflow as tfdata = tf.constant([[1, 2], [3, 4], [5, 6]])
indices = tf.constant([[1], [0], [1]])print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)res = tf.gather_nd(data, indices)print('res =\n', res)

打印结果:

data =tf.Tensor(
[[1 2][3 4][5 6]], shape=(3, 2), dtype=int32)
--------------------------------------------------
indices =tf.Tensor(
[[1][0][1]], shape=(3, 1), dtype=int32)
----------------------------------------------------------------------------------------------------
res =tf.Tensor(
[[3 4][1 2][3 4]], shape=(3, 2), dtype=int32)Process finished with exit code 0

2.2 案例02

import tensorflow as tfdata = tf.constant([[1, 2, 3], [3, 4, 5], [5, 6, 7]])
indices = tf.constant([[1, 0], [0, 2], [1, 2]])print('data =\n', data)
print("-" * 50)
print('indices =\n', indices)
print("-" * 100)res = tf.gather_nd(data, indices)print('res =\n', res)

打印结果:

data =tf.Tensor(
[[1 2 3][3 4 5][5 6 7]], shape=(3, 3), dtype=int32)
--------------------------------------------------
indices =tf.Tensor(
[[1 0][0 2][1 2]], shape=(3, 2), dtype=int32)
----------------------------------------------------------------------------------------------------
res =tf.Tensor([3 3 5], shape=(3,), dtype=int32)Process finished with exit code 0

三、tf.scatter_nd()函数:已知赋值位置,向0张量中赋值

根据indices索引位置将updates中的元素 散布 到新的(初始为零)张量shape中去。

  • 根据索引对给定shape的零张量中的单个值或切片应用稀疏updates来创建新的张量。
  • scatter_nd运算符是 tf.gather_nd 运算符的反函数,tf.gather_nd 运算符是从给定的张量中提取值或切片。

scatter_nd(indices,updates,shape,name=None)

  • indices:一个Tensor;必须是以下类型之一:int32,int64;指数张量。
  • updates:一个Tensor;分散到输出的更新。
  • shape:一个Tensor;必须与indices具有相同的类型;1-d;得到的张量的形状。
  • name:操作的名称(可选)。

警告:更新应用的顺序是非确定性的,所以如果indices包含重复项的话,则输出将是不确定的。

indices是一个整数张量,其中含有索引形成一个新的形状shape张量。indices的最后的维度可以是shape的最多的秩:

indices.shape[-1] <= shape.rank

indices的最后一个维度对应于沿着shape的indices.shape[-1]维度的元素的索引(if indices.shape[-1] = shape.rank)或切片(if indices.shape[-1] < shape.rank)的索引。updates是一个具有如下形状的张量:

indices.shape[:-1] + shape[indices.shape[-1]:]

1、案例01

最简单的分散形式是通过索引将单个元素插入到张量中。例如,假设我们想要在8个元素的1级张量中插入4个分散的元素。

import tensorflow as tfindices = tf.constant([[4], [3], [1], [7]])
updates = tf.constant([9, 10, 11, 12])
shape = tf.constant([8])output= tf.scatter_nd(indices, updates, shape)
print("output = ", output)

打印结果:

output =  tf.Tensor([ 0 11  0 10  9  0  0 12], shape=(8,), dtype=int32)

2、案例02

我们也可以一次插入一个更高阶张量的整个片。例如,如果我们想要在具有两个新值的矩阵的第三维张量中插入两个切片。

在Python中,这个分散操作看起来像这样:

import tensorflow as tfindices = tf.constant([[0], [2]])
updates = tf.constant([[[5, 5, 5, 5], [6, 6, 6, 6],[7, 7, 7, 7], [8, 8, 8, 8]],[[5, 5, 5, 5], [6, 6, 6, 6],[7, 7, 7, 7], [8, 8, 8, 8]]])
shape = tf.constant([4, 4, 4])
print("shape_zero = ", tf.zeros([4, 4, 4]))
print("-" * 200)output = tf.scatter_nd(indices, updates, shape)
print("output = ", output)

打印结果:

shape_zero =  tf.Tensor(
[[[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]][[0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.][0. 0. 0. 0.]]], shape=(4, 4, 4), dtype=float32)
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
output =  tf.Tensor(
[[[5 5 5 5][6 6 6 6][7 7 7 7][8 8 8 8]][[0 0 0 0][0 0 0 0][0 0 0 0][0 0 0 0]][[5 5 5 5][6 6 6 6][7 7 7 7][8 8 8 8]][[0 0 0 0][0 0 0 0][0 0 0 0][0 0 0 0]]], shape=(4, 4, 4), dtype=int32)Process finished with exit code 0



参考资料:
[tensorflow] tf.gather使用方法
tf.gather()函数详解
tf.gather( )的用法
tf.gather_nd和tf.gather的区别与联系
TensorFlow中gather, gather_nd, scatter, scatter_nd用法浅析
TensorFlow学习(三):tf.scatter_nd函数
Tensorflow (一): scatter_nd 与 gather_nd
Python tensorflow.scatter_nd方法代码示例
TensorFlow中tf.gather()函数的使用讲解
tf.gather_nd 用法
深度理解tf.gather和tf.gather_nd的用法
Python tensorflow.gather_nd()用法及代码示例

TensorFlow2-高阶操作(八):gather/gather_nd(已知元素的位置,从张量中提取该元素)、scatter_nd/scatter_nd_update(已知赋值位置,向0张量中赋值)相关推荐

  1. TensorFlow2 入门指南 | 06 TensorFLow2 高阶操作汇总

    前言: 本专栏在保证内容完整性的基础上,力求简洁,旨在让初学者能够更快地.高效地入门TensorFlow2 深度学习框架.如果觉得本专栏对您有帮助的话,可以给一个小小的三连,各位的支持将是我创作的最大 ...

  2. 深度学习(17)TensorFlow高阶操作六: 高阶OP

    深度学习(17)TensorFlow高阶操作六: 高阶OP 1. Where(tensor) 2. where(cond, A, B) 3. 1-D scatter_nd 4. 2-D scatter ...

  3. 深度学习(14)TensorFlow高阶操作三: 张量排序

    深度学习(14)TensorFlow高阶操作三: 张量排序 一. Sort, argsort 1. 一维Tensor 2. 多维Tensor 二. Top_k 三. Top-k accuracy(To ...

  4. Tensorflow学习四---高阶操作

    Tensorflow学习四-高阶操作 Merge and split 1.tf.concat 拼接 a = tf.ones([4,32,8]) b = tf.ones([2,32,8]) print( ...

  5. 深度学习(16)TensorFlow高阶操作五: 张量限幅

    深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...

  6. 深度学习(15)TensorFlow高阶操作四: 填充与复制

    深度学习(15)TensorFlow高阶操作四: 填充与复制 1. Pad 2. 常用于Image Padding 3. tile 4. tile VS broadcast_to Outline pa ...

  7. 深度学习(12)TensorFlow高阶操作一: 合并与分割

    深度学习(12)TensorFlow高阶操作一: 合并与分割 1. concat 2. stack: create new dim 3. Dim mismatch 4. unstuck 5. spli ...

  8. zotero文献管理高阶操作|全网最新最全的zotero高效运用技巧,quicker动作大盘点

    大家能都知道quicker是电脑端的效率神器,可以然你解放双手,提高你的办公效率,当然,也是科研效率神器,截止2022年10月16日,与zotero相关的动作已经有66个.今天就让让我盘点几个可以让你 ...

  9. PS高阶操作之字体特效

    PS高阶操作之字体特效 字体冰封效果 字体金属样式 字体冰封效果 新建一个白色的像素画布. 新建好画布后,用油漆桶刷成深蓝色. 打开通道,新建一个通道,选择文字工具. 调整文字大小和位置.在菜单栏中选 ...

最新文章

  1. PHP简单封装MysqlHelper类
  2. 《ADO.NET 2.0高级程序设计》读书随笔(1)使用连接池connection pool
  3. 用 Hadoop 进行分布式并行编程, 第 3 部分 部署到分布式环境
  4. 关于input的change事件触发多次发解决
  5. python 数据库表结构转为类_顺序表数据结构在python中的应用
  6. python梯度下降法实现线性回归_梯度下降法的python代码实现(多元线性回归)
  7. POJChallengeRound2 Guideposts 【单位根反演】【快速幂】
  8. Log4net数据表
  9. mysql connector net 6.9.3_MySQL Connector/Net 6.9.3 发布 MySQL Connector/Net 6.9.3下载
  10. linux 如何查看属性,linux 下查看系统属性
  11. 这样查看告警邮件要慢一点……
  12. java提前多久显示,Java当前日期/时间比原始时间提前1小时显示
  13. python程序内存分析_Python中使用MELIAE分析程序内存占用实例
  14. 贪心算法3——加油站问题
  15. 爱分享 IE地址栏显示空白?360电脑救援巧修复
  16. AppScan安全扫描问题解决方案
  17. 自编小程序,保持编程达人眼睛
  18. 基于OpenCV 的车道线检测方法
  19. 在控制台,打印出某个具体的变量,并监听其变化
  20. 《开源软件开发导论》作业1

热门文章

  1. 网络走红,小偷反扒秘籍
  2. 为数据披上隐形“斗篷”,如何收回部分数据隐私?
  3. python爬虫之-斗图网爬取
  4. SLAM小车系统配置与软件安装过程
  5. AV视频输出接口类型
  6. 一次线程被挂起问题排查
  7. Windows系统时间不能修改的解决办法!
  8. C#MVC中Controler的自定义属性使用
  9. 图像分割-综述2020.3.1
  10. 怎么改m4r格式?我来教你几招