文章目录

  • 一、PyTorch环境检查
  • 二、查看张量类型
  • 三、查看张量尺寸和所占内存大小
  • 四、创建张量
    • 4.1 创建值全为1的张量
    • 4.2 创建值全为0的张量
    • 4.3 创建值全为指定值的张量
    • 4.4 通过 list 创建张量
    • 4.5 通过 ndarray 创建张量
    • 4.6 创建指定范围和间距的有序张量
    • 4.7 创建单位矩阵(对角线为1)
  • 五、生成随机张量
    • 5.1 按均匀分布生成
    • 5.2 按标准正态分布生成
    • 5.3 生成指定区间的整型随机张量
    • 5.4 获取随机序列
  • 六、张量的索引与切片
    • 6.1 索引
    • 6.2 切片
      • 6.2.1 获取张量的前/后N个元素
      • 6.2.2 根据指定步长获取张量的前/后N个元素
      • 6.2.3 根据特殊索引获取张量值
      • 6.2.4 根据 mask 选取张量值
      • 6.2.5 根据展平的索引获取张量值
  • 七、张量的维度变换
    • 7.1 view 和 reshape 尺寸变换
    • 7.2 unsqueeze 升维
    • 7.3 squeeze 降维
    • 7.4 expand 扩展
    • 7.5 repeat 复制
    • 7.6 .t() 转置
    • 7.7 transpose 维度变换
    • 7.8 permute 维度变换
  • 八、张量的拼接和拆分
    • 8.1 cat
    • 8.2 stack
    • 8.3 split
    • 8.4 chunk
  • 九、基本运算
    • 9.1 广播机制
    • 9.2 matmul 矩阵/张量乘法
    • 9.3 pow 次方运算
    • 9.4 sqrt 平方根运算
    • 9.5 exp 指数幂运算
    • 9.6 log 对数运算(相当于 ln)
    • 9.7 取整
    • 9.8 clamp 控制张量的取值范围
  • 十、统计属性
    • 10.1 norm 求范数
    • 10.2 mean、median、sum、min、max、prod、argmax、argmin
    • 10.3 topk 获取最大的k个值
    • 10.4 kthvalue 获取第k大的值
    • 10.5 比较运算函数
  • 十一、高级操作
    • 11.1 where
    • 11.2 gather

一、PyTorch环境检查

import torch
# 输出PyTorch版本
print(torch.__version__)
# 检查PyTorch是否支持GPU加速
print("cuda:", torch.cuda.is_available())

输出:

1.8.0+cu101
cuda: True

二、查看张量类型

import torcha = torch.randn(2, 3)
b = torch.randint(0, 1, (2, 3))
print(a.type())
print(b.type())
print(type(a))
print(type(b))
print(isinstance(a, torch.FloatTensor))
print(isinstance(b, torch.FloatTensor))

输出:

torch.FloatTensor
torch.LongTensor
<class 'torch.Tensor'>
<class 'torch.Tensor'>
True
False

三、查看张量尺寸和所占内存大小

import torcha = torch.randn(2, 3)
print(a.size(), type(a.size()))
print(a.shape, type(a.shape))
print("维度数:", a.dim())
print("所占内存大小:", a.numel())

输出:

torch.Size([2, 3]) <class 'torch.Size'>
torch.Size([2, 3]) <class 'torch.Size'>
维度数: 2
所占内存大小: 6

四、创建张量

4.1 创建值全为1的张量

import torcha = torch.ones(2, 3)
print(a)

输出:

tensor([[1., 1., 1.],[1., 1., 1.]])

4.2 创建值全为0的张量

import torcha = torch.zeros(2, 3)
print(a)

输出:

tensor([[0., 0., 0.],[0., 0., 0.]])

4.3 创建值全为指定值的张量

import torcha = torch.full([2, 3], 6.6)
print(a)
print(a.shape)a = torch.full([], 6.6)
print(a)
print(a.shape)

输出:

tensor([[6.6000, 6.6000, 6.6000],[6.6000, 6.6000, 6.6000]])
torch.Size([2, 3])
tensor(6.6000)
torch.Size([])

