如果对 torch.unsqueeze 了解不多,建议先阅读博主的这篇文章。

1. Tensor 和 Scalar

import torchX = torch.tensor([[1,2,3],[4,5,6]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)Y = torch.tensor(7
)
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)Z1 = X*Y
print("\nZ:\n",Z1)
print("Z1's shape:\n",Z1.shape)Z2 = Y*X
print("\nZ:\n",Z2)
print("Z2's shape:\n",Z1.shape)

输出:

X:tensor([[1, 2, 3],[4, 5, 6]])
X's shape:torch.Size([2, 3])Y:tensor(7)
Y's shape:torch.Size([])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z2's shape:torch.Size([2, 3])

直接将 scalar 广播成 shape 为 [2,3] 的 tensor 了。

2. Tensor与Tensor: [2,3]*[3]

仅修改下列代码

Y = torch.tensor([7,8,9]
)

结果

X:tensor([[1, 2, 3],[4, 5, 6]])X's shape:torch.Size([2, 3])Y:tensor([7, 8, 9])
Y's shape:torch.Size([3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z2's shape:torch.Size([2, 3])

将 shape 为 [3] 的tensor 广播为 [2,3]

3. Tensor与Tensor [2,3]*[3,1], 报错

仅仅修改下列代码:

Y = torch.tensor([[7],[8],[9]]
)

输出:

Y:tensor([[7],[8],[9]])
Y's shape:torch.Size([3, 1])
Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 16, in <module>Z = X*Y
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0

这个报错我暂时还没搞懂。

4. X.shape: [5,3] Y.shape [5,6] 如何 broadcast ?

import torchX = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)X1 = X.unsqueeze(1)
print("X1:\n",X1)
print("X1's shape:\n",X1.shape,"\n")Y = torch.ones((5,6))
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)Y1 = Y.unsqueeze(2)
print("Y1:\n",Y1)
print("Y1's shape:\n",Y1.shape,"\n")Z1 = X.unsqueeze(1)*Y.unsqueeze(2);
print("\nZ1:\n",Z1)
print("Z1's shape:",Z1.shape,"\n\n\n")
Z2 = X*Y

输出:

X:tensor([[ 1,  2,  3],[ 4,  5,  6],[ 7,  8,  9],[10, 11, 12],[13, 14, 15]])
X's shape:torch.Size([5, 3])
X1:tensor([[[ 1,  2,  3]],[[ 4,  5,  6]],[[ 7,  8,  9]],[[10, 11, 12]],[[13, 14, 15]]])
X1's shape:torch.Size([5, 1, 3]) Y:tensor([[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.]])
Y's shape:torch.Size([5, 6])
Y1:tensor([[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]]])
Y1's shape:torch.Size([5, 6, 1]) Z1:tensor([[[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.]],[[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.]],[[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.]],[[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.]],[[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.]]])
Z1's shape: torch.Size([5, 6, 3]) Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 29, in <module>Z2 = X*Y
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1Process finished with exit code 1

上面的 X 相当于 rays_direction 向量,维度是 [num_rays, 3],表示所有光线的方向向量;
上面的 Y 相当于 z_val 向量,维度是 [num_rays,num_sample], 表示所有光线的采样点的距离。
那么针对每条光线的每个采样点,都需要乘上 X 中的某一行,相当于是 zx⃗z\vec{x}zx 的操作。

X:[5,3]
Y[5,6]
X.unsqueeze(1)*Y.unsqueeze(2):[5,6,3]
X.unsqueeze(1): [5,3]→[5,1,3][5,3] \rightarrow [5,1,\mathbf{3}][5,3]→[5,1,3]
Y.unsqueeze(2) : [5,6]→[5,6,1][5,6] \rightarrow [5,\mathbf{6},1][5,6]→[5,6,1]

在执行 X.unsqueeze(1)*Y.unsqueeze(2): 的时候:
X.unsqueeze(1) 、Y.unsqueeze(2)都会被 broadcast 成 [5,6,3] 的 shape
看上面两行,X.unsqueeze(1)与Y.unsqueeze(2)相乘, 结果 tensor 的 sahpe 的每一个维度的值取它们两个对应值的较大值。即:
dim 1: max⁡{1,6}=6\max\{1,6\}=6max{1,6}=6
dim 2: max⁡{3,1}=3\max\{3,1\}=3max{3,1}=3

理解 unsqueeze 的很好的一个例子

经过 element-wise 乘法后,即 “*” 乘法后,
得到的结果的 shape 和 参与乘法运算的tensor的较大 dim 的 tensor的shape一致
NeRF 里面一个 weight 矩阵,它的 shape 是:N_rays * N_samples
还有一个 RGB tensor,它的shape 是:N_rays * N_samples*3
根据 Volume Rendering,同一个点的 RGB 的系数是一样的
见下述代码:

nRays = 2
nSamples = 3weight = torch.tensor([[1,2,3],[7,4,8]])RGB = torch.tensor([[[0,3,0],[4,5,6],[7,8,9]],[[10,11,18],[8,58,6],[70,82,9]]
])print("weight's shape",weight.shape)
print("weight",weight)print("\n", "RGB's shape",RGB.shape)
print("RGB",RGB,"\n")weight_after_unsqueeze = weight.unsqueeze(-1)
print("weight_after_unsqueeze's shape: ",weight_after_unsqueeze.shape)
print("weight_after_unsqueeze",weight_after_unsqueeze)res2 = weight_after_unsqueeze*RGB
print("shape: ",res2.shape)
print(res2)

输出:

