PyTorch张量常用的创建、变形及数学运算总结

目录

  • PyTorch张量常用的创建、变形及数学运算总结
    • 1. 张量(tensor)的创建
      • 1.1 torch.Tensor()与torch.tensor()
      • 1.2 torch.rand(), torch.ones(), torch.zeros()
      • 1.3 torch.FloatTensor()、torch.DouleTensor()、torch.cuda.FloatTebsor()等
    • 2 基础运算或变换
      • 2.1 size()函数或shape属性获取张量尺寸
      • 2.2 变形(view()、transpose()、flatten())
      • 2.3 加减乘除(+、-、*、\0)及矩阵乘法(@/matnul)
      • 2.3 张量的拆分、连接、拼接与维度变化(split/chunk、cat、stack、sequeeze、unsequeeze)

1. 张量(tensor)的创建

1.1 torch.Tensor()与torch.tensor()

torch.Tensor(3,2)返回一个尺寸为3*2的未初始化的张量
torch.Tensor(data)返回一个torch.FloatTensor类型的data
torch.tensor(3, 2)报错,tensor函数参数必须为一个基本数据类型或引用数据类型数据如整型、浮点型或列表
torch.tensor(data)返回一个给定数值的张量
区别:

  • torch.Tensor()是一个python类,拥有很多可以对数据加以处理的方法,而torch.tensor()是一个函数
  • torc.tensor()的参数只能是data而不能是size
  • torch.Tensor()相当于torch.FloatTensor()的别名,返回单精度浮点型张量。torch.tensor()自动拷贝传入的数据,在没有给dtype参数时,返回torch.FloatTensor、torch.DoubleTensor、torch.LongTensor三种数据类型张量。给定dtype时返回指定dtype。
    PyTorch官方文档对torch.tensor()函数的解释如下
torch.tensor(data, dtype=None, device=None, requires_grad=False) → Tensor
Constructs a tensor with data.

示例代码:

data1 = torch.Tensor([2, 3])
data2 = torch.Tensor(2, 3)
data3 = torch.tensor([2, 3])
print(data1)
print(data2)
print(data3)# 运行结果
tensor([2., 3.])
tensor([[0., 0., 0.],[0., 0., 0.]])
tensor([2, 3])

1.2 torch.rand(), torch.ones(), torch.zeros()

rand()函数用于创建一个指定尺寸的随机矩阵
ones()函数用于创建一个指定尺寸的元素全为1的矩阵
zeros()函数用于创建一个指定尺寸的元素全为0的矩阵
示例代码

data4 = torch.rand(3, 2)
data5 = torch.ones(3, 2)
data6 = torch.zeros(3, 2)print(data4)
print(data5)
print(data6)# 运行结果
tensor([[0.6134, 0.3304],[0.6353, 0.1127],[0.9612, 0.8309]])
tensor([[1., 1.],[1., 1.],[1., 1.]])
tensor([[0., 0.],[0., 0.],[0., 0.]])

1.3 torch.FloatTensor()、torch.DouleTensor()、torch.cuda.FloatTebsor()等

以下表格列出了PyTorch支持的数据类型及其创建函数

数据类型 CPU张量 GPU张量
32位浮点型 torch.FloatTensor torch.cuda.FloatTensor
64位浮点型 torch.DouleTensor torch.cuda.FloatTensor
16位浮点型 torch.HalfTensor torch.cuda.HalfTensor
8位整型(无符号) torch.ByteTensor torch.cuda.ByteTensor
8位整型(有符号) torch.CharTensor torch.cuda.CharTensor
16位整型(有符号) torch.ShortTensor torch.cuda.ShortTensor
32位整型(有符号) torch.IntTensor torch.cuda.IntTensor
64位整型(有符号) torch.LongTensor torch.cuda.LongTensor

2 基础运算或变换

2.1 size()函数或shape属性获取张量尺寸

为了PyTorch API更加接近于NumpyAPI 以取得近似于Numpy数据处理包的强大功能,PyTorch在0.2版本后引入shape属性,用户不但可以像以前一样通过size()函数获取张量尺寸,也可以通过shape属性直接获取。
示例代码:

data2 = torch.Tensor(2, 3)
print(data2.size())
print(data2.shape)# 运行结果
torch.Size([2, 3])
torch.Size([2, 3])

2.2 变形(view()、transpose()、flatten())

有时我们需要对已有张量进行一些变形,比如,三维张量平展成一维张量, 或者交换矩阵的某两个维度(转置)

  • view()函数可对原张量进行任意形式的变形(前提是张量的参数个数保持不变)
  • transpose()函数可对原张量任意两个维度进行转置(交换)
  • flatten()函数可以指定任意连续维度进行“平铺”