4.4 通过 list 创建张量

import torchprint(torch.LongTensor([[1, 2], [3, 4]]))
print(torch.Tensor([[1, 2], [3, 4]]))
print(torch.FloatTensor([[1, 2], [3, 4]]))

输出:

tensor([[1, 2],[3, 4]])
tensor([[1., 2.],[3., 4.]])
tensor([[1., 2.],[3., 4.]])

4.5 通过 ndarray 创建张量

import torch
import numpy as npa = np.array([2, 3.3])
print(type(a))print(torch.from_numpy(a))

输出:

<class 'numpy.ndarray'>
tensor([2.0000, 3.3000], dtype=torch.float64)

4.6 创建指定范围和间距的有序张量

import torchprint("torch.arange(0,10):", torch.arange(0, 10))
print("torch.arange(0,10,2):", torch.arange(0, 10, 2))
print("torch.linspace(0,10,steps = 4):", torch.linspace(0, 10, steps=4))
print("torch.linspace(0,10,steps = 10):", torch.linspace(0, 10, steps=10))
print("torch.linspace(0,10,steps = 11):", torch.linspace(0, 10, steps=11))
print("torch.logspace(0,-1,steps = 10):", torch.logspace(0, -1, steps=10))
print("torch.logspace(0,1,steps = 10):", torch.logspace(0, 1, steps=10))

输出:

torch.arange(0,10): tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.arange(0,10,2): tensor([0, 2, 4, 6, 8])
torch.linspace(0,10,steps = 4): tensor([ 0.0000,  3.3333,  6.6667, 10.0000])
torch.linspace(0,10,steps = 10): tensor([ 0.0000,  1.1111,  2.2222,  3.3333,  4.4444,  5.5556,  6.6667,  7.7778,8.8889, 10.0000])
torch.linspace(0,10,steps = 11): tensor([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.])
torch.logspace(0,-1,steps = 10): tensor([1.0000, 0.7743, 0.5995, 0.4642, 0.3594, 0.2783, 0.2154, 0.1668, 0.1292,0.1000])
torch.logspace(0,1,steps = 10): tensor([ 1.0000,  1.2915,  1.6681,  2.1544,  2.7826,  3.5938,  4.6416,  5.9948,7.7426, 10.0000])

4.7 创建单位矩阵(对角线为1)

import torch# n * n
print(torch.eye(3))
print(torch.eye(4, 4))# 非 n * n
print(torch.eye(2, 3))

输出:

tensor([[1., 0., 0.],[0., 1., 0.],[0., 0., 1.]])
tensor([[1., 0., 0., 0.],[0., 1., 0., 0.],[0., 0., 1., 0.],[0., 0., 0., 1.]])
tensor([[1., 0., 0.],[0., 1., 0.]])

五、生成随机张量

5.1 按均匀分布生成

均匀分布:0-1之间

