1、transpose与permute

transpose() 和 permute() 都是返回转置后矩阵,在pytorch中转置用的函数就只有这两个 ,这两个函数都是交换维度的操作

transpose用法:tensor.transpose(dim0, dim1) → Tensor
只能操作2D矩阵的转置, transpose每次只能交换两个维度, 这是相比于permute的一个不同点,每次输入两个index,实现转置,,参数顺序无所谓。
permute用法:tensor.permute(dim0, dim1, ...., dimn)
permute可以进行多维度转置, permute每次可以交换多个维度,且必须传入所有维度数,参数顺序表示交换结果是原值的哪个维。

permute操作可以有1至多步的Transpose操作实现

注意:使用transpose或permute之后,若要使用view,必须先contiguous()

# 创造二维数据x,dim=0时候2,dim=1时候3
x = torch.randn(2,3)       'x.shape  →  [2,3]'
# 创造三维数据y,dim=0时候2,dim=1时候3,dim=2时候4
y = torch.randn(2,3,4)   'y.shape  →  [2,3,4]'"""
操作dim不同:
transpose()只能一次操作两个维度;permute()可以一次操作多维数据,
且必须传入所有维度数,因为permute()的参数是int*。
"""# 对于transpose
x.transpose(0,1)     'shape→[3,2] '
x.transpose(1,0)     'shape→[3,2] '
y.transpose(0,1)     'shape→[3,2,4]'
y.transpose(0,2,1)  'error,操作不了多维'# 对于permute()
x.permute(0,1)     'shape→[2,3]'
x.permute(1,0)     'shape→[3,2], 注意返回的shape不同于x.transpose(1,0) '
y.permute(0,1)     "error 没有传入所有维度数"
y.permute(1,0,2)  'shape→[3,2,4]'"""
操作dim不同:
transpose()只能一次操作两个维度, 维度的顺序不影响结果;permute()可以一次操作多维数据,
且必须传入所有维度数,因为permute()的参数是int*。
"""
# 对于transpose, (0,1) 和 (1,0) 都是指变换 维度 0 和 1, 输入顺序不影响
x1 = x.transpose(0,1)   'shape→[3,2] '
x2 = x.transpose(1,0)   '也变换了,shape→[3,2] '  # 对于permute(),
x1 = x.permute(0,1)     '保持原理tensor不变, 不同transpose,shape→[2,3] '
x2 = x.permute(1,0)     'shape→[3,2] '  y1 = y.permute(0,1,2)     '保持不变,shape→[2,3,4] '
y2 = y.permute(1,0,2)     'shape→[3,2,4] '
y3 = y.permute(1,2,0)     'shape→[3,4,2] ' 

2、关于连续contiguous()

用view()函数改变通过转置后的数据结构,导致报错

RuntimeError: invalid argument 2: view size is not compatible with input tensor's....

这是因为tensor经过转置后数据的内存地址不连续导致的,也就是tensor . is_contiguous()==False。
虽然在torch里面,view函数相当于numpy的reshape,但是这时候reshape()可以改变该tensor结构,但是view()不可以

x = torch.rand(3,4)
x = x.transpose(0,1)
print(x.is_contiguous()) # 是否连续
'False'
# 会发现
x.view(3,4)
'''
RuntimeError: invalid argument 2: view size is not compatible with input tensor's....
就是不连续导致的
'''
# 但是这样是可以的。
x = x.contiguous()
x.view(3,4)x = torch.rand(3,4)
x = x.permute(1,0) # 等价x = x.transpose(0,1)
x.reshape(3,4)
'''这就不报错了
说明x.reshape(3,4) 这个操作
等于x = x.contiguous().view()
尽管如此,但是我们还是不推荐使用reshape
除非为了获取完全不同但是数据相同的克隆体
'''

调用contiguous()时,会强制拷贝一份tensor,让它的布局从头到尾创建的一毛一样。
只需要记住了,每次在使用view()之前,该tensor只要使用了transpose()和permute()这两个函数一定要contiguous().

transpose与permute会实实在在的根据需求(要交换的dim)把相应的Tensor元素的位置进行调整, 而view 会将Tensor所有维度拉平成一维 (即按行,这也是为什么view操作要求Tensor是contiguous的原因),然后再根据传入的的维度(只要保证各维度的乘积=总元素个数即可)信息重构出一个Tensor。

