在神经网络运算过程中,维度变换是最核心的张量操作,通过维度变换可以将数据任意地切换形式,满足不同场合的运算需求。

维度变换的一个例子:

Y = X@W + b

X 的 shape 为[2,4]
W 的 shape 为[4,3]
X@W的运算张量shape 为[2,3]
偏置b的张量为[3]

不同 shape 的 2 个张量怎么直接相加呢?

需要将shape为[3]的偏置b按样本数量复制一份,变成矩阵形式B,

这样就可以将偏置b和X@W相加了。

基本的维度变换包含了改变视图 reshape,插入新维度 expand_dims,删除维度squeeze,交换维度 transpose,复制数据 tile 等。

1.Reshape

import tensorflow as tf
x = tf.range(96)
x = tf.reshape(x, [2, 4, 4, 3])
x<tf.Tensor: id=17, shape=(2, 4, 4, 3), dtype=int32, numpy=
array([[[[ 0,  1,  2],[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]],[[12, 13, 14],[15, 16, 17],[18, 19, 20],[21, 22, 23]],[[24, 25, 26],[27, 28, 29],[30, 31, 32],[33, 34, 35]],[[36, 37, 38],[39, 40, 41],[42, 43, 44],[45, 46, 47]]],[[[48, 49, 50],[51, 52, 53],[54, 55, 56],[57, 58, 59]],[[60, 61, 62],[63, 64, 65],[66, 67, 68],[69, 70, 71]],[[72, 73, 74],[75, 76, 77],[78, 79, 80],[81, 82, 83]],[[84, 85, 86],[87, 88, 89],[90, 91, 92],[93, 94, 95]]]])>

在存储数据时,内存并不支持这个维度层级概念,只能以平铺方式按序写入内存,数据在创建时按着初始的维度顺序写入,改变张量的视图仅仅是改变了张量的理解方式,并不会改变张量的存储顺序,这在一定程度上是从计算效率考虑的。

在 TensorFlow 中,可以通过张量的 ndim 和 shape 成员属性获得张量的维度数和形状:

x.ndim, x.shape(4, TensorShape([2, 4, 4, 3]))

通过 tf.reshape(x, new_shape),可以将张量的视图进行任意的合法的改变:

tf.reshape(x, [2, -1])<tf.Tensor: id=25, shape=(2, 48), dtype=int32, numpy=
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15,16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31,32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],[48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63,64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79,80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]])>

其中的参数-1 表示当前轴上长度需要根据视图总元素不变的法则自动推导,从而方便用户书写。

尽管我们可以有很多种方式去任意合法地改变维度,但是需要意识到,张量的存储顺序始终没有改变,数据在内存中仍然是按着初始写入的顺序0,1,2, … ,95保存的。

2.增删维度

增加维度 增加一个长度为 1 的维度相当于给原有的数据增加一个新维度的概念,维度长度为 1,故数据并不需要改变,仅仅是改变数据的理解方式。