import torch# 生成shape为(2,3,2)的Tensor
random_tensor = torch.rand(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[0.3321, 0.7077],[0.8372, 0.2545],[0.5849, 0.0312]],[[0.6792, 0.8339],[0.9689, 0.5579],[0.2843, 0.6578]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.2 按标准正态分布生成

标准正态分布:均值为0,方差为1

import torch# 生成shape为(2,3,2)的Tensor
random_tensor = torch.randn(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[ 1.6622,  1.4002],[ 1.5145, -0.0427],[ 0.4082, -0.3527]],[[ 1.2381, -0.2409],[-1.0770, -1.1289],[-1.5798,  0.3093]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.3 生成指定区间的整型随机张量

import torch# 生成shape为(2,3,2)的Tensor
# 整数范围[1,4)
random_tensor = torch.randint(1, 4, (2, 3, 2))
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)

输出:

tensor([[[3, 2],[2, 2],[3, 3]],[[2, 1],[2, 1],[1, 1]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])

5.4 获取随机序列

import torch# torch中没有random.shuffle
# y = torch.randperm(n) y是把0到n-1这些数随机打乱得到的一个数字序列
# randperm(n, out=None, dtype=torch.int64)-> LongTensor
idx = torch.randperm(3)
a = torch.Tensor(4, 2)
print(a)
print(idx, idx.type())
print(a[idx])

输出:

tensor([[0.0000e+00, 5.5491e-43],[1.8754e+28, 8.0439e+20],[4.2767e-05, 1.0413e-11],[4.2002e-08, 6.5558e-10]])
tensor([0, 1, 2]) torch.LongTensor
tensor([[0.0000e+00, 5.5491e-43],[1.8754e+28, 8.0439e+20],[4.2767e-05, 1.0413e-11]])

六、张量的索引与切片

6.1 索引

import torcha = torch.rand(4, 3, 28, 28)
print("a[0].shape:", a[0].shape)
print("a[0,0].shape:", a[0, 0].shape)
print("a[0,0,2,4]:", a[0, 0, 2, 4])

输出:

a[0].shape: torch.Size([3, 28, 28])
a[0,0].shape: torch.Size([28, 28])
a[0,0,2,4]: tensor(0.2935)

6.2 切片

6.2.1 获取张量的前/后N个元素

import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a[:2].shape:", a[:2].shape)
print("a[:2,:1,:,:].shape:", a[:2, :1, :, :].shape)
print("a[:2,1:,:,:].shape:", a[:2, 1:, :, :].shape)
print("a[:2,-1:,:,:].shape:", a[:2, -1:, :, :].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a[:2].shape: torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])

6.2.2 根据指定步长获取张量的前/后N个元素

import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a[:,:,0:28:2,0:28:2].shape:", a[:, :, 0:28:2, 0:28:2].shape)
print("a[:,:,::2,::2].shape:", a[:, :, ::2, ::2].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a[:,:,0:28:2,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape: torch.Size([4, 3, 14, 14])

6.2.3 根据特殊索引获取张量值

import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a.index_select(0,torch.tensor([0,2])).shape:", a.index_select(0, torch.tensor([0, 2])).shape)
print("a.index_select(1,torch.tensor([1,2])).shape:", a.index_select(1, torch.tensor([1, 2])).shape)
print("a.index_select(2,torch.arange(28)).shape:", a.index_select(2, torch.arange(28)).shape)
print("a.index_select(2,torch.arange(8)).shape:", a.index_select(2, torch.arange(8)).shape)print("a[...].shape:", a[...].shape)
print("a[0,...].shape:", a[0, ...].shape)
print("a[:,1,...].shape:", a[:, 1, ...].shape)
print("a[...,:2].shape:", a[..., :2].shape)

输出:

a.shape: torch.Size([4, 3, 28, 28])
a.index_select(0,torch.tensor([0,2])).shape: torch.Size([2, 3, 28, 28])
a.index_select(1,torch.tensor([1,2])).shape: torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape: torch.Size([4, 3, 28, 28])
a.index_select(2,torch.arange(8)).shape: torch.Size([4, 3, 8, 28])
a[...].shape: torch.Size([4, 3, 28, 28])
a[0,...].shape: torch.Size([3, 28, 28])
a[:,1,...].shape: torch.Size([4, 28, 28])
a[...,:2].shape: torch.Size([4, 3, 28, 2])

6.2.4 根据 mask 选取张量值

import torcha = torch.rand(3,4)
print(a)mask = a.ge(0.5)
print(mask)b = torch.masked_select(a, mask)
print(b)
print(b.shape)

输出:

tensor([[0.6119, 0.3231, 0.8763, 0.6680],[0.5421, 0.0359, 0.2040, 0.2894],[0.5961, 0.7953, 0.2759, 0.7808]])
tensor([[ True, False,  True,  True],[ True, False, False, False],[ True,  True, False,  True]])
tensor([0.6119, 0.8763, 0.6680, 0.5421, 0.5961, 0.7953, 0.7808])
torch.Size([7])

6.2.5 根据展平的索引获取张量值

import torcha = torch.Tensor([[4, 3, 5], [6, 7, 8]])
print(a)
print(torch.take(a, torch.tensor([0, 2, -1])))

输出:

tensor([[4., 3., 5.],[6., 7., 8.]])
tensor([4., 5., 8.])

七、张量的维度变换

7.1 view 和 reshape 尺寸变换

view 和 reshape 的用法一致

import torcha = torch.rand(4, 1, 28, 28)
print(a.shape)print(a.view(4, 28 * 28).shape)
print(a.view(4 * 28, 28).shape)
print(a.view(4, 28, 28, 1).shape)

输出:

torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([112, 28])
torch.Size([4, 28, 28, 1])

7.2 unsqueeze 升维

import torcha = torch.rand(4, 1, 28, 28)
print("a.shape:", a.shape)print("a.unsqueeze(0).shape:", a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape:", a.unsqueeze(-1).shape)b = torch.rand(32)
print("b.shape:", b.shape)
print("b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape:", b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)

输出:

a.shape: torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape: torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape: torch.Size([4, 1, 28, 28, 1])
b.shape: torch.Size([32])
b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape: torch.Size([1, 32, 1, 1])

7.3 squeeze 降维

import torchb = torch.rand(4, 1, 28, 28)
print("b.shape:", b.shape)print("b.squeeze().shape:", b.squeeze().shape)
print("b.squeeze(0).shape:", b.squeeze(0).shape)
print("b.squeeze(-1).shape:", b.squeeze(-1).shape)

输出:

b.shape: torch.Size([4, 1, 28, 28])
b.squeeze().shape: torch.Size([4, 28, 28])
b.squeeze(0).shape: torch.Size([4, 1, 28, 28])
b.squeeze(-1).shape: torch.Size([4, 1, 28, 28])

7.4 expand 扩展

import torchb = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)print("b.expand(4,32,14,14).shape:", b.expand(4, 32, 14, 14).shape)
print("b.expand(-1,32,-1,-1).shape:", b.expand(-1, 32, -1, -1).shape)
print("b.expand(-1,32,-1,4).shape:", b.expand(-1, 32, -1, 4).shape)

输出:

b.shape: torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape: torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape: torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,4).shape: torch.Size([1, 32, 1, 4])

7.5 repeat 复制

import torchb = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)print("b.repeat(4,32,1,1).shape:", b.repeat(4, 32, 1, 1).shape)
print("b.repeat(4,1,1,1).shape:", b.repeat(4, 1, 1, 1).shape)
print("b.repeat(4,1,32,32).shape:", b.repeat(4, 1, 32, 32).shape)

输出:

b.shape: torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape: torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape: torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape: torch.Size([4, 32, 32, 32])

7.6 .t() 转置

import torchb = torch.rand(3, 4)
print(b)
print(b.t())

输出:

tensor([[0.3598, 0.3820, 0.9488, 0.2987],[0.7339, 0.2339, 0.5251, 0.2017],[0.8442, 0.6528, 0.2914, 0.5034]])
tensor([[0.3598, 0.7339, 0.8442],[0.3820, 0.2339, 0.6528],[0.9488, 0.5251, 0.2914],[0.2987, 0.2017, 0.5034]])

7.7 transpose 维度变换

import torcha = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.transpose(1, 3).shape)