a = torch.Tensor([[[1,2,3,4,5], [6,7,8,9,10], [11,12,13,14,15]], [[-1,-2,-3,-4,-5], [-6,-7,-8,-9,-10], [-11,-12,-13,-14,-15]]])
>>> a.shape
torch.Size([2, 3, 5])
# 还是上面的Tensor a
>>> print(a.shape)
torch.Size([2, 3, 5])
>>> print(a.view(2,5,3))
tensor([[[  1.,   2.,   3.],[  4.,   5.,   6.],[  7.,   8.,   9.],[ 10.,  11.,  12.],[ 13.,  14.,  15.]],[[ -1.,  -2.,  -3.],[ -4.,  -5.,  -6.],[ -7.,  -8.,  -9.],[-10., -11., -12.],[-13., -14., -15.]]])
>>> c = a.transpose(1,2)
>>> print(c, c.shape)
(tensor([[[  1.,   6.,  11.],[  2.,   7.,  12.],[  3.,   8.,  13.],[  4.,   9.,  14.],[  5.,  10.,  15.]],[[ -1.,  -6., -11.],[ -2.,  -7., -12.],[ -3.,  -8., -13.],[ -4.,  -9., -14.],[ -5., -10., -15.]]]),
torch.Size([2, 5, 3]))

即使view()transpose()最终得到的Tensor的shape是一样的,但二者内容并不相同。view函数只是按照给定的(2,5,3)的Tensor维度,将元素按顺序一个个填进去;而transpose函数,则的确是在进行第一个第二维度的转置

3、view与reshape的区别

view()具有跟reshape()相同的功能,都能去重塑矩阵的形状

不同点:

reshape()方法不受此限制;如果对 tensor 调用过 transpose, permute等操作的话会使该 tensor 在内存中变得不再连续。

view():

作用:将tensor转换为指定的shape,原始的data不改变。返回的tensor与原始的tensor共享存储区。view()方法只适用于满足连续性(contiguous)条件的tensor,并且该操作不会开辟新的内存空间,只是产生了对原存储空间的一个新别称和引用,返回值是视图。也就是说view不会改变原来数据的存放方式,并且,也不会产生数据的副本,view返回的是视图。

如果tensor 不满足连续性条件,需要先调用 contiguous()方法,但这种方法变换后的tensor就不是与原始tensor共享内存了,而是被重新开辟了一个空间。

view()可以通过在某一维度输入为-1,来动态调整这个矩阵的维度的size, 而 reshape且无动态调整的功能。而且 view()用于pytorch中对张量进行处理,

view方法可以调整tensor的形状,但必须保证调整前后元素总数一致。view不会修改自身的数据,返回的新tensor与源tensor共享内存,即更改其中一个,另外一个也会跟着改变

reshape():

作用:与view方法类似,将输入tensor转换为新的shape格式。

reshape方法更强大,可以认为a.reshape = a.view() + a.contiguous().view()

reshape()方法的返回值既可以是视图,也可以是副本。即:在满足tensor连续性条件时,a.reshape返回的结果与a.view()相同,否则返回的结果与a.contiguous().view()相同。

PyTorch:view() 与 reshape() 区别详解_地球被支点撬走啦的博客-CSDN博客_reshape view

4、torch.flatten()

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

input (Tensor) – 输入为Tensor 
start_dim (int) – 展平的开始维度 
end_dim (int) – 展平的结束维度

展平一个连续范围的维度,输出类型为Tensor, flatten函数就是对tensor类型进行扁平化处理,也就是在不同维度上进行堆叠操作,a.flatten(m),这个意思是将a这个tensor,从第m维度开始堆叠,一直堆叠到最后一个维度

import torch
# t 是三维张量 torch.Size([3, 2, 2])
t = torch.tensor([[[1, 2],[3, 4]],[[5, 6],[7, 8]],[[9, 10],[11, 12]]])
#如果不传入参数,默认开始维度为0,最后维度为-1,展开为一维
result_0 = torch.flatten(t)
print(result_0)
'''
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])
'''#当开始维度为1,最后维度为-1,展开为3x4,也就是说第一维度不变,后面的压缩
result_1 = torch.flatten(t, start_dim=1)
print(result_1)
'''
tensor([[ 1,  2,  3,  4],[ 5,  6,  7,  8],[ 9, 10, 11, 12]])
'''torch.flatten(t, start_dim=1).size()
# torch.Size([3, 4])
#下面的和上面进行对比应该就能看出是,当锁定最后的维度的时候
#前面的就会合并
result_3 = torch.flatten(t, start_dim=0, end_dim=1)
print(result_3)
'''
tensor([[ 1,  2],[ 3,  4],[ 5,  6],[ 7,  8],[ 9, 10],[11, 12]])
'''torch.flatten(t, start_dim=0, end_dim=1).size()
# torch.Size([6, 2])