weight's shape torch.Size([2, 3])
weight tensor([[1, 2, 3],[7, 4, 8]])RGB's shape torch.Size([2, 3, 3])
RGB tensor([[[ 0,  3,  0],[ 4,  5,  6],[ 7,  8,  9]],[[10, 11, 18],[ 8, 58,  6],[70, 82,  9]]]) weight_after_unsqueeze's shape:  torch.Size([2, 3, 1])
weight_after_unsqueeze tensor([[[1],[2],[3]],[[7],[4],[8]]])
shape:  torch.Size([2, 3, 3])
tensor([[[  0,   3,   0],[  8,  10,  12],[ 21,  24,  27]],[[ 70,  77, 126],[ 32, 232,  24],[560, 656,  72]]])

上述代码中 N_rays 和 N_weights 分别取 2、3.
在 weight 经过 unsqueeze 后,得到的tensor的形式,竖着排的形式,会给人提示,是一个数占了一列。RGB三个分量都会乘上同一个数字,同一列的。

Tensor unsqueeze 以 broadcast相关推荐

  1. tensor.squeeze函数和tensor.unsqueeze函数的使用

    tensor.squeeze() 和 tensor.unsqueeze() 是 PyTorch 中用于改变 tensor 形状的两个函数,它们的作用如下: tensor.squeeze(dim=Non ...

  2. pytorch: Tensor的创建与调整

    测试环境版本: torch1.7.1 + CPU python 3.6 Tensor是pytorch中的"张量",可以看作是类似numpy的矩阵 本文介绍如何创建与调整Tensor ...

  3. Pytorch squeeze() 和 unsqueeze() 方法区别

    1 增加维度 unsqueeze() tensor = tensor.unsqueeze(0) 2 减少维度 squeeze() tensor = tensor.squeeze(0)

  4. torch对于tensor的常规操作

    前言 使用pytorch框架,会常操作tensor,以下则是对tensor常规操作的汇总. import torch torch.Tensor会继承某些torch的某些数学运算,例如sort, min ...

  5. 详解Tensor用法

    Tensor的操作 如果本文对你有帮助,欢迎点赞.订阅以及star我的项目. 你的支持是我创作的最大动力! 张量的数据属性与 NumPy 数组类似,如下所示: 张量的操作主要包括张量的结构操作和张量的 ...

  6. 给tensor增加维度 或 减少维度

    import torch import tensorflow as tf import numpy as np#tf.expand_dims(input, axis=1) <-> tf.s ...

  7. OneFlow 的 Global Tensor 学习笔记和实习总结

    文章目录 1 前言 2 关于 Global Tensor 2.1 OneFlow 分布式全局视角的基础保证 2.2 SBP 自动转换 2.3 to_global 方法 2.4 GlobalTensor ...

  8. pytorch 定义torch类型数据_PyTorch官方中文文档:torch.Tensor

    torch.Tensor torch.Tensor是一种包含单一数据类型元素的多维矩阵. Torch定义了七种CPU tensor类型和八种GPU tensor类型: Data tyoe CPU te ...

  9. Pytorch中torch.unsqueeze()和torch.squeeze()函数解析

    一. torch.squeeze()函数解析 1. 官网链接 torch.squeeze(),如下图所示: 2. torch.squeeze()函数解析 torch.squeeze(input, di ...

最新文章

  1. Windows程序的基本结构(转)
  2. 在python中如何有效的比较两个无序的列表是否包含完全同样的元素(不是set)?
  3. 基于高阶累积量的数字调制信号分类(Hierarchical Digital Modulation Classification Using Cumulants例1复现)
  4. linux性能监控sar命令详解
  5. maven 打包数据库加密_SpringBoot项目application.yml文件数据库配置密码加密的方法...
  6. python mss_Python实现的连接mssql数据库操作示例
  7. mysql5.7安装
  8. Python:pip下载库后导入Pycharm的方法
  9. 怎样在VS2013/MFC中使用TeeChart绘图控件
  10. 服务器创建新文件夹权限设置密码,在服务器上修改文件夹权限设置密码
  11. python实现kmeans图像分割、一只遥望大海的小狗_【Python】爬虫+ K-means 聚类分析电影海报主色...
  12. 移动web网页开发——动画
  13. 如何删除计算机中常用列表,清除右键多余菜单,鼠标右键菜单清理的方法(一) -电脑资料...
  14. linux下RTNETLINK answers: File exists的解决方案 慎重
  15. 叮咚小区官网新闻已不更新
  16. 前端获取验证码、手机号登录、注册功能
  17. 如何批量重命名图片,文档,文件夹名字 Windows CMD 批量修改文件名字 内含修改路径的操作,想改哪里改哪里!
  18. Ansible安装使用
  19. 客户流失?来看看大厂如何基于spark+机器学习构建千万数据规模上的用户留存模型 ⛵
  20. 大数据岗位面试失败的经历总结,这些面试的坑莫在踩

热门文章

  1. Idea快速选中一行的四种方式
  2. flutter_bloc使用及部分源码分析
  3. 经典版 树上有十只鸟,开枪打死了一只,还剩几只!
  4. python优势与劣势-python的优点和缺点是什么?
  5. SAP770系统FI模块配置(维护会计年度变式)
  6. VC++之 CreateEvent和SetEvent及WaitForSingleObject的用法
  7. ApacheCN 翻译活动进度公告 2019.3.17
  8. xp iis mysql php,XP下IIS配置PHP 和MySQL
  9. vivado2018.3创建一个流水灯(基于创龙k7核心开发板)
  10. 计算机网络购物支付说课,电子商务说课稿--网络支付