输出:

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])

7.8 permute 维度变换

import torcha = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.permute(0,2,3,1).shape)

输出:

torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])

八、张量的拼接和拆分

8.1 cat

import torcha = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0)
print(c.shape) # torch.Size([9, 32, 8])

8.2 stack

import torcha1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
c = torch.stack([a1,a2],dim = 2)
print(c.shape) # torch.Size([4, 3, 2, 16, 32])

8.3 split

import torcha = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])aa, bb = c.split([1, 1], dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])aa, bb = c.split([20, 12], dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 20, 8]) torch.Size([2, 12, 8])

8.4 chunk

import torcha = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape)  # torch.Size([2, 32, 8])aa, bb = c.chunk(2, dim=0)
print(aa.shape, bb.shape)  # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])aa, bb = c.chunk(2, dim=1)
print(aa.shape, bb.shape)  # torch.Size([2, 16, 8]) torch.Size([2, 16, 8])aa, bb, cc, dd = c.chunk(4, dim=1)
print(aa.shape, bb.shape, cc.shape,dd.shape)  #torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8])

九、基本运算

9.1 广播机制

import torcha = torch.rand(2,2)
print(a)
b = torch.rand(2)
print(b)
print(a+b)

输出:

tensor([[0.4668, 0.6053],[0.5321, 0.8734]])
tensor([0.7595, 0.6517])
tensor([[1.2263, 1.2570],[1.2916, 1.5251]])

