如果对 torch.unsqueeze 了解不多,建议先阅读博主的这篇文章。

1. Tensor 和 Scalar

import torchX = torch.tensor([[1,2,3],[4,5,6]]
print("X's shape:\n",X.shape)Y = torch.tensor(7
print("Y's shape:\n",Y.shape)Z1 = X*Y
print("Z1's shape:\n",Z1.shape)Z2 = Y*X
print("Z2's shape:\n",Z1.shape)


X:tensor([[1, 2, 3],[4, 5, 6]])
X's shape:torch.Size([2, 3])Y:tensor(7)
Y's shape:torch.Size([])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 14, 21],[28, 35, 42]])
Z2's shape:torch.Size([2, 3])

直接将 scalar 广播成 shape 为 [2,3] 的 tensor 了。

2. Tensor与Tensor: [2,3]*[3]


Y = torch.tensor([7,8,9]


X:tensor([[1, 2, 3],[4, 5, 6]])X's shape:torch.Size([2, 3])Y:tensor([7, 8, 9])
Y's shape:torch.Size([3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z1's shape:torch.Size([2, 3])Z:tensor([[ 7, 16, 27],[28, 40, 54]])
Z2's shape:torch.Size([2, 3])

将 shape 为 [3] 的tensor 广播为 [2,3]

3. Tensor与Tensor [2,3]*[3,1], 报错


Y = torch.tensor([[7],[8],[9]]


Y's shape:torch.Size([3, 1])
Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 16, in <module>Z = X*Y
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 0


4. X.shape: [5,3] Y.shape [5,6] 如何 broadcast ?

import torchX = torch.tensor([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]]
print("X's shape:\n",X.shape)X1 = X.unsqueeze(1)
print("X1's shape:\n",X1.shape,"\n")Y = torch.ones((5,6))
print("Y's shape:\n",Y.shape)Y1 = Y.unsqueeze(2)
print("Y1's shape:\n",Y1.shape,"\n")Z1 = X.unsqueeze(1)*Y.unsqueeze(2);
print("Z1's shape:",Z1.shape,"\n\n\n")
Z2 = X*Y


X:tensor([[ 1,  2,  3],[ 4,  5,  6],[ 7,  8,  9],[10, 11, 12],[13, 14, 15]])
X's shape:torch.Size([5, 3])
X1:tensor([[[ 1,  2,  3]],[[ 4,  5,  6]],[[ 7,  8,  9]],[[10, 11, 12]],[[13, 14, 15]]])
X1's shape:torch.Size([5, 1, 3]) Y:tensor([[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.],[1., 1., 1., 1., 1., 1.]])
Y's shape:torch.Size([5, 6])
Y1's shape:torch.Size([5, 6, 1]) Z1:tensor([[[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.],[ 1.,  2.,  3.]],[[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.],[ 4.,  5.,  6.]],[[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.],[ 7.,  8.,  9.]],[[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.],[10., 11., 12.]],[[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.],[13., 14., 15.]]])
Z1's shape: torch.Size([5, 6, 3]) Traceback (most recent call last):File "D:/Venv/Test/0710Test/Test.py", line 29, in <module>Z2 = X*Y
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 1Process finished with exit code 1

上面的 X 相当于 rays_direction 向量,维度是 [num_rays, 3],表示所有光线的方向向量;
上面的 Y 相当于 z_val 向量,维度是 [num_rays,num_sample], 表示所有光线的采样点的距离。
那么针对每条光线的每个采样点,都需要乘上 X 中的某一行,相当于是 zx⃗z\vec{x}zx 的操作。

X.unsqueeze(1): [5,3]→[5,1,3][5,3] \rightarrow [5,1,\mathbf{3}][5,3]→[5,1,3]
Y.unsqueeze(2) : [5,6]→[5,6,1][5,6] \rightarrow [5,\mathbf{6},1][5,6]→[5,6,1]

在执行 X.unsqueeze(1)*Y.unsqueeze(2): 的时候:
X.unsqueeze(1) 、Y.unsqueeze(2)都会被 broadcast 成 [5,6,3] 的 shape
看上面两行,X.unsqueeze(1)与Y.unsqueeze(2)相乘, 结果 tensor 的 sahpe 的每一个维度的值取它们两个对应值的较大值。即:
dim 1: max⁡{1,6}=6\max\{1,6\}=6max{1,6}=6
dim 2: max⁡{3,1}=3\max\{3,1\}=3max{3,1}=3

理解 unsqueeze 的很好的一个例子

经过 element-wise 乘法后,即 “*” 乘法后,
得到的结果的 shape 和 参与乘法运算的tensor的较大 dim 的 tensor的shape一致
NeRF 里面一个 weight 矩阵,它的 shape 是:N_rays * N_samples
还有一个 RGB tensor,它的shape 是:N_rays * N_samples*3
根据 Volume Rendering,同一个点的 RGB 的系数是一样的

nRays = 2
nSamples = 3weight = torch.tensor([[1,2,3],[7,4,8]])RGB = torch.tensor([[[0,3,0],[4,5,6],[7,8,9]],[[10,11,18],[8,58,6],[70,82,9]]
])print("weight's shape",weight.shape)
print("weight",weight)print("\n", "RGB's shape",RGB.shape)
print("RGB",RGB,"\n")weight_after_unsqueeze = weight.unsqueeze(-1)
print("weight_after_unsqueeze's shape: ",weight_after_unsqueeze.shape)
print("weight_after_unsqueeze",weight_after_unsqueeze)res2 = weight_after_unsqueeze*RGB
print("shape: ",res2.shape)


weight's shape torch.Size([2, 3])
weight tensor([[1, 2, 3],[7, 4, 8]])RGB's shape torch.Size([2, 3, 3])
RGB tensor([[[ 0,  3,  0],[ 4,  5,  6],[ 7,  8,  9]],[[10, 11, 18],[ 8, 58,  6],[70, 82,  9]]]) weight_after_unsqueeze's shape:  torch.Size([2, 3, 1])
weight_after_unsqueeze tensor([[[1],[2],[3]],[[7],[4],[8]]])
shape:  torch.Size([2, 3, 3])
tensor([[[  0,   3,   0],[  8,  10,  12],[ 21,  24,  27]],[[ 70,  77, 126],[ 32, 232,  24],[560, 656,  72]]])

上述代码中 N_rays 和 N_weights 分别取 2、3.
在 weight 经过 unsqueeze 后,得到的tensor的形式,竖着排的形式,会给人提示,是一个数占了一列。RGB三个分量都会乘上同一个数字,同一列的。