x = tf.random.uniform([28,28],maxval=10,dtype=tf.int32)
x<tf.Tensor: id=61, shape=(28, 28), dtype=int32, numpy=
array([[4, 2, 2, 9, 1, 7, 8, 7, 2, 3, 8, 2, 9, 9, 1, 3, 2, 0, 4, 1, 0, 2,

通过 tf.expand_dims(x, axis)可在指定的 axis 轴前可以插入一个新的维度:

x = tf.expand_dims(x,axis=2)
x<tf.Tensor: id=63, shape=(28, 28, 1), dtype=int32, numpy=
array([[[4],[2],[2],[9],[1],

tf.expand_dims 的 axis 为正时,表示在当前维度之前插入一个新维度;为负时,表示当前维度之后插入一个新的维度。

删除维度 是增加维度的逆操作,与增加维度一样,删除维度只能删除长度为 1 的维
度,也不会改变张量的存储。

axis 参数为待删除的维度的索引号。

x = tf.expand_dims(x,axis=2)
x<tf.Tensor: id=73, shape=(28, 28, 1), dtype=int32, numpy=
array([[[1],[8],[4],x = tf.squeeze(x, axis=2)
x<tf.Tensor: id=71, shape=(28, 28), dtype=int32, numpy=
array([[1, 8, 4, 2, 1, 5, 1, 9, 0, 4, 5, 6, 7, 9, 8, 0, 5, 2, 7, 0, 3, 9,

如果不指定维度参数 axis,即 tf.squeeze(x),那么他会默认删除所有长度为 1 的维度。

3.交换维度

交换维度操作是非常常见的,比如在 TensorFlow 中,图片张量的默认存储格式是通道后行格式:[b, h, w, c],但是部分库的图片格式是通道先行:[b, c, h, w],因此需要完成[b, h, w, c]到[b, c, h, w]维度交换运算。

比如一张图片地shape是[2, 32, 32, 3] 图片数量,行,列,通道数的维度索引分别为0, 1, 2, 3, 如果需要交换为[2, 3, 32, 32], 则新维度的排序为图片数量,通道数, 行,列, 对应的索引号为[0, 3, 1, 2]。

x = tf.random.normal([2, 32, 32, 3])
x<tf.Tensor: id=85, shape=(2, 32, 32, 3), dtype=float32, numpy=
array([[[[ 0.9281655 , -0.87817407, -0.7339403 ],[-1.0697994 ,  1.8129319 , -0.25105947],[ 1.0635214 ,  0.26829538, -0.578875  ],...,tf.transpose(x, perm=[0, 3, 1, 2])<tf.Tensor: id=87, shape=(2, 3, 32, 32), dtype=float32, numpy=
array([[[[ 0.9281655 , -1.0697994 ,  1.0635214 , ...,  0.04531553,

需要注意的是,通过 tf.transpose 完成维度交换后,张量的存储顺序已经改变,视图也随之改变,后续的所有操作必须基于新的存续顺序进行。

4.数据复制

可以通过tf.tile(x, multiples)函数完成数据在指定维度上的复制操作,multiples 分别指定了每个维度上面的复制倍数,对应位置为 1 表明不复制,为 2 表明新长度为原来的长度的 2 倍,即数据复制一份,以此类推。

import tensorflow as tf
x = tf.range(4)
x = tf.reshape(x, [2, 2])
x<tf.Tensor: id=125, shape=(2, 2), dtype=int32, numpy=
array([[0, 1],[2, 3]])># 在列维度复制1份数据 2就是列维度大小扩大为原来两倍 1表示行维度不变
x = tf.tile(x, multiples=[1, 2])
x<tf.Tensor: id=143, shape=(2, 4), dtype=int32, numpy=
array([[0, 1, 0, 1],[2, 3, 2, 3]])># 然后在行维度复制一份 2就是行维度大小扩大为原来两倍 1表示列维度不变
x = tf.tile(x, multiples=[2, 1])
x<tf.Tensor: id=145, shape=(4, 4), dtype=int32, numpy=
array([[0, 1, 0, 1],[2, 3, 2, 3],[0, 1, 0, 1],[2, 3, 2, 3]])>

再举一个例子:

import tensorflow as tf
x = tf.range(4)
x = tf.reshape(x, [1, 4])
x<tf.Tensor: id=231, shape=(1, 4), dtype=int32, numpy=array([[0, 1, 2, 3]])>x = tf.tile(x, multiples=[4, 1])
x<tf.Tensor: id=233, shape=(4, 4), dtype=int32, numpy=
array([[0, 1, 2, 3],[0, 1, 2, 3],[0, 1, 2, 3],[0, 1, 2, 3]])>x = tf.range(4)
x = tf.reshape(x, [1, 4])
x<tf.Tensor: id=231, shape=(1, 4), dtype=int32, numpy=array([[0, 1, 2, 3]])>x = tf.tile(x, multiples=[1, 4])
x<tf.Tensor: id=235, shape=(4, 16), dtype=int32, numpy=
array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]])>

需要注意的是,tf.tile 会创建一个新的张量来保存复制后的张量,由于复制操作涉及到大量数据的读写 IO 运算,计算代价相对较高。

TensorFlow张量的维度变换相关推荐

  1. Tensorflow张量和维度概念的理解

    Tensorflow张量和维度概念的理解 理解tensorflow张量的概念:张量就是一个数据存储容器,一种数据结构,是人为定义的.因为在计算机内存中哪里有什么2维空间3维空间,都是一块块连续的内存区 ...

  2. 深度学习_TensorFlow2.0基础_张量创建,运算,维度变换,采样

    Tensorflow2.0 基础 一:TensorFlow特性 1.TensorFlow An end-to-end open source machine learning platform end ...

  3. 深度学习(8)TensorFlow基础操作四: 维度变换

    深度学习(8)TensorFlow基础操作四: 维度变换 1. View 2. 示例 3. Reshape操作可能会导致潜在的bug 4. tf.transpose 5. Squeeze VS Exp ...

  4. TensorFlow——维度变换与Broadcasting

    TensorFlow 维度变换 文章目录 TensorFlow 维度变换 一.Reshape 二.tf.transpose 三.Squeeze和Expand_dims Broadcasting 前言 ...

  5. 一文带你读懂深度学习中的张量(tensor)是什么,它的运算是怎样的,如何理解张量,张量的维度,浅显易懂

    深度学习的数学基础(不要被吓到,很浅显) 数据表示与张量运算 张量 在多维 Numpy 数组中,也叫张量(tensor).一般来说,当前所有机器学习系统都使用张量作为基本数据结构. 张量这一概念的核心 ...

  6. Tensorflow:张量(Tensor)的创建及其基础操作

    Tensorflow版本:V2.8.0   Tensorflow中所有的运算操作都是基于张量进行的.Tensorflow中的张量Tensor是具有统一类型的多维数组.这篇博文主要介绍张量的创建及基础操 ...

  7. 【TensorFlow2.0】(4) 维度变换、广播

    各位同学好,今天我和大家分享一下TensorFlow2.0中有关数学计算的相关操作,主要内容有: (1) 改变维度:reshape():(2) 维度转置:transpose():(3) 增加维度:ex ...

  8. Tensorflow——张量

    在Tensorflow中,所有数据都通过张量的形式来表示,从功能上看,张量可以简单的被理解为多维数组.其中零阶张量表示标量(scalar),也就是一个数:第一阶张量为向量(vector),也就是一个一 ...

  9. 目标检测——使用OpenCV读取图片要注意进行维度变换

    注意:使用cv2.imread()读取彩色图片时,OpenCV获得的张量的数据顺序为h*w*c,其中张量的最后一个维度才是通道,所以在送入到torch中之前,需要对张量的维度顺序进行变换: 一个可行的 ...

最新文章

  1. 爬取CSDN最新月份所写的文章的最高阅读量文章(以及统计整个月所写的文章的阅读量的累积和)
  2. pycharm奇技淫巧 直接通过代码输出函数 refactor —— extract method
  3. anaconda安装后只有几个文件,大量文件缺失,开始栏里没有图标的解决方法
  4. Nginx命令与配置详解
  5. 全球最大IXP为何选择与华为开展数据中心互联合作?
  6. 怎么部署_2020怎么部署新零售商城?
  7. 【渝粤题库】陕西师范大学200431综合英语(一)作业(高起专、高起本)
  8. java web项目初始化启动一个java方法
  9. python多线程_干货|理解python多线程和多进程
  10. php数组循环转为对象,php中循环实现(字符串,对象,或者数组)编码相互转换
  11. CentOS下mysql安装
  12. lintcode:线段树的构造
  13. java sort方法_Java排序方法sort用法详解
  14. 件测试专家分享III GUI自动化测试相关
  15. 2008年5月Windows Mobile Webcast预告
  16. 很抱歉,程序无法在非MBR引导分区上进行激活
  17. MySQL的子查询(二十)
  18. 基于MATLAB的特征值与特征向量(附完整代码)
  19. 2021临泉一中高考成绩查询,临泉三所省级示范高中高考成绩揭晓!
  20. keras-文本图片文字识别

热门文章

  1. 1102 Invert a Binary Tree (25point(s))
  2. 反射式5×5衍射光束分束器的分析
  3. 关于malloc和free函数的用法
  4. 场效应管 | N-mos内部结构详解
  5. 74HC595D介绍与实现(C语言与verilog实现)
  6. 信息化与系统集成技术
  7. 标准时间转换为时间戳
  8. 时间格式转换2021-08-17T16:00:00.000Z存入数据库问题
  9. 15款最好用的思维导图(心智图 )工具
  10. JS Ajax 和 jQuery Ajax : 异步自动填充