9.2 matmul 矩阵/张量乘法

import torcha = torch.ones(2, 2) * 3
b = torch.ones(2, 2)
print(a)
print(b)
print(torch.matmul(a, b))

输出:

tensor([[3., 3.],[3., 3.]])
tensor([[1., 1.],[1., 1.]])
tensor([[6., 6.],[6., 6.]])

9.3 pow 次方运算

import torcha = torch.ones(2, 2) * 3
print(a)
print(torch.pow(a, 3))

输出:

tensor([[3., 3.],[3., 3.]])
tensor([[27., 27.],[27., 27.]])

9.4 sqrt 平方根运算

import torcha = torch.ones(2, 2) * 9
print(a)
print(torch.pow(a, 0.5))
print(torch.sqrt(a))

输出:

tensor([[9., 9.],[9., 9.]])
tensor([[3., 3.],[3., 3.]])
tensor([[3., 3.],[3., 3.]])

9.5 exp 指数幂运算

import torcha = torch.ones(2, 2)
print(a)
print(torch.exp(a))

输出:

tensor([[1., 1.],[1., 1.]])
tensor([[2.7183, 2.7183],[2.7183, 2.7183]])

9.6 log 对数运算(相当于 ln)

import torcha = torch.ones(2, 2) * 3
print(a)
print(torch.log(a))

输出:

tensor([[3., 3.],[3., 3.]])
tensor([[1.0986, 1.0986],[1.0986, 1.0986]])

9.7 取整

  • floor():向下取整
  • ceil():向上取整
  • round():四舍五入
  • trunc():截取整数部分
  • frac():截取小数部分
import torcha = torch.tensor(3.14)
print(a)  # tensor(3.1400)
print(torch.floor(a))  #tensor(3.)
print(torch.ceil(a))  #tensor(4.)
print(torch.round(a))  #tensor(3.)
print(torch.trunc(a))  #tensor(3.)
print(torch.frac(a))  #tensor(0.1400)

9.8 clamp 控制张量的取值范围

import torcha = torch.rand(2, 3) * 15print(a)
# 将大于8的值设置为8;小于4的值设置为4
print(torch.clamp(a, 4, 8))

输出:

tensor([[ 8.8872,  5.6534, 14.3027],[ 0.8305, 12.6266, 13.9683]])
tensor([[8.0000, 5.6534, 8.0000],[4.0000, 8.0000, 8.0000]])

十、统计属性

10.1 norm 求范数

import torcha = torch.ones(2, 3)
b = torch.norm(a)  # 默认求2范数
c = torch.norm(a, p=1)  # 指定求1范数
print(a)
print(b)
print(c)

输出:

tensor([[1., 1., 1.],[1., 1., 1.]])
tensor(2.4495)
tensor(6.)

10.2 mean、median、sum、min、max、prod、argmax、argmin

  • prod():返回张量里所有元素的乘积
  • armax():返回张量中最大元素的展平索引
  • argmin():返回张量中最小元素的展平索引
import torcha = torch.arange(8).view(2, 4).float()
print(a)
'''
tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]])
'''
print(a.mean())  #tensor(3.5000)
print(a.median())  #tensor(3.)
print(a.sum())  #tensor(28.)
print(a.min())  #tensor(0.)
print(a.max())  #tensor(7.)
print(a.prod())  #tensor(0.)
print(a.argmax())  #tensor(7)
print(a.argmin())  #tensor(0)
import torcha = torch.rand(2,4)
print(a)print(a.max(dim = 1))
print(a.max(dim = 1,keepdim = True))