示例代码

import torchdata1 = torch.rand(2, 3, 4, 5)
data2 = data1.view(2, 4, 3, 5)
data3 = data1.view(-1) # -1 让python自己计算张量长度,按照整体长度不变的原则进行填充
data4 = data1.view(-1, 10)print(data1[0, :, :, 0]) # 输出第2,3维张量
print(data2[0, :, :, 0])
print(data3.size())
print(data4.size())data5 = data1.transpose(1, 2)
print(data5[0, :, :, 0])
print(data2.size())
print(data5.size())# 运行结果
tensor([[0.6149, 0.0561, 0.6514, 0.8010],[0.4063, 0.5112, 0.1584, 0.5826],[0.0581, 0.5838, 0.8439, 0.1180]])
tensor([[0.6149, 0.0561, 0.6514],[0.8010, 0.4063, 0.5112],[0.1584, 0.5826, 0.0581],[0.5838, 0.8439, 0.1180]])
torch.Size([120])
torch.Size([12, 10])
tensor([[0.6149, 0.4063, 0.0581],[0.0561, 0.5112, 0.5838],[0.6514, 0.1584, 0.8439],[0.8010, 0.5826, 0.1180]])
torch.Size([2, 4, 3, 5])
torch.Size([2, 4, 3, 5])# flatten
data1 = torch.rand(2, 3, 4, 5)
data2 = data1.flatten(1)
data3 = data1.flatten(1, 2)
print(data2.size())
print(data3.size())
# 运行结果:
torch.Size([2, 60])
torch.Size([2, 12, 5])

可以发现,view()的变形是将原张量按顺序填入新的尺寸张量,这容易改变原张量的数学意义。transpose()有点类似于“转置”,一定程度上保留了原张量的数学意义。
此外,注意到view(2, 4, 3, 5)最后和transpose(1, 2)取得了同样的维度结果,但是两者的数值结果是完全不一样的,原因正如第一点所述。

2.3 加减乘除(+、-、*、\0)及矩阵乘法(@/matnul)

示例代码:

data10 = data6 + 5
data11 = data10 - 3
data12 = data10 * data11
data13 = torch.tensor([[1.],[1.]])
data14 = data12 @ data13 # 注意*  和 @ 的区别
data15 = data12.matmul(data13)
print(data10)
print(data11)
print(data12)
print(data14) # 运行结果
tensor([[5., 5.],[5., 5.],[5., 5.]])
tensor([[2., 2.],[2., 2.],[2., 2.]])
tensor([[10., 10.],[10., 10.],[10., 10.]])
tensor([[20.],[20.],[20.]])
tensor([[20.],[20.],[20.]])

2.3 张量的拆分、连接、拼接与维度变化(split/chunk、cat、stack、sequeeze、unsequeeze)

  1. split()函数和chunk()函数用于在指定维度下按指定步长拆分张量
  2. cat()函数用于连接只有一个维度大小不同的两个张量
  3. stack()用于给张量添加新的维度,要求两个张量除了待添加的维度不同,其余维度相同
  4. sequeeze()称为“挤压”,用于删除尺寸为1的维度,不仅可以用于匹配维度还可以大大减少运算时间
  5. unsequeeze()称为“反挤压”,用于给张量增加一个虚拟的维度,并且与satck不同是,unsequeeze增加的虚拟维度不参与张量运算

示例代码:

import torchdata1 = torch.rand(3, 2)
# 拆分
data2 = data1.split(1, 0) # 返回一个张量元组
print("data1:   ", data1)
print("data2:   ", data2)
print("data2[0]:    ", data2[0])# 连接
data3 = torch.ones(1, 2)
data4 = torch.cat((data1, data3), 0)
print("data3:   ", data3)
print("data4:   ", data4)# 拼接
data5 = torch.ones(3, 2)
data6 = torch.stack((data1, data5), 0)
print("data5:   ", data5)
print("data6:   ", data6)data7 = data1.unsqueeze(0)
print("data7:   ", data7)
print("data1.size:  ", data1.size())
print("data7.size:  ", data7.size())data8 = data7.squeeze()
print("data8.size:   ", data8.shape)# 运行结果
data1:    tensor([[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]])
data2:    (tensor([[0.3280, 0.8911]]), tensor([[0.4691, 0.9570]]), tensor([[0.5510, 0.6604]]))
data2[0]:     tensor([[0.3280, 0.8911]])
data3:    tensor([[1., 1.]])
data4:    tensor([[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604],[1.0000, 1.0000]])
data5:    tensor([[1., 1.],[1., 1.],[1., 1.]])
data6:    tensor([[[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]],[[1.0000, 1.0000],[1.0000, 1.0000],[1.0000, 1.0000]]])
data7:    tensor([[[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]]])
data1.size:   torch.Size([3, 2])
data7.size:   torch.Size([1, 3, 2])
data8.size:    torch.Size([3, 2])

