PyTorch常用的张量创建、变形及运算总结(速查表)
PyTorch张量常用的创建、变形及数学运算总结
目录
- PyTorch张量常用的创建、变形及数学运算总结
- 1. 张量(tensor)的创建
- 1.1 torch.Tensor()与torch.tensor()
- 1.2 torch.rand(), torch.ones(), torch.zeros()
- 1.3 torch.FloatTensor()、torch.DouleTensor()、torch.cuda.FloatTebsor()等
- 2 基础运算或变换
- 2.1 size()函数或shape属性获取张量尺寸
- 2.2 变形(view()、transpose()、flatten())
- 2.3 加减乘除(+、-、*、\0)及矩阵乘法(@/matnul)
- 2.3 张量的拆分、连接、拼接与维度变化(split/chunk、cat、stack、sequeeze、unsequeeze)
1. 张量(tensor)的创建
1.1 torch.Tensor()与torch.tensor()
torch.Tensor(3,2)
返回一个尺寸为3*2的未初始化的张量
torch.Tensor(data)
返回一个torch.FloatTensor
类型的data
torch.tensor(3, 2)
报错,tensor函数参数必须为一个基本数据类型或引用数据类型数据如整型、浮点型或列表
torch.tensor(data)
返回一个给定数值的张量
区别:
- torch.Tensor()是一个python类,拥有很多可以对数据加以处理的方法,而torch.tensor()是一个函数
- torc.tensor()的参数只能是data而不能是size
- torch.Tensor()相当于torch.FloatTensor()的别名,返回单精度浮点型张量。torch.tensor()自动拷贝传入的数据,在没有给dtype参数时,返回torch.FloatTensor、torch.DoubleTensor、torch.LongTensor三种数据类型张量。给定dtype时返回指定dtype。
PyTorch官方文档对torch.tensor()函数的解释如下
torch.tensor(data, dtype=None, device=None, requires_grad=False) → Tensor
Constructs a tensor with data.
示例代码:
data1 = torch.Tensor([2, 3])
data2 = torch.Tensor(2, 3)
data3 = torch.tensor([2, 3])
print(data1)
print(data2)
print(data3)# 运行结果
tensor([2., 3.])
tensor([[0., 0., 0.],[0., 0., 0.]])
tensor([2, 3])
1.2 torch.rand(), torch.ones(), torch.zeros()
rand()
函数用于创建一个指定尺寸的随机矩阵
ones()
函数用于创建一个指定尺寸的元素全为1的矩阵
zeros()
函数用于创建一个指定尺寸的元素全为0的矩阵
示例代码
data4 = torch.rand(3, 2)
data5 = torch.ones(3, 2)
data6 = torch.zeros(3, 2)print(data4)
print(data5)
print(data6)# 运行结果
tensor([[0.6134, 0.3304],[0.6353, 0.1127],[0.9612, 0.8309]])
tensor([[1., 1.],[1., 1.],[1., 1.]])
tensor([[0., 0.],[0., 0.],[0., 0.]])
1.3 torch.FloatTensor()、torch.DouleTensor()、torch.cuda.FloatTebsor()等
以下表格列出了PyTorch支持的数据类型及其创建函数
数据类型 | CPU张量 | GPU张量 |
---|---|---|
32位浮点型 | torch.FloatTensor | torch.cuda.FloatTensor |
64位浮点型 | torch.DouleTensor | torch.cuda.FloatTensor |
16位浮点型 | torch.HalfTensor | torch.cuda.HalfTensor |
8位整型(无符号) | torch.ByteTensor | torch.cuda.ByteTensor |
8位整型(有符号) | torch.CharTensor | torch.cuda.CharTensor |
16位整型(有符号) | torch.ShortTensor | torch.cuda.ShortTensor |
32位整型(有符号) | torch.IntTensor | torch.cuda.IntTensor |
64位整型(有符号) | torch.LongTensor | torch.cuda.LongTensor |
2 基础运算或变换
2.1 size()函数或shape属性获取张量尺寸
为了PyTorch API更加接近于NumpyAPI 以取得近似于Numpy数据处理包的强大功能,PyTorch在0.2版本后引入shape属性,用户不但可以像以前一样通过size()函数获取张量尺寸,也可以通过shape属性直接获取。
示例代码:
data2 = torch.Tensor(2, 3)
print(data2.size())
print(data2.shape)# 运行结果
torch.Size([2, 3])
torch.Size([2, 3])
2.2 变形(view()、transpose()、flatten())
有时我们需要对已有张量进行一些变形,比如,三维张量平展成一维张量, 或者交换矩阵的某两个维度(转置)
view()
函数可对原张量进行任意形式的变形(前提是张量的参数个数保持不变)transpose()
函数可对原张量任意两个维度进行转置(交换)flatten()
函数可以指定任意连续维度进行“平铺”
示例代码
import torchdata1 = torch.rand(2, 3, 4, 5)
data2 = data1.view(2, 4, 3, 5)
data3 = data1.view(-1) # -1 让python自己计算张量长度,按照整体长度不变的原则进行填充
data4 = data1.view(-1, 10)print(data1[0, :, :, 0]) # 输出第2,3维张量
print(data2[0, :, :, 0])
print(data3.size())
print(data4.size())data5 = data1.transpose(1, 2)
print(data5[0, :, :, 0])
print(data2.size())
print(data5.size())# 运行结果
tensor([[0.6149, 0.0561, 0.6514, 0.8010],[0.4063, 0.5112, 0.1584, 0.5826],[0.0581, 0.5838, 0.8439, 0.1180]])
tensor([[0.6149, 0.0561, 0.6514],[0.8010, 0.4063, 0.5112],[0.1584, 0.5826, 0.0581],[0.5838, 0.8439, 0.1180]])
torch.Size([120])
torch.Size([12, 10])
tensor([[0.6149, 0.4063, 0.0581],[0.0561, 0.5112, 0.5838],[0.6514, 0.1584, 0.8439],[0.8010, 0.5826, 0.1180]])
torch.Size([2, 4, 3, 5])
torch.Size([2, 4, 3, 5])# flatten
data1 = torch.rand(2, 3, 4, 5)
data2 = data1.flatten(1)
data3 = data1.flatten(1, 2)
print(data2.size())
print(data3.size())
# 运行结果:
torch.Size([2, 60])
torch.Size([2, 12, 5])
可以发现,view()的变形是将原张量按顺序填入新的尺寸张量,这容易改变原张量的数学意义。transpose()有点类似于“转置”,一定程度上保留了原张量的数学意义。
此外,注意到view(2, 4, 3, 5)最后和transpose(1, 2)取得了同样的维度结果,但是两者的数值结果是完全不一样的,原因正如第一点所述。
2.3 加减乘除(+、-、*、\0)及矩阵乘法(@/matnul)
示例代码:
data10 = data6 + 5
data11 = data10 - 3
data12 = data10 * data11
data13 = torch.tensor([[1.],[1.]])
data14 = data12 @ data13 # 注意* 和 @ 的区别
data15 = data12.matmul(data13)
print(data10)
print(data11)
print(data12)
print(data14) # 运行结果
tensor([[5., 5.],[5., 5.],[5., 5.]])
tensor([[2., 2.],[2., 2.],[2., 2.]])
tensor([[10., 10.],[10., 10.],[10., 10.]])
tensor([[20.],[20.],[20.]])
tensor([[20.],[20.],[20.]])
2.3 张量的拆分、连接、拼接与维度变化(split/chunk、cat、stack、sequeeze、unsequeeze)
split()
函数和chunk()函数用于在指定维度下按指定步长拆分张量cat()
函数用于连接只有一个维度大小不同的两个张量stack()
用于给张量添加新的维度,要求两个张量除了待添加的维度不同,其余维度相同sequeeze()
称为“挤压”,用于删除尺寸为1的维度,不仅可以用于匹配维度还可以大大减少运算时间unsequeeze()
称为“反挤压”,用于给张量增加一个虚拟的维度,并且与satck不同是,unsequeeze增加的虚拟维度不参与张量运算
示例代码:
import torchdata1 = torch.rand(3, 2)
# 拆分
data2 = data1.split(1, 0) # 返回一个张量元组
print("data1: ", data1)
print("data2: ", data2)
print("data2[0]: ", data2[0])# 连接
data3 = torch.ones(1, 2)
data4 = torch.cat((data1, data3), 0)
print("data3: ", data3)
print("data4: ", data4)# 拼接
data5 = torch.ones(3, 2)
data6 = torch.stack((data1, data5), 0)
print("data5: ", data5)
print("data6: ", data6)data7 = data1.unsqueeze(0)
print("data7: ", data7)
print("data1.size: ", data1.size())
print("data7.size: ", data7.size())data8 = data7.squeeze()
print("data8.size: ", data8.shape)# 运行结果
data1: tensor([[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]])
data2: (tensor([[0.3280, 0.8911]]), tensor([[0.4691, 0.9570]]), tensor([[0.5510, 0.6604]]))
data2[0]: tensor([[0.3280, 0.8911]])
data3: tensor([[1., 1.]])
data4: tensor([[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604],[1.0000, 1.0000]])
data5: tensor([[1., 1.],[1., 1.],[1., 1.]])
data6: tensor([[[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]],[[1.0000, 1.0000],[1.0000, 1.0000],[1.0000, 1.0000]]])
data7: tensor([[[0.3280, 0.8911],[0.4691, 0.9570],[0.5510, 0.6604]]])
data1.size: torch.Size([3, 2])
data7.size: torch.Size([1, 3, 2])
data8.size: torch.Size([3, 2])
PyTorch常用的张量创建、变形及运算总结(速查表)相关推荐
- pandas常用函数说明及速查表
pandas常用函数说明及速查表 如果你用python做开发,那么几乎肯定会使用pandas库. Pandas 是 Python 语言的一个扩展程序库,用于数据分析. Pandas 是一个开放源码.B ...
- Git 常用命令速查表(图文+表格)
一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态 git commit 提交 git branch -a 查看所有的分支 git branch ...
- Git 常用命令速查表(图文+表格)【转】
转自:http://www.jb51.net/article/55442.htm 一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态 git co ...
- 常用Python标准库对象速查表(1)
封面图片:<Python程序设计基础(第2版)>,董付国,清华大学出版社 =============== 常用Python标准库对象速查表(1) 标准库 对象 简要说明 math sin( ...
- MySQL 常用命令速查表:日常开发、求职面试必备良方!
备注:PDF 版本点此下载. 文章目录 连接服务器 查看帮助 查看连接 退出连接 账户和权限 创建用户 查看用户 修改密码 锁定/解锁用户 用户授权 查看权限 撤销权限 管理角色 删除用户 管理数据库 ...
- 全套Python数据分析常用命令速查表!PDF文档限时分享
当下利用python学习数据分析的热度越来越高,对于很多新手而言,大量要学习的库和工具的命令繁杂,用起来不是很顺手. 今天给大家分享一份python数据分析常用命令速查表. 一共6张表,包括:Jupy ...
- Linux系统运维人员常用速查表
Linux系统运维人员常用速查表 walkingcloud 2020-08-09 19:55:41 Linux系统运维人员常用速查表 1.awk速查表 2.bash速查表 3.firewall-cmd ...
- Python常用网络爬虫速查表下载
Python常用网络爬虫速查表下载 Post方法: Get方法: css选择器 beautiful soup选择器 xpath选择器 可以将图片打印出来,放在桌面看 下载地址: 一天掌握python网 ...
- 转收藏:Git常用命令速查表
一. Git 常用命令速查 git branch 查看本地所有分支 git status 查看当前状态 git commit 提交 git branch -a 查看所有的分支 git branch ...
最新文章
- python 删除变量_DAY1-step4 Python变量:声明,连接变量,全局和局部
- 在 C++ 中实现一个轻量的标记清除 gc 系统
- .NET Core实战项目之CMS 第十章 设计篇-系统开发框架设计
- cannot find output in imported module librosa报错解决
- Android下常见的内存泄露 经典
- python模块的函数_python模块内置函数
- 决定你人生命运的10年,你做了什么?
- swiper.js插件的使用
- CentOS 如何修改mysql 用户root的密码
- 使用ES6的Promis完美解决ajax的回调(优化代码)
- java获取异常信息
- init mysql db error_Python mysql curs错误
- python万年历差农历程序_Python实现公历(阳历)转农历(阴历)的方法示例
- 那些年,我们一起做过的 Java 课后练习题(61 - 65)
- 升级coda_提高生产力:Coda的快速提示
- pdf.net sod oracle,SOD: 原PDF.NET框架将成为一个全功能的企业开发框架,而 SOD框架将是PDF.NET开发框架下面的 “数据开发框架...
- 隐私加密系列|Grin与BEAM之间技术公开对比
- linux显示全部字符集,linux 字符集 页面显示乱码
- dubbo实现SOA架构
- 企业微信小程序_集成腾讯地图实现精准定位考勤打卡
热门文章
- 程序员什么专业毕业算是科班出身?这个回答与你想的是否一样?
- 基于Python/Capl脚本 对通信矩阵报文(Flexray/Can)的周期检测(二)
- 如何获取屏幕DPI/PPI并计算A4纸在屏幕的大小
- JDK、JER、JVM是什么
- 音频文件wav转gsm
- java计算机毕业设计汽车票订购系统源码+程序+lw文档+mysql数据库
- 基于数据库的企业内部邮件系统的设计
- 如何解决Beyond Compare中文乱码问题
- 《时代三部曲》感悟三
- 有人问我为什么不买iphon12,我为什么要买iphone12 pro max