输出:

tensor([[0.7239, 0.9412, 0.7602, 0.2131],[0.6277, 0.1033, 0.8300, 0.9909]])
torch.return_types.max(
values=tensor([0.9412, 0.9909]),
indices=tensor([1, 3]))
torch.return_types.max(
values=tensor([[0.9412],[0.9909]]),
indices=tensor([[1],[3]]))

10.3 topk 获取最大的k个值

import torcha = torch.rand(2,4)
print(a)
print(a.topk(2,dim=1))
'''
tensor([[0.3247, 0.9220, 0.4314, 0.8123],[0.7133, 0.2471, 0.0281, 0.3595]])
torch.return_types.topk(
values=tensor([[0.9220, 0.8123],[0.7133, 0.3595]]),
indices=tensor([[1, 3],[0, 3]]))
'''

10.4 kthvalue 获取第k大的值

import torcha = torch.rand(2, 4)
print(a)
print(a.kthvalue(3,dim=1))
'''
tensor([[0.0980, 0.0479, 0.9298, 0.5638],[0.9095, 0.9071, 0.4913, 0.6144]])
torch.return_types.kthvalue(
values=tensor([0.5638, 0.9071]),
indices=tensor([3, 1]))
'''

10.5 比较运算函数

import torcha = torch.rand(2, 3)
print(a)
'''
tensor([[0.1196, 0.5068, 0.9272],[0.6395, 0.2433, 0.9702]])
'''
# a >= 0.5
print(a.ge(0.5))
'''
tensor([[False,  True,  True],[ True, False,  True]])
'''
# a > 0.5
print(a.gt(0.5))
'''
tensor([[False,  True,  True],[ True, False,  True]])
'''
# a <= 0.5
print(a.le(0.5))
'''
tensor([[ True, False, False],[False,  True, False]])
'''
# a < 0.5
print(a.lt(0.5))
'''
tensor([[ True, False, False],[False,  True, False]])
'''
# a = 0.5
print(a.eq(0.5))
'''
tensor([[False, False, False],[False, False, False]])
'''

十一、高级操作

11.1 where

import torchcond = torch.rand(2, 2)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(cond)
'''
tensor([[0.3622, 0.9658],[0.1774, 0.6670]])
'''
print(a)
'''
tensor([[0., 0.],[0., 0.]])
'''
print(b)
'''
tensor([[1., 1.],[1., 1.]])
'''
# 满足条件cond.ge(0.5)的按照a的对应元素赋值,否则按照b的对应元素赋值
print(torch.where(cond.ge(0.5), a, b))
'''
tensor([[1., 0.],[1., 0.]])
'''

11.2 gather

帮助我们从批量tensor中取出指定乱序索引下的数据

import torcha = torch.arange(3, 12).view(3, 3)
print(a)
'''
tensor([[ 3,  4,  5],[ 6,  7,  8],[ 9, 10, 11]])
'''index = torch.tensor([[2, 1, 0]])
print(a.gather(1, index)) # tensor([[5, 4, 3]])index = torch.tensor([[2, 1, 0]]).t()
print(a.gather(1, index))
'''
tensor([[5],[7],[9]])
'''index = torch.tensor([[0, 2],[1, 2]])
print(a.gather(1, index))
'''
tensor([[3, 5],[7, 8]])
'''

参考链接:图解PyTorch中的torch.gather函数

在强化学习DQN中的使用 gather() 函数

