详解 torch.unsqueeze 和 torch.squeeze

  • 1. 入门测试
  • 2. 深入研究
    • 2.1 torch.unsqueeze 详解
    • 2.2 unsqueeze_和 unsqueeze 的区别
    • 2.3 torch.squeeze 详解

1. 入门测试

  • torch.squeeze(input, dim = None, out = None): 返回一个tensor

    • 当dim不设值时,去掉输入的tensor的所有维度为1的维度;
    • 当dim为某一整数(0<=dim<input.dim())时,判断dim维的维度是否为1,若是则去掉,否则不变。
    • 另外,当input是一维的时候,squeeze不变
>>> x = torch.zeros(1,1,2,1,3)
>>> x.dim()
5
>>> torch.squeeze(x).size() # 去掉dim=1的维度
torch.Size([2, 3])
>>> torch.squeeze(x,0).size()  # dim=0表示第一维,且第一维的维度为1,所以去掉
torch.Size([1, 2, 1, 3])
>>> torch.squeeze(x,3).size()
torch.Size([1, 1, 2, 3])
>>> torch.squeeze(x,2).size()  # dim=2,第三维的维度为2!=1,所以不变
torch.Size([1, 1, 2, 1, 3])
  • torch.unqueeze(input, dim, out=None): 和squeeze作用相反,unsqueeze()在dim维插入一个维度为1的维,例如原来x是n×m维的,torch.unqueeze(x,0)这返回1×n×m的tensor
>>> x = torch.tensor([1,2,3])  # dim=1,即(3)
>>> torch.unsqueeze(x, 1)  # 变为(3,1)的矩阵
tensor([[ 1],[ 2],[ 3]])
  1. squeeze:压缩(降维)
  2. unqueeze:解压缩(升维)

2. 深入研究

2.1 torch.unsqueeze 详解

torch.unsqueeze(input, dim, out=None)
  • 作用:扩展维度

返回一个新的张量,对输入的既定位置插入维度 1

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。

如果dim为负,则将会被转化dim+input.dim()+1

import torchx = torch.tensor([1, 2, 3])
print(x)
print(x.size())
print(torch.unsqueeze(x, -1))  # dim=-1作用效果等同dim=1
print(torch.unsqueeze(x, -1).size())
>>>
tensor([1, 2, 3])
torch.Size([3])
------------
tensor([[1],[2],[3]])
torch.Size([3, 1])
import torchx = torch.tensor([1, 2, 3])
print(x)
print(x.size())
print(torch.unsqueeze(x, -2))  # dim=-2:转化为(1, input_dim)
print(torch.unsqueeze(x, -2).size())
>>>
tensor([1, 2, 3])
torch.Size([3])
------------
tensor([[1, 2, 3]])
torch.Size([1, 3])

IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)

  • 参数:

    • tensor (Tensor) – 输入张量
    • dim (int) – 插入维度的索引
    • out (Tensor, optional) – 结果张量
import torchx = torch.Tensor([1, 2, 3, 4])  # torch.Tensor是默认的tensor类型(torch.FlaotTensor)的简称。print('-' * 50)
print(x)  # tensor([1., 2., 3., 4.])
print(x.size())  # torch.Size([4])
print(x.dim())  # 1
print(x.numpy())  # [1. 2. 3. 4.]print('-' * 50)
print(torch.unsqueeze(x, 0))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim())  # 2
print(torch.unsqueeze(x, 0).numpy())  # [[1. 2. 3. 4.]]print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, 1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim())  # 2print('-' * 50)
print(torch.unsqueeze(x, -1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, -1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim())  # 2print('-' * 50)
print(torch.unsqueeze(x, -2))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, -2).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, -2).dim())  # 2# 边界测试
# 说明:A dim value within the range [-input.dim() - 1, input.dim() + 1) (左闭右开)can be used.
# print('-' * 50)
# print(torch.unsqueeze(x, -3))
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got -3)# print('-' * 50)
# print(torch.unsqueeze(x, 2))
# IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)# 为何取值范围要如此设计呢?
# 原因:方便操作
# 0(-2)-行扩展
# 1(-1)-列扩展
# 正向:我们在0,1位置上扩展
# 逆向:我们在-2,-1位置上扩展
# 维度扩展:1维->2维,2维->3维,...,n维->n+1维
# 维度降低:n维->n-1维,n-1维->n-2维,...,2维->1维# 以 1维->2维 为例,# 从【正向】的角度思考:# torch.Size([4])
# 最初的 tensor([1., 2., 3., 4.]) 是 1维,我们想让它扩展成 2维,那么,可以有两种扩展方式:# 一种是:扩展成 1行4列 ,即 tensor([[1., 2., 3., 4.]])
# 针对第一种,扩展成 [1, 4]的形式,那么,在 dim=0 的位置上添加 1# 另一种是:扩展成 4行1列,即
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
# 针对第二种,扩展成 [4, 1]的形式,那么,在dim=1的位置上添加 1# 从【逆向】的角度思考:
# 原则:一般情况下, "-1" 是代表的是【最后一个元素】
# 在上述的原则下,
# 扩展成[1, 4]的形式,就变成了,在 dim=-2 的的位置上添加 1
# 扩展成[4, 1]的形式,就变成了,在 dim=-1 的的位置上添加 1

