【深度学习】超详细的 PyTorch 学习笔记(上)
文章目录
- 一、PyTorch环境检查
- 二、查看张量类型
- 三、查看张量尺寸和所占内存大小
- 四、创建张量
- 4.1 创建值全为1的张量
- 4.2 创建值全为0的张量
- 4.3 创建值全为指定值的张量
- 4.4 通过 list 创建张量
- 4.5 通过 ndarray 创建张量
- 4.6 创建指定范围和间距的有序张量
- 4.7 创建单位矩阵(对角线为1)
- 五、生成随机张量
- 5.1 按均匀分布生成
- 5.2 按标准正态分布生成
- 5.3 生成指定区间的整型随机张量
- 5.4 获取随机序列
- 六、张量的索引与切片
- 6.1 索引
- 6.2 切片
- 6.2.1 获取张量的前/后N个元素
- 6.2.2 根据指定步长获取张量的前/后N个元素
- 6.2.3 根据特殊索引获取张量值
- 6.2.4 根据 mask 选取张量值
- 6.2.5 根据展平的索引获取张量值
- 七、张量的维度变换
- 7.1 view 和 reshape 尺寸变换
- 7.2 unsqueeze 升维
- 7.3 squeeze 降维
- 7.4 expand 扩展
- 7.5 repeat 复制
- 7.6 .t() 转置
- 7.7 transpose 维度变换
- 7.8 permute 维度变换
- 八、张量的拼接和拆分
- 8.1 cat
- 8.2 stack
- 8.3 split
- 8.4 chunk
- 九、基本运算
- 9.1 广播机制
- 9.2 matmul 矩阵/张量乘法
- 9.3 pow 次方运算
- 9.4 sqrt 平方根运算
- 9.5 exp 指数幂运算
- 9.6 log 对数运算(相当于 ln)
- 9.7 取整
- 9.8 clamp 控制张量的取值范围
- 十、统计属性
- 10.1 norm 求范数
- 10.2 mean、median、sum、min、max、prod、argmax、argmin
- 10.3 topk 获取最大的k个值
- 10.4 kthvalue 获取第k大的值
- 10.5 比较运算函数
- 十一、高级操作
- 11.1 where
- 11.2 gather
一、PyTorch环境检查
import torch
# 输出PyTorch版本
print(torch.__version__)
# 检查PyTorch是否支持GPU加速
print("cuda:", torch.cuda.is_available())
输出:
1.8.0+cu101
cuda: True
二、查看张量类型
import torcha = torch.randn(2, 3)
b = torch.randint(0, 1, (2, 3))
print(a.type())
print(b.type())
print(type(a))
print(type(b))
print(isinstance(a, torch.FloatTensor))
print(isinstance(b, torch.FloatTensor))
输出:
torch.FloatTensor
torch.LongTensor
<class 'torch.Tensor'>
<class 'torch.Tensor'>
True
False
三、查看张量尺寸和所占内存大小
import torcha = torch.randn(2, 3)
print(a.size(), type(a.size()))
print(a.shape, type(a.shape))
print("维度数:", a.dim())
print("所占内存大小:", a.numel())
输出:
torch.Size([2, 3]) <class 'torch.Size'>
torch.Size([2, 3]) <class 'torch.Size'>
维度数: 2
所占内存大小: 6
四、创建张量
4.1 创建值全为1的张量
import torcha = torch.ones(2, 3)
print(a)
输出:
tensor([[1., 1., 1.],[1., 1., 1.]])
4.2 创建值全为0的张量
import torcha = torch.zeros(2, 3)
print(a)
输出:
tensor([[0., 0., 0.],[0., 0., 0.]])
4.3 创建值全为指定值的张量
import torcha = torch.full([2, 3], 6.6)
print(a)
print(a.shape)a = torch.full([], 6.6)
print(a)
print(a.shape)
输出:
tensor([[6.6000, 6.6000, 6.6000],[6.6000, 6.6000, 6.6000]])
torch.Size([2, 3])
tensor(6.6000)
torch.Size([])
4.4 通过 list 创建张量
import torchprint(torch.LongTensor([[1, 2], [3, 4]]))
print(torch.Tensor([[1, 2], [3, 4]]))
print(torch.FloatTensor([[1, 2], [3, 4]]))
输出:
tensor([[1, 2],[3, 4]])
tensor([[1., 2.],[3., 4.]])
tensor([[1., 2.],[3., 4.]])
4.5 通过 ndarray 创建张量
import torch
import numpy as npa = np.array([2, 3.3])
print(type(a))print(torch.from_numpy(a))
输出:
<class 'numpy.ndarray'>
tensor([2.0000, 3.3000], dtype=torch.float64)
4.6 创建指定范围和间距的有序张量
import torchprint("torch.arange(0,10):", torch.arange(0, 10))
print("torch.arange(0,10,2):", torch.arange(0, 10, 2))
print("torch.linspace(0,10,steps = 4):", torch.linspace(0, 10, steps=4))
print("torch.linspace(0,10,steps = 10):", torch.linspace(0, 10, steps=10))
print("torch.linspace(0,10,steps = 11):", torch.linspace(0, 10, steps=11))
print("torch.logspace(0,-1,steps = 10):", torch.logspace(0, -1, steps=10))
print("torch.logspace(0,1,steps = 10):", torch.logspace(0, 1, steps=10))
输出:
torch.arange(0,10): tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
torch.arange(0,10,2): tensor([0, 2, 4, 6, 8])
torch.linspace(0,10,steps = 4): tensor([ 0.0000, 3.3333, 6.6667, 10.0000])
torch.linspace(0,10,steps = 10): tensor([ 0.0000, 1.1111, 2.2222, 3.3333, 4.4444, 5.5556, 6.6667, 7.7778,8.8889, 10.0000])
torch.linspace(0,10,steps = 11): tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
torch.logspace(0,-1,steps = 10): tensor([1.0000, 0.7743, 0.5995, 0.4642, 0.3594, 0.2783, 0.2154, 0.1668, 0.1292,0.1000])
torch.logspace(0,1,steps = 10): tensor([ 1.0000, 1.2915, 1.6681, 2.1544, 2.7826, 3.5938, 4.6416, 5.9948,7.7426, 10.0000])
4.7 创建单位矩阵(对角线为1)
import torch# n * n
print(torch.eye(3))
print(torch.eye(4, 4))# 非 n * n
print(torch.eye(2, 3))
输出:
tensor([[1., 0., 0.],[0., 1., 0.],[0., 0., 1.]])
tensor([[1., 0., 0., 0.],[0., 1., 0., 0.],[0., 0., 1., 0.],[0., 0., 0., 1.]])
tensor([[1., 0., 0.],[0., 1., 0.]])
五、生成随机张量
5.1 按均匀分布生成
均匀分布:0-1之间
import torch# 生成shape为(2,3,2)的Tensor
random_tensor = torch.rand(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)
输出:
tensor([[[0.3321, 0.7077],[0.8372, 0.2545],[0.5849, 0.0312]],[[0.6792, 0.8339],[0.9689, 0.5579],[0.2843, 0.6578]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])
5.2 按标准正态分布生成
标准正态分布:均值为0,方差为1
import torch# 生成shape为(2,3,2)的Tensor
random_tensor = torch.randn(2, 3, 2)
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)
输出:
tensor([[[ 1.6622, 1.4002],[ 1.5145, -0.0427],[ 0.4082, -0.3527]],[[ 1.2381, -0.2409],[-1.0770, -1.1289],[-1.5798, 0.3093]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])
5.3 生成指定区间的整型随机张量
import torch# 生成shape为(2,3,2)的Tensor
# 整数范围[1,4)
random_tensor = torch.randint(1, 4, (2, 3, 2))
print(random_tensor)
print(type(random_tensor))
print(random_tensor.shape)
输出:
tensor([[[3, 2],[2, 2],[3, 3]],[[2, 1],[2, 1],[1, 1]]])
<class 'torch.Tensor'>
torch.Size([2, 3, 2])
5.4 获取随机序列
import torch# torch中没有random.shuffle
# y = torch.randperm(n) y是把0到n-1这些数随机打乱得到的一个数字序列
# randperm(n, out=None, dtype=torch.int64)-> LongTensor
idx = torch.randperm(3)
a = torch.Tensor(4, 2)
print(a)
print(idx, idx.type())
print(a[idx])
输出:
tensor([[0.0000e+00, 5.5491e-43],[1.8754e+28, 8.0439e+20],[4.2767e-05, 1.0413e-11],[4.2002e-08, 6.5558e-10]])
tensor([0, 1, 2]) torch.LongTensor
tensor([[0.0000e+00, 5.5491e-43],[1.8754e+28, 8.0439e+20],[4.2767e-05, 1.0413e-11]])
六、张量的索引与切片
6.1 索引
import torcha = torch.rand(4, 3, 28, 28)
print("a[0].shape:", a[0].shape)
print("a[0,0].shape:", a[0, 0].shape)
print("a[0,0,2,4]:", a[0, 0, 2, 4])
输出:
a[0].shape: torch.Size([3, 28, 28])
a[0,0].shape: torch.Size([28, 28])
a[0,0,2,4]: tensor(0.2935)
6.2 切片
6.2.1 获取张量的前/后N个元素
import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a[:2].shape:", a[:2].shape)
print("a[:2,:1,:,:].shape:", a[:2, :1, :, :].shape)
print("a[:2,1:,:,:].shape:", a[:2, 1:, :, :].shape)
print("a[:2,-1:,:,:].shape:", a[:2, -1:, :, :].shape)
输出:
a.shape: torch.Size([4, 3, 28, 28])
a[:2].shape: torch.Size([2, 3, 28, 28])
a[:2,:1,:,:].shape: torch.Size([2, 1, 28, 28])
a[:2,1:,:,:].shape: torch.Size([2, 2, 28, 28])
a[:2,-1:,:,:].shape: torch.Size([2, 1, 28, 28])
6.2.2 根据指定步长获取张量的前/后N个元素
import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a[:,:,0:28:2,0:28:2].shape:", a[:, :, 0:28:2, 0:28:2].shape)
print("a[:,:,::2,::2].shape:", a[:, :, ::2, ::2].shape)
输出:
a.shape: torch.Size([4, 3, 28, 28])
a[:,:,0:28:2,0:28:2].shape: torch.Size([4, 3, 14, 14])
a[:,:,::2,::2].shape: torch.Size([4, 3, 14, 14])
6.2.3 根据特殊索引获取张量值
import torcha = torch.rand(4, 3, 28, 28)
print("a.shape:", a.shape)print("a.index_select(0,torch.tensor([0,2])).shape:", a.index_select(0, torch.tensor([0, 2])).shape)
print("a.index_select(1,torch.tensor([1,2])).shape:", a.index_select(1, torch.tensor([1, 2])).shape)
print("a.index_select(2,torch.arange(28)).shape:", a.index_select(2, torch.arange(28)).shape)
print("a.index_select(2,torch.arange(8)).shape:", a.index_select(2, torch.arange(8)).shape)print("a[...].shape:", a[...].shape)
print("a[0,...].shape:", a[0, ...].shape)
print("a[:,1,...].shape:", a[:, 1, ...].shape)
print("a[...,:2].shape:", a[..., :2].shape)
输出:
a.shape: torch.Size([4, 3, 28, 28])
a.index_select(0,torch.tensor([0,2])).shape: torch.Size([2, 3, 28, 28])
a.index_select(1,torch.tensor([1,2])).shape: torch.Size([4, 2, 28, 28])
a.index_select(2,torch.arange(28)).shape: torch.Size([4, 3, 28, 28])
a.index_select(2,torch.arange(8)).shape: torch.Size([4, 3, 8, 28])
a[...].shape: torch.Size([4, 3, 28, 28])
a[0,...].shape: torch.Size([3, 28, 28])
a[:,1,...].shape: torch.Size([4, 28, 28])
a[...,:2].shape: torch.Size([4, 3, 28, 2])
6.2.4 根据 mask 选取张量值
import torcha = torch.rand(3,4)
print(a)mask = a.ge(0.5)
print(mask)b = torch.masked_select(a, mask)
print(b)
print(b.shape)
输出:
tensor([[0.6119, 0.3231, 0.8763, 0.6680],[0.5421, 0.0359, 0.2040, 0.2894],[0.5961, 0.7953, 0.2759, 0.7808]])
tensor([[ True, False, True, True],[ True, False, False, False],[ True, True, False, True]])
tensor([0.6119, 0.8763, 0.6680, 0.5421, 0.5961, 0.7953, 0.7808])
torch.Size([7])
6.2.5 根据展平的索引获取张量值
import torcha = torch.Tensor([[4, 3, 5], [6, 7, 8]])
print(a)
print(torch.take(a, torch.tensor([0, 2, -1])))
输出:
tensor([[4., 3., 5.],[6., 7., 8.]])
tensor([4., 5., 8.])
七、张量的维度变换
7.1 view 和 reshape 尺寸变换
view 和 reshape 的用法一致
import torcha = torch.rand(4, 1, 28, 28)
print(a.shape)print(a.view(4, 28 * 28).shape)
print(a.view(4 * 28, 28).shape)
print(a.view(4, 28, 28, 1).shape)
输出:
torch.Size([4, 1, 28, 28])
torch.Size([4, 784])
torch.Size([112, 28])
torch.Size([4, 28, 28, 1])
7.2 unsqueeze 升维
import torcha = torch.rand(4, 1, 28, 28)
print("a.shape:", a.shape)print("a.unsqueeze(0).shape:", a.unsqueeze(0).shape)
print("a.unsqueeze(-1).shape:", a.unsqueeze(-1).shape)b = torch.rand(32)
print("b.shape:", b.shape)
print("b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape:", b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape)
输出:
a.shape: torch.Size([4, 1, 28, 28])
a.unsqueeze(0).shape: torch.Size([1, 4, 1, 28, 28])
a.unsqueeze(-1).shape: torch.Size([4, 1, 28, 28, 1])
b.shape: torch.Size([32])
b.unsqueeze(1).unsqueeze(2).unsqueeze(0).shape: torch.Size([1, 32, 1, 1])
7.3 squeeze 降维
import torchb = torch.rand(4, 1, 28, 28)
print("b.shape:", b.shape)print("b.squeeze().shape:", b.squeeze().shape)
print("b.squeeze(0).shape:", b.squeeze(0).shape)
print("b.squeeze(-1).shape:", b.squeeze(-1).shape)
输出:
b.shape: torch.Size([4, 1, 28, 28])
b.squeeze().shape: torch.Size([4, 28, 28])
b.squeeze(0).shape: torch.Size([4, 1, 28, 28])
b.squeeze(-1).shape: torch.Size([4, 1, 28, 28])
7.4 expand 扩展
import torchb = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)print("b.expand(4,32,14,14).shape:", b.expand(4, 32, 14, 14).shape)
print("b.expand(-1,32,-1,-1).shape:", b.expand(-1, 32, -1, -1).shape)
print("b.expand(-1,32,-1,4).shape:", b.expand(-1, 32, -1, 4).shape)
输出:
b.shape: torch.Size([1, 32, 1, 1])
b.expand(4,32,14,14).shape: torch.Size([4, 32, 14, 14])
b.expand(-1,32,-1,-1).shape: torch.Size([1, 32, 1, 1])
b.expand(-1,32,-1,4).shape: torch.Size([1, 32, 1, 4])
7.5 repeat 复制
import torchb = torch.rand(1, 32, 1, 1)
print("b.shape:", b.shape)print("b.repeat(4,32,1,1).shape:", b.repeat(4, 32, 1, 1).shape)
print("b.repeat(4,1,1,1).shape:", b.repeat(4, 1, 1, 1).shape)
print("b.repeat(4,1,32,32).shape:", b.repeat(4, 1, 32, 32).shape)
输出:
b.shape: torch.Size([1, 32, 1, 1])
b.repeat(4,32,1,1).shape: torch.Size([4, 1024, 1, 1])
b.repeat(4,1,1,1).shape: torch.Size([4, 32, 1, 1])
b.repeat(4,1,32,32).shape: torch.Size([4, 32, 32, 32])
7.6 .t() 转置
import torchb = torch.rand(3, 4)
print(b)
print(b.t())
输出:
tensor([[0.3598, 0.3820, 0.9488, 0.2987],[0.7339, 0.2339, 0.5251, 0.2017],[0.8442, 0.6528, 0.2914, 0.5034]])
tensor([[0.3598, 0.7339, 0.8442],[0.3820, 0.2339, 0.6528],[0.9488, 0.5251, 0.2914],[0.2987, 0.2017, 0.5034]])
7.7 transpose 维度变换
import torcha = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.transpose(1, 3).shape)
输出:
torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])
7.8 permute 维度变换
import torcha = torch.rand(4, 3, 28, 28)
print(a.shape)
print(a.permute(0,2,3,1).shape)
输出:
torch.Size([4, 3, 28, 28])
torch.Size([4, 28, 28, 3])
八、张量的拼接和拆分
8.1 cat
import torcha = torch.rand(4, 32, 8)
b = torch.rand(5, 32, 8)
c = torch.cat([a, b], dim=0)
print(c.shape) # torch.Size([9, 32, 8])
8.2 stack
import torcha1 = torch.rand(4,3,16,32)
a2 = torch.rand(4,3,16,32)
c = torch.stack([a1,a2],dim = 2)
print(c.shape) # torch.Size([4, 3, 2, 16, 32])
8.3 split
import torcha = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape) # torch.Size([2, 32, 8])aa, bb = c.split([1, 1], dim=0)
print(aa.shape, bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])aa, bb = c.split([20, 12], dim=1)
print(aa.shape, bb.shape) # torch.Size([2, 20, 8]) torch.Size([2, 12, 8])
8.4 chunk
import torcha = torch.rand(32, 8)
b = torch.rand(32, 8)
c = torch.stack([a, b], dim=0)
print(c.shape) # torch.Size([2, 32, 8])aa, bb = c.chunk(2, dim=0)
print(aa.shape, bb.shape) # torch.Size([1, 32, 8]) torch.Size([1, 32, 8])aa, bb = c.chunk(2, dim=1)
print(aa.shape, bb.shape) # torch.Size([2, 16, 8]) torch.Size([2, 16, 8])aa, bb, cc, dd = c.chunk(4, dim=1)
print(aa.shape, bb.shape, cc.shape,dd.shape) #torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8]) torch.Size([2, 8, 8])
九、基本运算
9.1 广播机制
import torcha = torch.rand(2,2)
print(a)
b = torch.rand(2)
print(b)
print(a+b)
输出:
tensor([[0.4668, 0.6053],[0.5321, 0.8734]])
tensor([0.7595, 0.6517])
tensor([[1.2263, 1.2570],[1.2916, 1.5251]])
9.2 matmul 矩阵/张量乘法
import torcha = torch.ones(2, 2) * 3
b = torch.ones(2, 2)
print(a)
print(b)
print(torch.matmul(a, b))
输出:
tensor([[3., 3.],[3., 3.]])
tensor([[1., 1.],[1., 1.]])
tensor([[6., 6.],[6., 6.]])
9.3 pow 次方运算
import torcha = torch.ones(2, 2) * 3
print(a)
print(torch.pow(a, 3))
输出:
tensor([[3., 3.],[3., 3.]])
tensor([[27., 27.],[27., 27.]])
9.4 sqrt 平方根运算
import torcha = torch.ones(2, 2) * 9
print(a)
print(torch.pow(a, 0.5))
print(torch.sqrt(a))
输出:
tensor([[9., 9.],[9., 9.]])
tensor([[3., 3.],[3., 3.]])
tensor([[3., 3.],[3., 3.]])
9.5 exp 指数幂运算
import torcha = torch.ones(2, 2)
print(a)
print(torch.exp(a))
输出:
tensor([[1., 1.],[1., 1.]])
tensor([[2.7183, 2.7183],[2.7183, 2.7183]])
9.6 log 对数运算(相当于 ln)
import torcha = torch.ones(2, 2) * 3
print(a)
print(torch.log(a))
输出:
tensor([[3., 3.],[3., 3.]])
tensor([[1.0986, 1.0986],[1.0986, 1.0986]])
9.7 取整
- floor():向下取整
- ceil():向上取整
- round():四舍五入
- trunc():截取整数部分
- frac():截取小数部分
import torcha = torch.tensor(3.14)
print(a) # tensor(3.1400)
print(torch.floor(a)) #tensor(3.)
print(torch.ceil(a)) #tensor(4.)
print(torch.round(a)) #tensor(3.)
print(torch.trunc(a)) #tensor(3.)
print(torch.frac(a)) #tensor(0.1400)
9.8 clamp 控制张量的取值范围
import torcha = torch.rand(2, 3) * 15print(a)
# 将大于8的值设置为8;小于4的值设置为4
print(torch.clamp(a, 4, 8))
输出:
tensor([[ 8.8872, 5.6534, 14.3027],[ 0.8305, 12.6266, 13.9683]])
tensor([[8.0000, 5.6534, 8.0000],[4.0000, 8.0000, 8.0000]])
十、统计属性
10.1 norm 求范数
import torcha = torch.ones(2, 3)
b = torch.norm(a) # 默认求2范数
c = torch.norm(a, p=1) # 指定求1范数
print(a)
print(b)
print(c)
输出:
tensor([[1., 1., 1.],[1., 1., 1.]])
tensor(2.4495)
tensor(6.)
10.2 mean、median、sum、min、max、prod、argmax、argmin
- prod():返回张量里所有元素的乘积
- armax():返回张量中最大元素的展平索引
- argmin():返回张量中最小元素的展平索引
import torcha = torch.arange(8).view(2, 4).float()
print(a)
'''
tensor([[0., 1., 2., 3.],[4., 5., 6., 7.]])
'''
print(a.mean()) #tensor(3.5000)
print(a.median()) #tensor(3.)
print(a.sum()) #tensor(28.)
print(a.min()) #tensor(0.)
print(a.max()) #tensor(7.)
print(a.prod()) #tensor(0.)
print(a.argmax()) #tensor(7)
print(a.argmin()) #tensor(0)
import torcha = torch.rand(2,4)
print(a)print(a.max(dim = 1))
print(a.max(dim = 1,keepdim = True))
输出:
tensor([[0.7239, 0.9412, 0.7602, 0.2131],[0.6277, 0.1033, 0.8300, 0.9909]])
torch.return_types.max(
values=tensor([0.9412, 0.9909]),
indices=tensor([1, 3]))
torch.return_types.max(
values=tensor([[0.9412],[0.9909]]),
indices=tensor([[1],[3]]))
10.3 topk 获取最大的k个值
import torcha = torch.rand(2,4)
print(a)
print(a.topk(2,dim=1))
'''
tensor([[0.3247, 0.9220, 0.4314, 0.8123],[0.7133, 0.2471, 0.0281, 0.3595]])
torch.return_types.topk(
values=tensor([[0.9220, 0.8123],[0.7133, 0.3595]]),
indices=tensor([[1, 3],[0, 3]]))
'''
10.4 kthvalue 获取第k大的值
import torcha = torch.rand(2, 4)
print(a)
print(a.kthvalue(3,dim=1))
'''
tensor([[0.0980, 0.0479, 0.9298, 0.5638],[0.9095, 0.9071, 0.4913, 0.6144]])
torch.return_types.kthvalue(
values=tensor([0.5638, 0.9071]),
indices=tensor([3, 1]))
'''
10.5 比较运算函数
import torcha = torch.rand(2, 3)
print(a)
'''
tensor([[0.1196, 0.5068, 0.9272],[0.6395, 0.2433, 0.9702]])
'''
# a >= 0.5
print(a.ge(0.5))
'''
tensor([[False, True, True],[ True, False, True]])
'''
# a > 0.5
print(a.gt(0.5))
'''
tensor([[False, True, True],[ True, False, True]])
'''
# a <= 0.5
print(a.le(0.5))
'''
tensor([[ True, False, False],[False, True, False]])
'''
# a < 0.5
print(a.lt(0.5))
'''
tensor([[ True, False, False],[False, True, False]])
'''
# a = 0.5
print(a.eq(0.5))
'''
tensor([[False, False, False],[False, False, False]])
'''
十一、高级操作
11.1 where
import torchcond = torch.rand(2, 2)
a = torch.zeros(2, 2)
b = torch.ones(2, 2)
print(cond)
'''
tensor([[0.3622, 0.9658],[0.1774, 0.6670]])
'''
print(a)
'''
tensor([[0., 0.],[0., 0.]])
'''
print(b)
'''
tensor([[1., 1.],[1., 1.]])
'''
# 满足条件cond.ge(0.5)的按照a的对应元素赋值,否则按照b的对应元素赋值
print(torch.where(cond.ge(0.5), a, b))
'''
tensor([[1., 0.],[1., 0.]])
'''
11.2 gather
帮助我们从批量tensor中取出指定乱序索引下的数据
import torcha = torch.arange(3, 12).view(3, 3)
print(a)
'''
tensor([[ 3, 4, 5],[ 6, 7, 8],[ 9, 10, 11]])
'''index = torch.tensor([[2, 1, 0]])
print(a.gather(1, index)) # tensor([[5, 4, 3]])index = torch.tensor([[2, 1, 0]]).t()
print(a.gather(1, index))
'''
tensor([[5],[7],[9]])
'''index = torch.tensor([[0, 2],[1, 2]])
print(a.gather(1, index))
'''
tensor([[3, 5],[7, 8]])
'''
参考链接:图解PyTorch中的torch.gather函数
在强化学习DQN中的使用 gather() 函数
【深度学习】超详细的 PyTorch 学习笔记(上)相关推荐
- 超详细!Vue-coderwhy个人学习笔记(二)(Day3)
前言 本文章接上一篇笔记 超详细!Vue-coderwhy个人学习笔记(一)(Day1-Day2) 这篇主要是Day3笔记,组件化,组件通信,插槽 四.组件化开发 (一).内容概述 认识组件化 注册组 ...
- 超详细的Git学习记录(Git基础内容/IDEA集成Git/GitHub/Gitee/GitLab及Centos7部署GitLab)
超详细的Git学习笔记 从B站搜到的尚硅谷视频学习了Git,记录了一下学习的内容,收获很大 学习地址: https://www.bilibili.com/video/BV1vy4y1s7k6?p=11 ...
- 【libuv高效编程】libuv学习超详细教程3——libuv事件循环
文章目录 libuv系列文章 libuv事件循环 uv_loop_t demo uv_loop_init() uv_run() uv_loop_close() 参考 例程代码获取 libuv系列文章 ...
- LiteFlow学习(超详细)
LiteFlow学习(超详细) 文章目录 LiteFlow学习(超详细) 1. LiteFlow简介 1.1 前言 1.2 LiteFlow框架的优势 1.3 LiteFlow的设计原则 1.4 Li ...
- 【网速】Visual Studio 下载太慢的问题的解决办法【超详细,来源于学习笔记】
Visual Studio 下载太慢的问题的解决办法[详细,来源于学习的笔记] Visual Studio 下载太慢的解决办法两个步骤即可: 一.测试DNS 二.修改host 做完以上工作后,VS的下 ...
- 【超详细】嵌入式软件学习大纲
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_34981463/article/ ...
- 适用于任意模糊内核的深度即插即用超分辨率(DPSR论文笔记-2019CVPR)
Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels (适用于任意模糊内核的深度即插即用超分辨率) 源码包:https://gi ...
- 一篇超详细的pytorch基础语法讲解及理论推导(一)
张量 - 线性回归 - 自动求导 - 逻辑回归 来源:投稿 来源:阿克西 编辑:学姐 1 pytorch简介 PyTorch是2017年1月FAIR(Facebook AI Research)发布的一 ...
- 超详细的计算机视觉学习书籍pdf汇总(涉及CV、深度学习、多视图几何、SLAM、点云处理等)
计算机视觉入门的一些pdf书籍,[计算机视觉工坊]按照不同领域帮大家划分了下,涉及深度学习基础.目标检测.Opencv.SLAM.点云.多视图集合.三维重建等~ 计算机视觉 1. 计算机视觉算法与应用 ...
最新文章
- Node和java和php,服务端I/O性能大比拼:Node、PHP、Java和Go(三)
- dtree.js树的使用
- VB6.0连接MySQL数据库
- NYOJ15-括号匹配(二)-区间DP
- Linux系统下软件包管理六
- oracle查询用户下所有表名称
- Redis Cluster集群的搭建与实践
- DX使用随笔--NavBarControl
- 复制百度文库的文字加什么后缀_下载百度文库文档 怎么快速提取百度文库中可以完整阅读的文档...
- PDF软件有这么好用的打印机,你知道吗?
- 物理 常见力与牛顿三定律
- 太牛了!B 站 UP 主开发会写高考作文的 AI
- oracle显示连接超时,Oracle 12179:tns:连接超时的问题
- c语言编写图书检索系统,求C语言编写图书管理系统
- 同花顺_代码解析_技术指标_P、Q
- π=4*atan(1.0);
- 模块电路选型(6)----存储模块
- OSChina 周四乱弹 ——PM是这样学程序的
- 70个python项目代码_python项目实例源码
- 2021-12-6 《聪明的投资者》学习笔记-3.一个世纪的股市历史:1972年年初的股价水平-股市周期性。股价、利润和股息
热门文章
- ShareSDK Android端权限说明
- $(this).addClass('class').siblings('class').removeClass('class')的作用
- QQ聊天记录备份BAK文件的修复方法
- WannaCry席卷全球 软件作者到底赚了多少钱?
- Win10实现窗口AeroGlass化
- 【PTA】求交错序列前N项和
- vue项目点击后,从左边或右边滑出组件,再次点击原路滑回。<transition>、transform
- 计算机软件由程序数据和文档组成其中主体是,chap03 计算机软件
- VRTK抓取触碰交互
- NG-ZORRO1.x自定义主题