在PyTorch中,张量属于一种基本的数据类型,和Numpy库中的ndarry类似,无论是标量、向量、矩阵还是高维数组都是以张量(Tensor)这种数据类型来表示。因此,有必要对该基本数据类型有所了解。(在PyTorch 0.4版本以前,还有个Variable类,主要是用于计算梯度,在后续版本中已经和Tensor合并,本专栏主要以最新的1.7.0版本做实验)

目录

  • 1. 张量的数据类型
  • 2. 张量的创建
    • (1) 使用torch.tensor()函数创建张量或者使用torch.Tensor()创建张量
    • (2) 使用torch.xx_like()创建张量
    • (3) 随机数生成张量
    • (4) 按照数值内容创建张量
    • (5) 按照某种规则生成张量
    • (6) 与Numpy数据相互转化
  • 3. 张量的基本操作
    • (1)改变张量的形状
    • (2) 获取张量元素
    • (3) 张量的拼接与拆分
  • 4.张量的计算
    • (1) 张量的逐元素操作
    • (2) 张量的矩阵操作
    • (3) 张量的矩阵操作
    • (3) 张量的统计量计算
  • Reference

1. 张量的数据类型

在PyTorch中,张量也就是Tensor,分别有两大类型:浮点型和整型。其中浮点型按照精度不同,又分为16位、32位以及64位;整型则根据有无符号位以及精度,又分为8位无符号整型、8位有符号整型、16位有符号整型、32位有符号整型以及64位有符号整型。每个类型又可以区分为CPU类型以及GPU类型。具体如下图所示:

一般常用的就是32位浮点型,而且只有浮点型才可以计算梯度。训练用的标签则一般是整型中的32或64位。 我们可以使用tensor.dtype来查看当前tensor的数据类型。

2. 张量的创建

(1) 使用torch.tensor()函数创建张量或者使用torch.Tensor()创建张量

例如:

In [2]: torch.tensor([1.0, 2.0]).type()
Out[2]: 'torch.FloatTensor'
In [3]: torch.Tensor([1.0, 2.0]).type()
Out[3]: 'torch.FloatTensor'

两者都可以通过传入一个类array的data进行tensor的创建。那两者有什么区别呢?

torch.tensor()是一个函数,而torch.Tensor()是一个类,是默认tensor 类型的一个别名,它实际上是调用的__init__函数去构建的张量。

因此,torch.tensor()可以通过dtype参数指定生成的tensor的类型并且通过requires_grad参数去指定是否计算梯度,而torch.Tensor()都不可以,还有就是torch.Tensor()可以通过传入形状来创建张量。具体见下图示例:

In [4]: torch.Tensor([1, 2]).type()  # 无论传入的data是什么类型,都只能生成默认类型的tensor
Out[4]: 'torch.FloatTensor'In [5]: torch.tensor([1, 2]).type()
Out[5]: 'torch.LongTensor'In [7]: torch.Tensor(2,3)  # 指定shape
Out[7]:
tensor([[0., 0., 0.],[0., 0., 0.]])

(2) 使用torch.xx_like()创建张量

torch.ones_like(X), torch.zeros_like(X), torch.rand_like(X)可以分别用于创建一个与张量X具有相同维度的全1、全0或者是服从[0,1]区间上均匀分布的张量。

(3) 随机数生成张量

torch.normal(mean, std)可以通过传入指定的均值张量和方差张量,从而生成一个对应满足该分布的随机数张量,当mean和std传入多个值时,则会相应得到多个随机数张量。均值和方差向量的维度需要满足广播机制。 torch.rand(shape), torch.randn(shape)则是用于生成服从[0,1]区间上均匀分布的张量以及服从标准正态分布的张量。(randn多了个n表示normal distribution,也很好记忆)

(4) 按照数值内容创建张量

torch.zeros(shape),torch.ones(shape),torch.eye(shape),torch.full(shape, fill_value),torch.empty(shape)可以通过指定shape来创建一个全0、全1、全为fill_value或是完全随机的一个张量。

(5) 按照某种规则生成张量

torch.arange(start, end, step), torch.linspace(start, end, step), torch.logspace(start, end, step)可以通过指定start,end以及step参数来在某个范围内基于固定步长、等长间隔或对数间隔的张量。

(6) 与Numpy数据相互转化

这是PyTorch中非常常用的一个张量生成方式。

  • 将Numpy数组转化为PyTorch张量:torch.as_tensor(ndarray), torch.from_numpy(ndarray)
  • 将PyTorch张量转化为Numpy数组:tensor.numpy()

3. 张量的基本操作

(1) 改变张量的形状