dim值对应增加维度的方式:

2.2 unsqueeze_和 unsqueeze 的区别

unsqueeze_unsqueeze 实现一样的功能,区别在于 unsqueeze_in_place 操作,即 unsqueeze 不会对使用 unsqueeze 的 tensor 进行改变,想要获取 unsqueeze 后的值必须赋予个新值, unsqueeze_ 则会对自己改变

print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])b = torch.unsqueeze(a, 1)
print(b)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])print(a)
# tensor([1., 2., 3., 4.])print("-" * 50)
a = torch.Tensor([1, 2, 3, 4])
print(a)
# tensor([1., 2., 3., 4.])print(a.unsqueeze_(1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])print(a)
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])

2.3 torch.squeeze 详解

torch.squeeze(input, dim=None, out=None)
  • 作用:降维

1. 将输入张量形状中的1 去除并返回。

如果输入是形如(A×1×B×1×C×1×D),那么输出形状就为: (A×B×C×D)

2. 当给定dim时,那么挤压操作只在给定维度上。

例如,输入形状为: (A×1×B), squeeze(input, 0) 将会保持张量不变,只有用 squeeze(input, 1),形状会变成 (A×B)。

  • 注意: 返回张量与输入张量共享内存,所以改变其中一个的内容会改变另一个。
  • 参数:
    • input (Tensor) – 输入张量
    • dim (int, optional) – 如果给定,则input只会在给定维度挤压
    • out (Tensor, optional) – 输出张量

为何只去掉 1 呢?

多维张量本质上就是一个变换,如果维度是 1 ,那么,1 仅仅起到扩充维度的作用,而没有其他用途,因而,在进行降维操作时,为了加快计算,是可以去掉这些 1 的维度

print("*" * 50)m = torch.zeros(2, 1, 2, 1, 2)
print(m.size())  # torch.Size([2, 1, 2, 1, 2])n = torch.squeeze(m)
print(n.size())  # torch.Size([2, 2, 2])n = torch.squeeze(m, 0)  # 当给定dim时,那么挤压操作只在给定维度上
print(n.size())  # torch.Size([2, 1, 2, 1, 2])n = torch.squeeze(m, 1)
print(n.size())  # torch.Size([2, 2, 1, 2])n = torch.squeeze(m, 2)
print(n.size())  # torch.Size([2, 1, 2, 1, 2])n = torch.squeeze(m, 3)
print(n.size())  # torch.Size([2, 1, 2, 2])print("@" * 50)
p = torch.zeros(2, 1, 1)
print(p)
# tensor([[[0.]],
#         [[0.]]])
print(p.numpy())
# [[[0.]]
#  [[0.]]]print(p.size())
# torch.Size([2, 1, 1])q = torch.squeeze(p)
print(q)
# tensor([0., 0.])print(q.numpy())
# [0. 0.]print(q.size())
# torch.Size([2])print(torch.zeros(3, 2).numpy())
# [[0. 0.]
#  [0. 0.]
#  [0. 0.]]

参考:Link


加油!

感谢!

努力!