示例:

import torch
# 随机产生了一个tensor,它的Batchsize是2,C是3,H是2,W是3
a=torch.rand(2,3,2,3)
print(a)
'''
tensor([[[[0.5521, 0.2547, 0.5242],[0.8248, 0.4500, 0.2413]],[[0.7759, 0.1261, 0.0090],[0.0197, 0.6191, 0.0422]],[[0.0896, 0.1731, 0.5484],[0.7927, 0.0752, 0.2176]]],[[[0.0118, 0.3865, 0.9587],[0.6599, 0.2464, 0.0728]],[[0.2858, 0.3772, 0.8215],[0.3267, 0.2859, 0.4329]],[[0.7329, 0.4436, 0.4246],[0.4162, 0.8688, 0.5286]]]])
'''
##########################################################################
result_0 = a.flatten(0)
print(result_0.shape)
print(result_0)
'''
torch.Size([36])
tensor([0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176,0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286])
'''
##########################################################################
result_1 = a.flatten(1)
print(result_1.shape)
print(result_1)
'''
torch.Size([2, 18])
tensor([[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413, 0.7759, 0.1261, 0.0090,0.0197, 0.6191, 0.0422, 0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176],[0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728, 0.2858, 0.3772, 0.8215,0.3267, 0.2859, 0.4329, 0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]])
'''
##########################################################################
result_2 = a.flatten(2)
print(result_2.shape)
print(result_2)
'''
torch.Size([2, 3, 6])
tensor([[[0.5521, 0.2547, 0.5242, 0.8248, 0.4500, 0.2413],[0.7759, 0.1261, 0.0090, 0.0197, 0.6191, 0.0422],[0.0896, 0.1731, 0.5484, 0.7927, 0.0752, 0.2176]],[[0.0118, 0.3865, 0.9587, 0.6599, 0.2464, 0.0728],[0.2858, 0.3772, 0.8215, 0.3267, 0.2859, 0.4329],[0.7329, 0.4436, 0.4246, 0.4162, 0.8688, 0.5286]]])
'''
##########################################################################
result_3 = a.flatten(3)
print(result_3.shape)
print(result_3)
'''
torch.Size([2, 3, 2, 3])
tensor([[[[0.5521, 0.2547, 0.5242],[0.8248, 0.4500, 0.2413]],[[0.7759, 0.1261, 0.0090],[0.0197, 0.6191, 0.0422]],[[0.0896, 0.1731, 0.5484],[0.7927, 0.0752, 0.2176]]],[[[0.0118, 0.3865, 0.9587],[0.6599, 0.2464, 0.0728]],[[0.2858, 0.3772, 0.8215],[0.3267, 0.2859, 0.4329]],[[0.7329, 0.4436, 0.4246],[0.4162, 0.8688, 0.5286]]]])
'''##########################################################################
result_4 = a.flatten(0, 1)
print(result_4.shape)
print(result_4)'''
torch.Size([6, 2, 3])
tensor([[[0.5521, 0.2547, 0.5242],[0.8248, 0.4500, 0.2413]],[[0.7759, 0.1261, 0.0090],[0.0197, 0.6191, 0.0422]],[[0.0896, 0.1731, 0.5484],[0.7927, 0.0752, 0.2176]],[[0.0118, 0.3865, 0.9587],[0.6599, 0.2464, 0.0728]],[[0.2858, 0.3772, 0.8215],[0.3267, 0.2859, 0.4329]],[[0.7329, 0.4436, 0.4246],[0.4162, 0.8688, 0.5286]]])
'''

a.flatten(0)的意思就是从batchsize这个维度开始堆叠,直到W结束,那最后就是成一维的了,也就是只剩W这个维度,那当然就是只有一条这样子

a.flatten(1)的意思就是从C(channel)这个维度开始堆叠,直到W结束,Batchsize这个维度没有参与运算,因此还是有B这个维度的,这样的话就是相当于将三维的数据堆叠成只有一个维度W的数据,那当然就变成了两条

a.flatten(2)的意思就是从H(Height)这个维度开始堆叠,直到W结束,B和C这两个维度都没有参与运算,因此将H这个维度堆叠到W上去,就是将原本的平面变成了一个长条

最后a.flatten(3)的意思就是将H这个维度堆叠到H这个维度上去,自己堆叠自己就是没有堆叠

a.flatten(0,1), 将B的维度叠加到C的维度上,就是将两个batch叠加合并了

5、flatten函数的用法及其与reshape函数的区别

