Tensor unsqueeze 以 broadcast
如果对 torch.unsqueeze 了解不多,建议先阅读博主的这篇文章。
1. Tensor 和 Scalar
import torchX = torch.tensor([[1,2,3],[4,5,6]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)Y = torch.tensor(7
)
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)Z1 = X*Y
print("\nZ:\n",Z1)
print("Z1's shape:\n",Z1.shape)Z2 = Y*X
print("\nZ:\n",Z2)
print("Z2's shape:\n",Z1.shape)
输出:
X:tensor([[1, 2, 3],[4, 5, 6]])
X's shape:torch.Size([2, 3])Y:tensor(7)
Y's shape:torch.Size([])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z2's shape:torch.Size([2, 3])
直接将 scalar 广播成 shape 为 [2,3] 的 tensor 了。
2. Tensor与Tensor: [2,3]*[3]
仅修改下列代码
Y = torch.tensor([7,8,9]
)
结果
X:tensor([[1, 2, 3],[4, 5, 6]])X's shape:torch.Size([2, 3])Y:tensor([7, 8, 9])
Y's shape:torch.Size([3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z2's shape:torch.Size([2, 3])
将 shape 为 [3] 的tensor 广播为 [2,3]
3. Tensor与Tensor [2,3]*[3,1], 报错
仅仅修改下列代码:
Y = torch.tensor([[7],[8],[9]]
)
输出:
Y:tensor([[7],[8],[9]])
Y's shape:torch.Size([3, 1])
Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 16, in <module>Z = X*Y
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0
这个报错我暂时还没搞懂。
4. X.shape: [5,3] Y.shape [5,6] 如何 broadcast ?
import torchX = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
)
print("X:\n",X)
print("X's shape:\n",X.shape)X1 = X.unsqueeze(1)
print("X1:\n",X1)
print("X1's shape:\n",X1.shape,"\n")Y = torch.ones((5,6))
print("\nY:\n",Y)
print("Y's shape:\n",Y.shape)Y1 = Y.unsqueeze(2)
print("Y1:\n",Y1)
print("Y1's shape:\n",Y1.shape,"\n")Z1 = X.unsqueeze(1)*Y.unsqueeze(2);
print("\nZ1:\n",Z1)
print("Z1's shape:",Z1.shape,"\n\n\n")
Z2 = X*Y
输出:
X:tensor([[ 1, 2, 3],[ 4, 5, 6],[ 7, 8, 9],[10, 11, 12],[13, 14, 15]])
X's shape:torch.Size([5, 3])
X1:tensor([[[ 1, 2, 3]],[[ 4, 5, 6]],[[ 7, 8, 9]],[[10, 11, 12]],[[13, 14, 15]]])
X1's shape:torch.Size([5, 1, 3]) Y:tensor([[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.]])
Y's shape:torch.Size([5, 6])
Y1:tensor([[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]],[[1.],[1.],[1.],[1.],[1.],[1.]]])
Y1's shape:torch.Size([5, 6, 1]) Z1:tensor([[[ 1., 2., 3.],[ 1., 2., 3.],[ 1., 2., 3.],[ 1., 2., 3.],[ 1., 2., 3.],[ 1., 2., 3.]],[[ 4., 5., 6.],[ 4., 5., 6.],[ 4., 5., 6.],[ 4., 5., 6.],[ 4., 5., 6.],[ 4., 5., 6.]],[[ 7., 8., 9.],[ 7., 8., 9.],[ 7., 8., 9.],[ 7., 8., 9.],[ 7., 8., 9.],[ 7., 8., 9.]],[[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.]],[[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.]]])
Z1's shape: torch.Size([5, 6, 3]) Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 29, in <module>Z2 = X*Y
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1Process finished with exit code 1
上面的 X 相当于 rays_direction 向量,维度是 [num_rays, 3],表示所有光线的方向向量;
上面的 Y 相当于 z_val 向量,维度是 [num_rays,num_sample], 表示所有光线的采样点的距离。
那么针对每条光线的每个采样点,都需要乘上 X 中的某一行,相当于是 zx⃗z\vec{x}zx 的操作。
X:[5,3]
Y[5,6]
X.unsqueeze(1)*Y.unsqueeze(2):[5,6,3]
X.unsqueeze(1): [5,3]→[5,1,3][5,3] \rightarrow [5,1,\mathbf{3}][5,3]→[5,1,3]
Y.unsqueeze(2) : [5,6]→[5,6,1][5,6] \rightarrow [5,\mathbf{6},1][5,6]→[5,6,1]
在执行 X.unsqueeze(1)*Y.unsqueeze(2): 的时候:
X.unsqueeze(1) 、Y.unsqueeze(2)都会被 broadcast 成 [5,6,3] 的 shape
看上面两行,X.unsqueeze(1)与Y.unsqueeze(2)相乘, 结果 tensor 的 sahpe 的每一个维度的值取它们两个对应值的较大值。即:
dim 1: max{1,6}=6\max\{1,6\}=6max{1,6}=6
dim 2: max{3,1}=3\max\{3,1\}=3max{3,1}=3
理解 unsqueeze 的很好的一个例子
经过 element-wise 乘法后,即 “*” 乘法后,
得到的结果的 shape 和 参与乘法运算的tensor的较大 dim 的 tensor的shape一致
NeRF 里面一个 weight 矩阵,它的 shape 是:N_rays * N_samples
还有一个 RGB tensor,它的shape 是:N_rays * N_samples*3
根据 Volume Rendering,同一个点的 RGB 的系数是一样的
见下述代码:
nRays = 2
nSamples = 3weight = torch.tensor([[1,2,3],[7,4,8]])RGB = torch.tensor([[[0,3,0],[4,5,6],[7,8,9]],[[10,11,18],[8,58,6],[70,82,9]]
])print("weight's shape",weight.shape)
print("weight",weight)print("\n", "RGB's shape",RGB.shape)
print("RGB",RGB,"\n")weight_after_unsqueeze = weight.unsqueeze(-1)
print("weight_after_unsqueeze's shape: ",weight_after_unsqueeze.shape)
print("weight_after_unsqueeze",weight_after_unsqueeze)res2 = weight_after_unsqueeze*RGB
print("shape: ",res2.shape)
print(res2)
输出:
weight's shape torch.Size([2, 3])
weight tensor([[1, 2, 3],[7, 4, 8]])RGB's shape torch.Size([2, 3, 3])
RGB tensor([[[ 0, 3, 0],[ 4, 5, 6],[ 7, 8, 9]],[[10, 11, 18],[ 8, 58, 6],[70, 82, 9]]]) weight_after_unsqueeze's shape: torch.Size([2, 3, 1])
weight_after_unsqueeze tensor([[[1],[2],[3]],[[7],[4],[8]]])
shape: torch.Size([2, 3, 3])
tensor([[[ 0, 3, 0],[ 8, 10, 12],[ 21, 24, 27]],[[ 70, 77, 126],[ 32, 232, 24],[560, 656, 72]]])
上述代码中 N_rays 和 N_weights 分别取 2、3.
在 weight 经过 unsqueeze 后,得到的tensor的形式,竖着排的形式,会给人提示,是一个数占了一列。RGB三个分量都会乘上同一个数字,同一列的。
Tensor unsqueeze 以 broadcast相关推荐
- tensor.squeeze函数和tensor.unsqueeze函数的使用
tensor.squeeze() 和 tensor.unsqueeze() 是 PyTorch 中用于改变 tensor 形状的两个函数,它们的作用如下: tensor.squeeze(dim=Non ...
- pytorch: Tensor的创建与调整
测试环境版本: torch1.7.1 + CPU python 3.6 Tensor是pytorch中的"张量",可以看作是类似numpy的矩阵 本文介绍如何创建与调整Tensor ...
- Pytorch squeeze() 和 unsqueeze() 方法区别
1 增加维度 unsqueeze() tensor = tensor.unsqueeze(0) 2 减少维度 squeeze() tensor = tensor.squeeze(0)
- torch对于tensor的常规操作
前言 使用pytorch框架,会常操作tensor,以下则是对tensor常规操作的汇总. import torch torch.Tensor会继承某些torch的某些数学运算,例如sort, min ...
- 详解Tensor用法
Tensor的操作 如果本文对你有帮助,欢迎点赞.订阅以及star我的项目. 你的支持是我创作的最大动力! 张量的数据属性与 NumPy 数组类似,如下所示: 张量的操作主要包括张量的结构操作和张量的 ...
- 给tensor增加维度 或 减少维度
import torch import tensorflow as tf import numpy as np#tf.expand_dims(input, axis=1) <-> tf.s ...
- OneFlow 的 Global Tensor 学习笔记和实习总结
文章目录 1 前言 2 关于 Global Tensor 2.1 OneFlow 分布式全局视角的基础保证 2.2 SBP 自动转换 2.3 to_global 方法 2.4 GlobalTensor ...
- pytorch 定义torch类型数据_PyTorch官方中文文档:torch.Tensor
torch.Tensor torch.Tensor是一种包含单一数据类型元素的多维矩阵. Torch定义了七种CPU tensor类型和八种GPU tensor类型: Data tyoe CPU te ...
- Pytorch中torch.unsqueeze()和torch.squeeze()函数解析
一. torch.squeeze()函数解析 1. 官网链接 torch.squeeze(),如下图所示: 2. torch.squeeze()函数解析 torch.squeeze(input, di ...
最新文章
- Windows程序的基本结构(转)
- 在python中如何有效的比较两个无序的列表是否包含完全同样的元素(不是set)?
- 基于高阶累积量的数字调制信号分类(Hierarchical Digital Modulation Classification Using Cumulants例1复现)
- linux性能监控sar命令详解
- maven 打包数据库加密_SpringBoot项目application.yml文件数据库配置密码加密的方法...
- python mss_Python实现的连接mssql数据库操作示例
- mysql5.7安装
- Python:pip下载库后导入Pycharm的方法
- 怎样在VS2013/MFC中使用TeeChart绘图控件
- 服务器创建新文件夹权限设置密码,在服务器上修改文件夹权限设置密码
- python实现kmeans图像分割、一只遥望大海的小狗_【Python】爬虫+ K-means 聚类分析电影海报主色...
- 移动web网页开发——动画
- 如何删除计算机中常用列表,清除右键多余菜单,鼠标右键菜单清理的方法(一) -电脑资料...
- linux下RTNETLINK answers: File exists的解决方案 慎重
- 叮咚小区官网新闻已不更新
- 前端获取验证码、手机号登录、注册功能
- 如何批量重命名图片,文档,文件夹名字 Windows CMD 批量修改文件名字 内含修改路径的操作,想改哪里改哪里!
- Ansible安装使用
- 客户流失?来看看大厂如何基于spark+机器学习构建千万数据规模上的用户留存模型 ⛵
- 大数据岗位面试失败的经历总结,这些面试的坑莫在踩
热门文章
- Idea快速选中一行的四种方式
- flutter_bloc使用及部分源码分析
- 经典版 树上有十只鸟,开枪打死了一只,还剩几只!
- python优势与劣势-python的优点和缺点是什么?
- SAP770系统FI模块配置(维护会计年度变式)
- VC++之 CreateEvent和SetEvent及WaitForSingleObject的用法
- ApacheCN 翻译活动进度公告 2019.3.17
- xp iis mysql php,XP下IIS配置PHP 和MySQL
- vivado2018.3创建一个流水灯(基于创龙k7核心开发板)
- 计算机网络购物支付说课,电子商务说课稿--网络支付