【终于有人搞懂了】详解 torch.unsqueeze() 和 torch.squeeze()相关推荐

  1. 手机里竟然有这么多传感器!终于都搞懂了

    手机里竟然有这么多传感器!终于都搞懂了 本文来自快科技 随着技术的进步,手机已经不再是一个简单的通信工具,而是具有综合功能的便携式电子设备.手机的虚拟功能,比如交互.游戏.都是通过处理器强大的计算能力 ...

  2. 匈奴 东胡 突厥 蒙古 契丹 女真 ……终于全部搞懂了!值得看。

    匈奴 东胡 突厥 蒙古 契丹 女真 --终于全部搞懂了!值得看. 古代到现在一些小国家的形成和解体

  3. torch.unsqueeze和 torch.squeeze() 详解

    1. torch.unsqueeze 详解 torch.unsqueeze(input, dim, out=None) 作用:扩展维度 返回一个新的张量,对输入的既定位置插入维度 1 注意: 返回张量 ...

  4. 关于GaussDB(DWS)的正则表达式知多少?人人都能看得懂的详解来了!

    摘要:GaussDB(DWS)除了支持标准的POSIX正则表达式句法,还拥有一些特殊句法和选项,这些你可了解?本文便为你讲解这些特殊句法和选项. 概述 正则表达式(Regular Expression ...

  5. torch.unsqueeze()和torch.unsqueeze()

    参考:torch.squeeze() 和torch.unsqueeze()用法的通俗解释 import torch x = torch.tensor([[1, 2, 3],[1, 2, 3],[1, ...

  6. Pytorch中torch.unsqueeze()和torch.squeeze()函数解析

    一. torch.squeeze()函数解析 1. 官网链接 torch.squeeze(),如下图所示: 2. torch.squeeze()函数解析 torch.squeeze(input, di ...

  7. 看一遍就懂,详解java多线程——volatile

    多线程一直以来都是面试必考点,而volatile.synchronized也是必问点,这里我试图用容易理解的方式来解释一下volatile. 来看一下它的最大特点和作用: 一 使变量在多个线程间可见 ...

  8. 【JVM系列3】方法重载和方法重写原理分析,看完这篇终于彻底搞懂了

    深入分析Java虚拟机中方法执行流程及方法重载和方法重写原理 前言 思考 栈帧 局部变量表(Local Variables) 操作数栈(Operand Stacks) 动态连接(Dynamic Lin ...

  9. 只有20%的iOS程序员能看懂:详解intrinsicContentSize 及 约束优先级/content Hugging/content Compression Resistance

    在了解intrinsicContentSize之前,我们需要先了解2个概念: AutoLayout在做什么 约束优先级是什么意思. 如果不了解这两个概念,看intinsic content size没 ...

最新文章

  1. LeetCode简单题之最常见的单词
  2. ZooKeeper的基本原理
  3. 从责任界定和问题预警角度 解读全栈溯源对DevOps的价值
  4. mysql专区_MySQL-技术专区-详解索引原理
  5. StringBuilder的toString方法
  6. stream pipe的原理及简化源码分析
  7. CMSIS-DAP和J-Link、ST-Link是什么关系?
  8. c语言中合法转义字符,判断c语言合法转义字符
  9. u盘pe无人值守linux,从U盘无人值守安装linux操作系统(纯实践笔记
  10. Python中的具名元组类用法
  11. merge规则 python_用Python处理PDF
  12. 37. 使用accumulate或者for_each进行区间统计
  13. SCADA之父:物理隔离没什么用
  14. 腾讯云对象存储(cos) js jdk上传文件
  15. python语言程序设计 陈东_清华大学出版社-图书详情-《Python语言程序设计》
  16. python已知三角形的顶点坐标,求任一顶点角度
  17. 开源OCR文字识别软件Calamari
  18. 微信营销诀窍:有朋自各方来
  19. 区块链是什么通俗解释?
  20. IDEA创建java项目src下没有办法创建包文件/MAVEN模块名变灰且模块多道横杠

热门文章

  1. 简易ffmpeg安装
  2. Windows10 安装IIS
  3. Linux——LVM管理之文件系统新建
  4. 删除OpenStack僵尸卷
  5. Springdata_自己的小小总结02
  6. python退出程序命令
  7. Second Life第二人生 注册 登陆 常见问题解析
  8. 多家电商平台有大量三有保护动物,被指纵容犯罪
  9. 《YOLOv5/v7改进实战专栏》专栏介绍 专栏目录
  10. join是什么意思啊(join是什么意思啊)