【深度学习】超详细的 PyTorch 学习笔记(上)相关推荐

  1. 超详细!Vue-coderwhy个人学习笔记(二)(Day3)

    前言 本文章接上一篇笔记 超详细!Vue-coderwhy个人学习笔记(一)(Day1-Day2) 这篇主要是Day3笔记,组件化,组件通信,插槽 四.组件化开发 (一).内容概述 认识组件化 注册组 ...

  2. 超详细的Git学习记录(Git基础内容/IDEA集成Git/GitHub/Gitee/GitLab及Centos7部署GitLab)

    超详细的Git学习笔记 从B站搜到的尚硅谷视频学习了Git,记录了一下学习的内容,收获很大 学习地址: https://www.bilibili.com/video/BV1vy4y1s7k6?p=11 ...

  3. 【libuv高效编程】libuv学习超详细教程3——libuv事件循环

    文章目录 libuv系列文章 libuv事件循环 uv_loop_t demo uv_loop_init() uv_run() uv_loop_close() 参考 例程代码获取 libuv系列文章 ...

  4. LiteFlow学习(超详细)

    LiteFlow学习(超详细) 文章目录 LiteFlow学习(超详细) 1. LiteFlow简介 1.1 前言 1.2 LiteFlow框架的优势 1.3 LiteFlow的设计原则 1.4 Li ...

  5. 【网速】Visual Studio 下载太慢的问题的解决办法【超详细,来源于学习笔记】

    Visual Studio 下载太慢的问题的解决办法[详细,来源于学习的笔记] Visual Studio 下载太慢的解决办法两个步骤即可: 一.测试DNS 二.修改host 做完以上工作后,VS的下 ...

  6. 【超详细】嵌入式软件学习大纲

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_34981463/article/ ...

  7. 适用于任意模糊内核的深度即插即用超分辨率(DPSR论文笔记-2019CVPR)

    Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels (适用于任意模糊内核的深度即插即用超分辨率) 源码包:https://gi ...

  8. 一篇超详细的pytorch基础语法讲解及理论推导(一)

    张量 - 线性回归 - 自动求导 - 逻辑回归 来源:投稿 来源:阿克西 编辑:学姐 1 pytorch简介 PyTorch是2017年1月FAIR(Facebook AI Research)发布的一 ...

  9. 超详细的计算机视觉学习书籍pdf汇总(涉及CV、深度学习、多视图几何、SLAM、点云处理等)

    计算机视觉入门的一些pdf书籍,[计算机视觉工坊]按照不同领域帮大家划分了下,涉及深度学习基础.目标检测.Opencv.SLAM.点云.多视图集合.三维重建等~ 计算机视觉 1. 计算机视觉算法与应用 ...

最新文章

  1. Node和java和php,服务端I/O性能大比拼:Node、PHP、Java和Go(三)
  2. dtree.js树的使用
  3. VB6.0连接MySQL数据库
  4. NYOJ15-括号匹配(二)-区间DP
  5. Linux系统下软件包管理六
  6. oracle查询用户下所有表名称
  7. Redis Cluster集群的搭建与实践
  8. DX使用随笔--NavBarControl
  9. 复制百度文库的文字加什么后缀_下载百度文库文档 怎么快速提取百度文库中可以完整阅读的文档...
  10. PDF软件有这么好用的打印机,你知道吗?
  11. 物理 常见力与牛顿三定律
  12. 太牛了!B 站 UP 主开发会写高考作文的 AI
  13. oracle显示连接超时,Oracle 12179:tns:连接超时的问题
  14. c语言编写图书检索系统,求C语言编写图书管理系统
  15. 同花顺_代码解析_技术指标_P、Q
  16. π=4*atan(1.0);
  17. 模块电路选型(6)----存储模块
  18. OSChina 周四乱弹 ——PM是这样学程序的
  19. 70个python项目代码_python项目实例源码
  20. 2021-12-6 《聪明的投资者》学习笔记-3.一个世纪的股市历史:1972年年初的股价水平-股市周期性。股价、利润和股息

热门文章

  1. ShareSDK Android端权限说明
  2. $(this).addClass('class').siblings('class').removeClass('class')的作用
  3. QQ聊天记录备份BAK文件的修复方法
  4. WannaCry席卷全球 软件作者到底赚了多少钱?
  5. Win10实现窗口AeroGlass化
  6. 【PTA】求交错序列前N项和
  7. vue项目点击后,从左边或右边滑出组件,再次点击原路滑回。<transition>、transform
  8. 计算机软件由程序数据和文档组成其中主体是,chap03 计算机软件
  9. VRTK抓取触碰交互
  10. NG-ZORRO1.x自定义主题