引言:看了关于torch.cat函数的文章,有点乱,自己总结一篇,关于四维tensor合并。

  1. 一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224
  2. 在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]
  3. 如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如  [[[[1,2,3]]]],是四维张量。
  4. torch.cat()函数,官方文档是这样写的torch.cat(tensorsdim=0*out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。


    废话少说开始演示。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])

定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3

import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])

在维度0上进行合并,然后输出维度为[2,1,2,3],所以得出结论 四维张量在0维合并的时候 其实是在批量大小维度上进行合并。

import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])

在1维度上进行合并,输出维度为[1,2,2,3],即在1维上合并是在通道维度上进行合并。

import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)#torch.Size([1, 1, 4, 3])

在维度2上进行合并,输出维度为[1,1,4,3]。即在2维上进行合并是在高上进行合并(也可以说是在行维度进行合并)

import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)#torch.Size([1, 1, 4, 3])#在维度3上进行合并
x=torch.cat((a,b),dim=3)
print(x.shape)#torch.Size([1, 1, 2, 6])

在维度3上进行合并,输出维度为[1,1,2,6],即在3维上进行合并是在宽维度进行合并(也可以说是列)

注:在拼接时 除了选择拼接的维度可以不同,其他维度要相同。什么意思?看代码

import torch#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(1,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的通道维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 3, 2, 3])

也就是选择合并的那个维度可以不同,其他维度要相同

如果不同,报错,如下。

import torch#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(2,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的批量大小不同,维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)#RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0

可以看到当我们选择在通道维度合并时(通道数可以不同),但是其他的维度要相同(下面的a,b的批量大小也不同)。所以直接报错。

总结的有不足之错还望各位大佬指正。

torch.cat()函数 ,关于四维tensor维度合并。相关推荐

  1. 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数

    文章目录 前言 一.torch.cat()函数 拼接只存在h,w(高,宽)的图像 二.torch.cat() 拼接存在c,h,w(通道,高,宽)的图像 三.torch.add()使张量对应元素直接相加 ...

  2. torch.cat()函数的官方解释,详解以及例子

    可以直接看最下面的例子,再回头看前面的解释,就很明白了. 在pytorch中,常见的拼接函数主要是两个,分别是: stack() cat() 一般torch.cat()是为了把多个tensor进行拼接 ...

  3. torch.max()函数==》返回该维度的最大值以及该维度最大值对应的索引

    今天在学习TTSR的过程总遇到了一行代码,我发现max()函数竟然可以返回两个值,于是我决定重新学习一下这个函数 R_lv3_star, R_lv3_star_arg = torch.max(R_lv ...

  4. torch.cat() 函数用法

    torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. 使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数 ...

  5. pytorch 中 torch.cat 函数的使用

    1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. 2. 例子理解 >>> import torch ...

  6. torch的拼接函数_Pytorch中的torch.cat()函数

    cat( )的用法 按维数0拼接(竖着拼) C = torch.cat( (A,B),0 ) 按维数1拼接(横着拼) C = torch.cat( (A,B),1 ) 按维数0拼接 A=torch.o ...

  7. Pytorch中的torch.cat()函数

    转载自:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是 ...

  8. 【pytorch】torch.cat()函数

    欢迎移步我的个人博客 例子 import torchA=torch.ones(2,3) #2x3的张量(矩阵) Atensor([[ 1., 1., 1.],[ 1., 1., 1.]])B=2*to ...

  9. pytorch拼接函数:torch.stack()和torch.cat()--详解及例子

    原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...

最新文章

  1. Mysql压测工具mysqlslap 讲解
  2. 代码规范之华为公司代码规范
  3. 人类无法抗拒的10种心理(转)
  4. 地址总线是单向还是双向_三端双向交流开关(TRIAC)
  5. Oracle学习(十五)PLSQL安装
  6. axis2手动设置命名空间targetNamespace
  7. [HTTP] 重定向的302,301
  8. CPU高速缓存SRAM命中问题的总结与实验
  9. 使用/调用 函数的时候, 前面加不加 对象或 this?
  10. PHP中使用CURL(三)
  11. django mysql处理_利用Django去操作数据库并完成简易的登录及编辑功能
  12. 物流管理系统需要的服务器,物流业务管理系统
  13. vue 生成PDF(A4标准PDF分页)
  14. Digi Digimesh无线自组网协议和模块介绍
  15. ShuffleNet神经网络
  16. 【聚类】算法及其评估指标
  17. android手机最低内存,安卓想用很久不卡顿?12GB内存是最低标准,这6款硬核配置还便宜...
  18. 【车辆计数】基于matlab GUI背景差分法道路行驶多车辆检测【含Matlab源码 1911期】
  19. TJA1050国产替代DP1050T高速 CAN 总线收发器
  20. 用大白话说说JavaWeb相关技术

热门文章

  1. np.random.random()系列函数
  2. 马云说的到底对不对,京东到底行不行?
  3. Java开源GIS系统
  4. 【leetcode刷题】 64.数组的度——Java版
  5. SEXTANTE中调用任意C++控制台程序的简单例子
  6. 苏格拉底《临死前的演说》
  7. HTML5-页面加载动画
  8. android车载娱乐系统场景,复合式娱乐综合体,共享设备集成场景化空间-迷你ktv官网...
  9. STP生成树算法广播风暴的产生
  10. js根据日期往前或者往后多少月,推算出日期