对张量的形状的改变是PyTorch中非常重要且常用的操作之一,因为很多操作,例如全连接层前向传播、矩阵乘法等对输入的shape都有一定的要求。

  • tensor.reshape(shape), torch.reshape(input, shape), tensor.resize_(shape), tensorA.resize_as_(tensorB), tensor.view(shape) 可以根据指定的shape对原tensor进行reshape操作。具体含义,看函数名称都能比较清晰地看出来。(PS. 在PyTorch中,以“_”结尾的函数操作一般表示原地操作,该操作没有返回值,而是直接对原tensor进行操作)(PPS. torch.reshape和tensor.view都可以用来调整tensor的形状,但view函数要求作用的tensor在内存中连续存储,如果对tensor调用过transpose,permute等操作后就无法直接使用view,而需要使用tensor.contiguous来获得一个连续的copy,而reshape则没有这个要求)
In [1]: import torch
In [2]: x = torch.arange(6)
In [3]: y = x.reshape(2, 3)
In [4]: x, y
Out[4]:
(tensor([0, 1, 2, 3, 4, 5]),tensor([[0, 1, 2],[3, 4, 5]]))
In [5]: y = x.resize_(2, 3)
In [6]: x, y
Out[6]:
(tensor([[0, 1, 2],[3, 4, 5]]),tensor([[0, 1, 2],[3, 4, 5]]))

  • torch.flatten(input, start_dim, end_dim)可以将输入的tensor的第start_dim到end_dim之间的数据拉平成一维tensor。这个操作对于全连接网络十分常用。一般都会先使用全卷积网络得到一个特征图,shape为B, C, H, W,此时若要送入全连接网络进行分类,则需要对其进行拉平,这里可以使用flatten函数——torch.flatten(tensor, start_dim=1)对该特征图tensor进行拉平而不需要去获取其shape大小,得到一个tensor,shape为B, C * H * W的二维向量。
  • torch.unsqueeze(tensor, dim)可以,可以在张量的指定维度插入一个新的维度,得到一个新的张量,当然也可以使用tensor.unsqueeze_(dim)进行原地操作。这个操作我个人一般是用于给灰度图扩展维度使用,灰度图倘若使用灰度方式读取的话,最终得到的shape为B, H, W,是不符合PyTorch四个维度的要求的,此时可以使用tensor.unsqueeze_(dim=1)来扩充一个维度,得到的新tensor的shape为B, 1, H, W即可进行后续操作了。
  • torch.squeeze(tensor, dim)功能则正好相反,它可以移除指定或者所有维度大小为1(不指定维度时)的维度,从而得到一个维度减小的新张量,类似的也可以使用tensor.squeeze_(dim)进行原地操作。这个操作我个人一般用于全卷积网络做分类时的多余维度压缩用,在使用全卷积网络得到最后分类的feature map时(shape为B, C, H, W, C即为待分类数),此时使用Global Average Pooling对特征图的空间信息进行压缩和进一步提取,得到feature map的shape为B, C, 1, 1,此时即可使用tensor.squeeze_()对维度进行压缩,得到shape为B, C的分类向量,即可用于后续loss计算。(PS. 这里有个小坑需要注意,倘若此时batchsize也是1(类别数一般不会为1),则也会被压缩,从而得到错误的shape,此时可以使用flatten函数进行压缩

(2) 获取张量元素

PyTorch中的张量是支持和Numpy数组类似的切片以及索引操作的,并且也可以使用真值mask去提取真值内容。例如:

In [1]: import torchIn [2]: x = torch.arange(6).reshape(2, 3)In [3]: x
Out[3]:
tensor([[0, 1, 2],[3, 4, 5]])In [4]: x[0], x[0, :2]
Out[4]: (tensor([0, 1, 2]), tensor([0, 1]))In [5]: y = torch.ones_like(x)In [6]: y
Out[6]:
tensor([[1, 1, 1],[1, 1, 1]])In [7]: torch.where(x > 2, x, y)
Out[7]:
tensor([[1, 1, 1],[3, 4, 5]])

(3) 张量的拼接与拆分

  • torch.cat(tensors, dim)可以将多个张量在指定维度进行拼接,从而得到新的张量。这个操作在模型前向传播过程中是非常常用的一个操作,用于对特征的concatenation,从而后续对特征进行高级地融合。例如:
In [1]: import torchIn [2]: x = torch.rand(size=(1, 3, 6, 6))In [3]: y = torch.rand(size=(1, 4, 6, 6))In [4]: x.shape, y.shape, torch.cat((x, y), dim = 1).shape
Out[4]: (torch.Size([1, 3, 6, 6]), torch.Size([1, 4, 6, 6]), torch.Size([1, 7, 6, 6]))

  • torch.stack(tensors, dim)也可以用来将多个张量在指定维度进行拼接,但与torch.cat()不同的是,该操作是沿新维度拼接张量。例如:
In [1]: import torchIn [2]: x = torch.rand(size=(1, 3, 6, 6))In [3]: y = torch.rand(size=(1, 3, 6, 6))In [4]: x.shape, y.shape, torch.stack((x, y), dim = 1).shape
Out[4]:
(torch.Size([1, 3, 6, 6]),torch.Size([1, 3, 6, 6]),torch.Size([1, 2, 3, 6, 6]))

这个维度的指定是根据生成的新张量中的新维度来确定的,也就是说,如果你希望这些shape相同的张量在拼接后的张量的第几个维度进行拼接,那么这个dim就设定为几。需要注意的是,此时要求送入的所有tensor的shape相同。

  • torch.chunk(tensor, chunks, dim)可以在指定维度上,将tensor划分为chunks块,如果指定维度上的张量的大小不能被chunks整除,则最后一个块的size将会略小。torch.split(tensor, split_size_or_sections, dim)可以在指定维度上,将tensor划分为split_size_or_sections块(split_size_or_sections为整数时),此时该函数与torch.chunk()的功能类似,当split_size_or_sections为列表时,可以将tensor划分为不同大小的块,块的大小由list中的数值指定。

4.张量的计算

(1) 张量的逐元素操作

  • 张量间的逐元素大小比较:

torch.eq(A, B),torch.equal(A, B),torch.ge(A, B),torch.gt(A, B),torch.le(A, B),torch.lt(A, B),torch.ne(A, B),torch.isnan(A, B)的功能分别是:逐元素比较张量A和张量B是否相等,判断两个张量是否具有相同的形状和元素,逐元素比较是否大于等于,逐元素比较是否大于,逐元素比较是否小于等于,逐元素比较是否小于,逐元素比较是否不相等,逐元素判断是否为缺失值。函数的功能见名知意,g代表greater,t表示than,e表示equal,l代表less等。该比较结果也可以作为一个mask去取出tensor中满足某些条件的值,例如:

In [1]: import torchIn [2]: x = torch.randn(2, 3)In [3]: x
Out[3]:
tensor([[-0.1698,  0.3462,  1.0038],[ 0.6679,  0.6578, -0.5917]])In [4]: x[torch.gt(x, torch.tensor(0.0))]
Out[4]: tensor([0.3462, 1.0038, 0.6679, 0.6578])In [5]: torch.where(torch.gt(x, torch.tensor(0)), x, torch.tensor(0.0))
Out[5]:
tensor([[0.0000, 0.3462, 1.0038],[0.6679, 0.6578, 0.0000]])

  • 张量的逐元素运算:

A + B, A - B, A * B, A / B, A // B 表示张量A和张量B逐元素相加、相减、相乘、相除以及相整除。

torch.pow(tensor, exponent), tensor ** exponent, torch.exp(tensor), torch.log(tensor), torch.sqrt(tensor), torch.rsqrt(tensor) 表示计算张量的幂、幂、指数、对数、平方根以及平方根倒数。

torch.clamp_max(tensor, max), torch.clamp_min(tensor, min), torch.clamp(tensor, min, max) 表示逐元素对tensor按照最大值、最小值以及一个范围进行裁剪。

(2) 张量的矩阵操作

torch.t(tensor), torch.matmul(A, B), torch.inverse(tensor), torch.trace(tensor) 表示矩阵的转置、张量的矩阵乘法、张量的逆以及张量的迹。

(3) 张量的统计量计算

torch.max(tensor), torch.argmax(tensor), torch.min(tensor), torch.argmin(tensor) 分别返回一个张量的最大值、最大值索引、最小值以及最小值索引,对于torch.max()和torch.min()来说,如果指定了dim,也就是说对某个维度求最大值,那么该函数会返回一个namedtuple(values, indices),第一个为最大/小的值,第二个为最大/小值的索引。

torch.topk(tensor, k) 获取张量的前k个最大的值以及其对应的索引位置,也是一样的namedtuple,也可以通过设定largest参数为FALSE,从而获得topk小的值及其索引。

torch.mean(tensor, dim), torch.sum(tensor, dim), torch.cumsum(tensor, dim), torch.median(tensor, dim), torch.cumprod(tensor, dim), torch.std(tensor, dim) 表示根据指定的维度计算均值、求和、计算累积和、计算中位数、计算累乘积以及标准差。

Reference

  1. https://pytorch.org/docs/1.7.0/
  2. PyTorch view和reshape的区别
  3. PyTorch深度学习入门与实战

- END -

pytorch flatten函数_1. PyTorch中的基本数据类型——张量相关推荐

  1. Pytorch中 permute / transpose 和 view / reshape, flatten函数

    1.transpose与permute transpose() 和 permute() 都是返回转置后矩阵,在pytorch中转置用的函数就只有这两个 ,这两个函数都是交换维度的操作 transpos ...

  2. Pytorch阅读文档之flatten函数

    pytorch中flatten函数 torch.flatten() #展平一个连续范围的维度,输出类型为Tensor torch.flatten(input, start_dim=0, end_dim ...

  3. [PyTorch] 深度学习框架PyTorch中的概念和函数

    Pytorch的概念 Pytorch最重要的概念是tensor,意为"张量". Variable是能够构建计算图的 tensor(对 tensor 的封装).借用Variable才 ...

  4. 旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题

    旧版中 pytorch.rfft 函数与新版 pytorch.fft.rfft 函数对应修改问题 前言 一.旧版 pytorch.rfft()函数解释 二.新版pytorch.fft.rfft()函数 ...

  5. 全新开源,《Pytorch常用函数函数手册》开放下载!内含200余个函数!

    近期有很多小伙伴在后台咨询有没有关于Pytorch函数使用的学习资料.Pytorch是目前常用的深度学习框架之一,深受学生党的喜爱,小白本人也是使用的Pytorch框架.为了帮助更多小伙伴,小白学视觉 ...

  6. pytorch基础函数学习

    深度学习框架,似乎永远离不开哪个最热哪个最实用的话题,自己接触甚浅,尚不敢对齐进行大加评论,这里也只是初步接触.目前常见的有TensorFlow,pytorch,Keras等,至于目前哪个做好用,就像 ...

  7. python中tolist_高效的张量操作 Pytorch中就占5种

    PyTorch是一个基于Python的科学包,用于使用一种称为张量的特殊数据类型执行高级操作. 虽然也有其他方式可以实现相同的效果,但今天分享的这5个操作更加方便高效,值得一试. 什么是张量? 张量是 ...

  8. pytorch拼接函数:torch.stack()和torch.cat()--详解及例子

    原文链接: https://blog.csdn.net/xinjieyuan/article/details/105205326 https://blog.csdn.net/xinjieyuan/ar ...

  9. 哈工大博士历时半年整理的《Pytorch常用函数函数手册》开放下载!内含200余个函数!...

    近期有很多小伙伴在公众号后台咨询有没有关于Pytorch函数使用的学习资料.Pytorch是目前常用的深度学习框架之一,深受学生党的喜爱,小白本人也是使用的Pytorch框架.为了帮助更多小伙伴,小白 ...

最新文章

  1. IC/FPGA校招笔试题分析(二)任意切换的时钟分频电路
  2. 【千字分析】剑指 Offer 46. 把数字翻译成字符串
  3. Firefox Developer Edition已阻止此网站安装未经验证的附加组件的解决办法
  4. 字符串php手册,php知识点复习之字符串
  5. pdfplumber解析pdf文件
  6. 守护进程之PHP实现
  7. 用SVM分类模型处理iris数据集
  8. MySQL基础入门-创建表格系列操作
  9. 报考上传照片时显示服务器错误,报考上传照片所遇问题及解决方法(转载)
  10. 2020牛客寒假算法基础集训营4 - G 音乐鉴赏-全概率公式
  11. CSS3 经典教程系列:CSS3 线性渐变(linear-gradient)
  12. VBS 请求WebAPI接口_如何设计WEB API
  13. WGS84(GPS)、火星坐标系(GCJ02)、百度地图(BD09)坐标系转换案例教程(附转换工具下载)
  14. 洛谷P4683 [IOI2008] Type Printer 题解
  15. EasyExcel解析excel(合并单元格和未合并)
  16. 震撼人心的战争类背景音乐
  17. 国内的人工智能神经网络研究院有哪些
  18. 炒股杠杆放大多少合适
  19. async/await的用法
  20. 位深度怎么调_吉他大神是怎么炼成的?

热门文章

  1. k6前级效果器怎么用_P18:调制类效果器的那些事儿(Modulation)
  2. x射线直接投影成像的条件_告诉你如何区分X射线DR、CR和胶片成像?
  3. Java Excel(jxl)学习笔记
  4. 这可能是最好的RxJava 2.x 入门教程学习系列
  5. 基于JAVA+SpringBoot+Mybatis+MYSQL的宝妈购母婴用品商城
  6. 基于JAVA+SpringMVC+MYSQL的小说管理系统
  7. 基于JAVA+SpringMVC+MYSQL的实验室预约管理系统
  8. 搜狐畅游笔试题:1. 美丽的项链(动态规划) 2.多线程并发交替输出
  9. 第四周笔记 c++ Boolan
  10. 虚拟机 ubuntu10.04 安装 Mercury MW150U 无线网卡(AR9271芯片组)