PyTorch常用的张量创建、变形及运算总结(速查表)相关推荐

  1. pandas常用函数说明及速查表

    pandas常用函数说明及速查表 如果你用python做开发,那么几乎肯定会使用pandas库. Pandas 是 Python 语言的一个扩展程序库,用于数据分析. Pandas 是一个开放源码.B ...

  2. Git 常用命令速查表(图文+表格)

    一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态  git commit 提交  git branch -a 查看所有的分支 git branch ...

  3. Git 常用命令速查表(图文+表格)【转】

    转自:http://www.jb51.net/article/55442.htm 一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态  git co ...

  4. 常用Python标准库对象速查表(1)

    封面图片:<Python程序设计基础(第2版)>,董付国,清华大学出版社 =============== 常用Python标准库对象速查表(1) 标准库 对象 简要说明 math sin( ...

  5. MySQL 常用命令速查表:日常开发、求职面试必备良方!

    备注:PDF 版本点此下载. 文章目录 连接服务器 查看帮助 查看连接 退出连接 账户和权限 创建用户 查看用户 修改密码 锁定/解锁用户 用户授权 查看权限 撤销权限 管理角色 删除用户 管理数据库 ...

  6. 全套Python数据分析常用命令速查表!PDF文档限时分享

    当下利用python学习数据分析的热度越来越高,对于很多新手而言,大量要学习的库和工具的命令繁杂,用起来不是很顺手. 今天给大家分享一份python数据分析常用命令速查表. 一共6张表,包括:Jupy ...

  7. Linux系统运维人员常用速查表

    Linux系统运维人员常用速查表 walkingcloud 2020-08-09 19:55:41 Linux系统运维人员常用速查表 1.awk速查表 2.bash速查表 3.firewall-cmd ...

  8. Python常用网络爬虫速查表下载

    Python常用网络爬虫速查表下载 Post方法: Get方法: css选择器 beautiful soup选择器 xpath选择器 可以将图片打印出来,放在桌面看 下载地址: 一天掌握python网 ...

  9. 转收藏:Git常用命令速查表

    一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态  git commit 提交  git branch -a 查看所有的分支 git branch ...

最新文章

  1. python 删除变量_DAY1-step4 Python变量:声明,连接变量,全局和局部
  2. 在 C++ 中实现一个轻量的标记清除 gc 系统
  3. .NET Core实战项目之CMS 第十章 设计篇-系统开发框架设计
  4. cannot find output in imported module librosa报错解决
  5. Android下常见的内存泄露 经典
  6. python模块的函数_python模块内置函数
  7. 决定你人生命运的10年,你做了什么?
  8. swiper.js插件的使用
  9. CentOS 如何修改mysql 用户root的密码
  10. 使用ES6的Promis完美解决ajax的回调(优化代码)
  11. java获取异常信息
  12. init mysql db error_Python mysql curs错误
  13. python万年历差农历程序_Python实现公历(阳历)转农历(阴历)的方法示例
  14. 那些年,我们一起做过的 Java 课后练习题(61 - 65)
  15. 升级coda_提高生产力:Coda的快速提示
  16. pdf.net sod oracle,SOD: 原PDF.NET框架将成为一个全功能的企业开发框架,而 SOD框架将是PDF.NET开发框架下面的 “数据开发框架...
  17. 隐私加密系列|Grin与BEAM之间技术公开对比
  18. linux显示全部字符集,linux 字符集 页面显示乱码
  19. dubbo实现SOA架构
  20. 企业微信小程序_集成腾讯地图实现精准定位考勤打卡

热门文章

  1. 程序员什么专业毕业算是科班出身?这个回答与你想的是否一样?
  2. 基于Python/Capl脚本 对通信矩阵报文(Flexray/Can)的周期检测(二)
  3. 如何获取屏幕DPI/PPI并计算A4纸在屏幕的大小
  4. JDK、JER、JVM是什么
  5. 音频文件wav转gsm
  6. java计算机毕业设计汽车票订购系统源码+程序+lw文档+mysql数据库
  7. 基于数据库的企业内部邮件系统的设计
  8. 如何解决Beyond Compare中文乱码问题
  9. 《时代三部曲》感悟三
  10. 有人问我为什么不买iphon12,我为什么要买iphone12 pro max