深度学习(PyTorch)——flatten函数的用法及其与reshape函数的区别_清泉_流响的博客-CSDN博客_flatten函数

Pytorch中 permute / transpose 和 view / reshape, flatten函数相关推荐

  1. PyTorch中permute的用法

    RuntimeError: Given groups=1, weight of size [18, 8, 8], expected input[64, 32, 8] to have 8 channel ...

  2. Pytorch之view,reshape,resize函数

    对于深度学习中的一下数据,我们通常是要变成tensor格式,并且需要对其调整形状,很多时候我们往往只关注view之后的结果(比如输出的尺寸),而不关心过程.但有时候还是要关注一下这个到底是怎么变换过来 ...

  3. pytorch中的transpose()

    维度的相互调换 运行结果 batch_es和batch_ee输出的结果是一样的.原始的transpose参数(默认的参数)为(0,1,2),这个转置相当于将第一个坐标与第二坐标进行了互换.

  4. python中permute_PyTorch中permute的用法详解

    PyTorch中permute的用法详解 permute(dims) 将tensor的维度换位. 参数:参数是一系列的整数,代表原来张量的维度.比如三维就有0,1,2这些dimension. 例: i ...

  5. permute、transpose、view、reshape、unsequeeze与flatten

    Pytorch中view, transpose, permute等方法的区别 关于张量的Flatten.Reshape和Squeeze的解释 transpose 只能交换两个维度,参数是(dim1,d ...

  6. pytorch中的reshape()、view()、nn.flatten()和flatten()

    在使用pytorch定义神经网络结构时,经常会看到类似如下的.view() / flatten()用法,这里对其用法做出讲解与演示. torch.reshape用法 reshape()可以由torch ...

  7. Pytorch中tensor.view().permute().contiguous()函数理解

    Pytorch中tensor.view().permute().contiguous()函数理解 yolov3中有一行这样的代码,在此记录一下三个函数的含义 # 例子中batch_size为整型,le ...

  8. PyTorch中contiguous、view、Sequential、permute函数的用法

    在pytorch中,tensor的实际数据以一维数组(storage)的形式存储于某个连续的内存中,以"行优先"进行存储. 1. tensor的连续性 tensor连续(conti ...

  9. **Pytorch 中view函数和reshape函数的区别*

    Pytorch 中view函数和reshape函数的区别(我是一名大一刚学计算机的学生 希望我的说法对你有帮助) 首先:要了解这个问题我们要先了解一个基本知识 张量的储存方式 跟据图片我们可以清楚的看 ...

最新文章

  1. c++中把一个函数中的语句复制到另一个语句中报错_从底层看前端(十一)—— JavaScript语法:脚本,模块和函数体。...
  2. 姿态迁移CoCosNet v2
  3. import javax.servlet.http.HttpServletRequest 提示错误
  4. Darknet_Yolov4实战(二)_安装OpenCV
  5. Oracle第三课之PLSQL
  6. 安装phpssdb扩展:
  7. 互联网人理想假期VS现实假期
  8. 复古多变“格子控”混搭 夏季继续魅力四射
  9. how to catch out of memory exception in c++
  10. 数学与编程——概率论与数理统计
  11. Python中函数的形参与按值传递之间的关系
  12. Docker教程:docker远程repository和自建本地registry
  13. 猫猫学iOS之小知识之xcode6自己主动提示图片插件 KSImageNamed的安装
  14. K3 WISE 开发插件《SQL语句WHERE查询-范围查询/模糊查询》
  15. pytorch:测试GPU是否可用
  16. 小公司代理记账报税常见问题
  17. Android出现Could not initialize class com.android.sdklib.repository.AndroidSdkHandler的解决方法
  18. 别再逐帧扒电影了 生活中处处都有彩蛋!
  19. Substrate 基础 -- 教程(Tutorials)
  20. Mysql数据库管理系统原理及基本操作

热门文章

  1. Devtools inspection is not available because it‘s in production mode or explicitly disabled by the a
  2. 软件实施-SQL基础
  3. WPS一级计算机基础知识,2017年一级计算机基础及WPS Office考试大纲
  4. Nginx 安装教程 (windows) 及详解 并通过Nginx启动项目(vue项目举例)
  5. 罗马数字(暴力破解)
  6. C++ string字符串分割
  7. JsZip+FileSaver实现打包文件并下载
  8. 电子学会青少年编程等级考试Python一级题目解析06
  9. 给QC项目瘦身的方案
  10. 内有干货!2个人3个月如何从零完成一款社区App《林卡》