torch.cat()函数 ,关于四维tensor维度合并。
引言:看了关于torch.cat函数的文章,有点乱,自己总结一篇,关于四维tensor合并。
- 一张图像在计算机中的表示通常为三维tensor(张量),即[channels,height,width] 。也就是一张彩色图片通常有三色通道(R,G,B)组成,高和宽也就是常说的照片大小,比如224x224
- 在图像处理的时候会增加一个变量batch_size,也就是把多少张图片作为一批进行处理。所以就变成了四维张量,即[batch_size,channels,heigth,width],也即是[批量大小,通道数,高,宽]
- 如何判断一个tensor是几维张量最简单的办法就是看中括号数。例如 [[[[1,2,3]]]],是四维张量。
- torch.cat()函数,官方文档是这样写的
torch.
cat
(tensors, dim=0, *, out=None),也就是有两个参数,一个是要合并的张量,一个是在哪个维度上进行合并。
废话少说开始演示。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])
定义了两个四维张量。维度都为[1,1,2,3],即批量大小为1,通道为1,高为2,宽为3
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])
在维度0上进行合并,然后输出维度为[2,1,2,3],所以得出结论 四维张量在0维合并的时候 其实是在批量大小维度上进行合并。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])
在1维度上进行合并,输出维度为[1,2,2,3],即在1维上合并是在通道维度上进行合并。
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)#torch.Size([1, 1, 4, 3])
在维度2上进行合并,输出维度为[1,1,4,3]。即在2维上进行合并是在高上进行合并(也可以说是在行维度进行合并)
import torch
a=torch.tensor([[[[1,1,1],[2,2,2]]]])
b=torch.tensor([[[[3,3,3],[4,4,4]]]])
print(a.shape,b.shape)#torch.Size([1, 1, 2, 3]) torch.Size([1, 1, 2, 3])#在维度0上面进行合并
x=torch.cat((a,b),dim=0)
print(x.shape)#torch.Size([2, 1, 2, 3])#在维度1上进行合并
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 2, 2, 3])#在维度2上进行合并
x=torch.cat((a,b),dim=2)
print(x.shape)#torch.Size([1, 1, 4, 3])#在维度3上进行合并
x=torch.cat((a,b),dim=3)
print(x.shape)#torch.Size([1, 1, 2, 6])
在维度3上进行合并,输出维度为[1,1,2,6],即在3维上进行合并是在宽维度进行合并(也可以说是列)
注:在拼接时 除了选择拼接的维度可以不同,其他维度要相同。什么意思?看代码
import torch#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(1,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的通道维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)#torch.Size([1, 3, 2, 3])
也就是选择合并的那个维度可以不同,其他维度要相同
如果不同,报错,如下。
import torch#定义两个变量[batch_size,channel,height,width]
a=torch.randn(size=(1,1,2,3))
b=torch.randn(size=(2,2,2,3))
#选择在1维度进行合并(也就是通道维度),注意a,b的批量大小不同,维度不同,其他维度都相同。
x=torch.cat((a,b),dim=1)
print(x.shape)#RuntimeError: Sizes of tensors must match except in dimension 1. Got 1 and 2 in dimension 0
可以看到当我们选择在通道维度合并时(通道数可以不同),但是其他的维度要相同(下面的a,b的批量大小也不同)。所以直接报错。
总结的有不足之错还望各位大佬指正。
torch.cat()函数 ,关于四维tensor维度合并。相关推荐
- 【深度学习】特征融合的重要方法 | 张量的拼接 | torch.cat()函数 | torch.add(函数
文章目录 前言 一.torch.cat()函数 拼接只存在h,w(高,宽)的图像 二.torch.cat() 拼接存在c,h,w(通道,高,宽)的图像 三.torch.add()使张量对应元素直接相加 ...
- torch.cat()函数的官方解释,详解以及例子
可以直接看最下面的例子,再回头看前面的解释,就很明白了. 在pytorch中,常见的拼接函数主要是两个,分别是: stack() cat() 一般torch.cat()是为了把多个tensor进行拼接 ...
- torch.max()函数==》返回该维度的最大值以及该维度最大值对应的索引
今天在学习TTSR的过程总遇到了一行代码,我发现max()函数竟然可以返回两个值,于是我决定重新学习一下这个函数 R_lv3_star, R_lv3_star_arg = torch.max(R_lv ...
- torch.cat() 函数用法
torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. 使用torch.cat((A,B),dim)时,除拼接维数dim数值可不同外其余维数 ...
- pytorch 中 torch.cat 函数的使用
1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是concatnate的意思,即拼接,联系在一起. 2. 例子理解 >>> import torch ...
- torch的拼接函数_Pytorch中的torch.cat()函数
cat( )的用法 按维数0拼接(竖着拼) C = torch.cat( (A,B),0 ) 按维数1拼接(横着拼) C = torch.cat( (A,B),1 ) 按维数0拼接 A=torch.o ...
- Pytorch中的torch.cat()函数
转载自:https://www.cnblogs.com/JeasonIsCoding/p/10162356.html 1. 字面理解:torch.cat是将两个张量(tensor)拼接在一起,cat是 ...
- 【pytorch】torch.cat()函数
欢迎移步我的个人博客 例子 import torchA=torch.ones(2,3) #2x3的张量(矩阵) Atensor([[ 1., 1., 1.],[ 1., 1., 1.]])B=2*to ...
- pytorch拼接函数:torch.stack()和torch.cat()--详解及例子
原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...
最新文章
- Mysql压测工具mysqlslap 讲解
- 代码规范之华为公司代码规范
- 人类无法抗拒的10种心理(转)
- 地址总线是单向还是双向_三端双向交流开关(TRIAC)
- Oracle学习(十五)PLSQL安装
- axis2手动设置命名空间targetNamespace
- [HTTP] 重定向的302,301
- CPU高速缓存SRAM命中问题的总结与实验
- 使用/调用 函数的时候, 前面加不加 对象或 this?
- PHP中使用CURL(三)
- django mysql处理_利用Django去操作数据库并完成简易的登录及编辑功能
- 物流管理系统需要的服务器,物流业务管理系统
- vue 生成PDF(A4标准PDF分页)
- Digi Digimesh无线自组网协议和模块介绍
- ShuffleNet神经网络
- 【聚类】算法及其评估指标
- android手机最低内存,安卓想用很久不卡顿?12GB内存是最低标准,这6款硬核配置还便宜...
- 【车辆计数】基于matlab GUI背景差分法道路行驶多车辆检测【含Matlab源码 1911期】
- TJA1050国产替代DP1050T高速 CAN 总线收发器
- 用大白话说说JavaWeb相关技术