tensor是深度学习运算的基本数据结构,本文主要归纳总结了Pytorch中的tensor对象的基础知识,包括它的常用属性、创建方法以及类型转化。

1. Tensor属性

Pytorch中定义了许多类,类就包括属性和行为(方法),Tensor是最基本的类,是用来运算的基本单位。tensor的大多属性都不是基本数据类型,而是Pytorch中定义的类,比如torch.dtype、torch.device等

1.1 torch.dtype

torch.dtype 属性标识了 tensor的数据类型。Pytorch中定义了8种CPU张量类型和对应的GPU张量类型,CPU类型(如torch.FloatTensor)中间加一个cuda即为GPU类型(如torch.cuda.FloatTensor),八种数据类型如下(左一列是Python中的元素数据类型,中间一列是pytorch中tensor的元素数据类型,最后一列是pytorch中tensor的数据类型,有点绕hhh):

Data type dtype Tensor types
32-bit floating point torch.float32 or torch.float torch.*.FloatTensor
64-bit floating point torch.float64 or torch.double torch.*.DoubleTensor
16-bit floating point torch.float16 or torch.half torch.*.HalfTensor
8-bit integer (unsigned) torch.uint8 torch.*.ByteTensor
8-bit integer (signed) torch.int8 torch.*.CharTensor
16-bit integer (signed) torch.int16 or torch.short torch.*.ShortTensor
32-bit integer (signed) torch.int32 or torch.int torch.*.IntTensor
64-bit integer (signed) torch.int64 or torch.long torch.*.LongTensor

注意:type()是方法,它返回的是数据本身的数据类型(右侧那列),detype是属性,返回的是其中的数据元素的数据类型(中间一列),在pytorch它两是有一一对应关系的,比如:

x = torch.Tensor([1, 2])
print(x.dtype) # 返回x中的元素的数据类型
print(x.type()) # 返回x的数据类型#
torch.float32 #tensor中的元素的数据类型
torch.FloatTensor #tensor的数据类型

需要特别注意tensor的类型,特别是在默认(不指定dtype)情况下创建的类型,只有相同数据类型的tensor才能做运算

1.2 torch.device

torch.device 属性标识了torch.Tensor对象所存储在的设备名称,包含了两种类型 :cpu和cuda ,表示是存在CPU还是GPU,如果是GPU还可以指定具体的卡号,如果没有指定设备编号,则默认将对象存储于current_device()当前设备中。

torch.device 对象支持使用字符串或者字符串加设备编号这两种方式来创建:

cuda0= torch.device('cuda:0')
cuda1 = torch.device('cuda', 1)
cudax = torch.device('cuda') # 默认当前设备中
cpu = torch.device('cpu')

但通常我们不会用得这么麻烦,都是直接用字符串来代替,比如:

cuda1 = torch.device('cuda:1')
a = torch.randn((2,3), device=cuda1)
b = torch.randn((2,3), device='cuda:1')

1.3 torch.size

torch.size标识了tensor的形状,虽然是pytorch中的类,但本质上是tuple,获取tensor的size可以通过 tensor.shape或tensor.size(),一个是属性一个是方法,结果都一样。torch.size用法和tuple一样,比如:

a = torch.randn(size=(1,2))
b = a.shape
c = a.size()
print(b, type(b), b[0])
print(c, type(c), c[1])
#
torch.Size([1, 2]) <class 'torch.Size'> 1
torch.Size([1, 2]) <class 'torch.Size'> 2

2. Tensor创建

在PyTorch中创建一个tensor从宏观上看主要有2种:一种是使用torch提供的方法,一种是实例化Tensor类

2.1 torch方法创建

为了满足多样化的需求,torch提供了许多方法来创建tensor。

  • 大多方法都有size、dtype、device、requires_grad(默认False)几个参数
  • 创建浮点数的方法默认(不设置dtype)情况下都是torch.float32(torch.FloatTensor),创建整形的方法默认是torch.int64(torch.LongTensor),这个需要特别注意!!不同类型的tensor不能作运算
  • size就是元组,可以通过tensor.size()得到,创建一维tensor最好写成(m, ),要是写m有些方法会报错*

torch方法创建tensor主要分三类:
1.按照特定规则创建,比如随机和特定数值

常见的随机创建方法有:

torch.randn(size):均值为0,方差为1的正态分布N(0,1)torch.rand(size):产生[0,1]均匀分布的数据torch.randint(low = 0, high, size):随机生成整数值tensor,范围 [min,max):左闭右开每种方法都有一个like(input)方法,表示创建同input一样size的tensor,比如torch.rand_like(input)

常见的特定数值创建方法有:

 torch.arange(start, end, step) :生成一维的**整形**等差数列tensor, [start,end) 左闭右开,间隔为step(公差)默认为1torch.linspace(start, end, steps) : 生成一维的**浮点**等差数列tensor,包括end, steps 是点的个数,注意和arrange的区别torch.full(size, fill_value) :使用相同元素创建tensor,注意默认类型是**torch.float32**torch.zeros(size):全0tensor,类似还有全1, 注意默认类型是**torch.float32**

2.从python 对象中创建

torch.tensor(data, dtype):data可以是list、tuple,如果data的元素是浮点数,
默认创建出来的就是torch.float32(torch.FloatTensor),如果是整形,默认是torch.int64(torch.LongTensor)

3.从numpy对象中创建

这个好像用的少一点

torch.from_numpy(x):x是ndarry对象,创建出来的tensor类型和x保持一致

2.2 实例化torch.Tensor创建

另一种方法是通过实例化Tensor类来创建tensor,结果都是一样的。因为tensor有不同的数据类型,所以Tensor也有为不同的类,可以指定特定类型的类来创建特定类型的tensor,比如:

torch.FloatTensor(data):data可以是list、tuple,它的另一个别名是torch.Tensor() 这个更常用些类似的还有:
torch.LongTensor()
torch.DoubleTensor()
torch.IntTensor()
…
其它类型见八中数据类型表

3. 类型转化

上面我们讲了tensor对象它有device属性,这区分了它是存在GPU还是CPU,有dtype属性,这区分了它的元素是什么类型,也有从不同对象中(Python、numpy)创建而来的,这就涉及一个这些不同对象之间的转化问题。

1.数据类型转化
在tensor后加 .long(), .int(), .float(), .double()等即可,比如

a = torch.randn((4,), dtype=torch.float)
b = a.int()
c = a.double()
print( a.dtype)
print(b.dtype)
print(c.dtype)
#
torch.float32
torch.int32
torch.float64

2.数据存储位置转换
CPU张量 ----> GPU张量:data.cuda()
GPU张量 ----> CPU张量:data.cpu()

3.与Python数据类型转换
Tensor ----> Python list,使用data.tolist(),返回shape相同的嵌套的list
如果data是一个一维的而且只有一个数(1, ),可以用data.item()来得到Python单个数据

4.与numpy数据类型转换
Tensor---->Numpy : data.numpy()
Numpy ---->Tensor:torch.from_numpy(data)

如果对你有帮助,请点个赞:-D

Pytorch tensor基础知识相关推荐

  1. pytorch基础知识+构建LeNet对Cifar10进行训练+PyTorch-OpCounter统计模型大小和参数量+模型存储与调用

    整个环境的配置请参考我另一篇博客.ubuntu安装python3.5+pycharm+anaconda+opencv+docker+nvidia-docker+tensorflow+pytorch+C ...

  2. PyTorch学习笔记(二):PyTorch简介与基础知识

    往期学习资料推荐: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 本系列目录: PyTorch学习笔记(一):PyTorch环境安 ...

  3. Datawhale-深入浅出pytorch简介安装和基础知识

    目录 1.1 PyTorch简介 1.1.1 PyTorch的介绍 1.1.2 PyTorch的发展 1.1.3 PyTorch的优势 1.2 PyTorch的安装 1.2.1 Anaconda的安装 ...

  4. 【Datawhale 组队学习Pytorch】Task01 Pytorch安装和基础知识

    [项目简介] PyTorch是利用深度学习进行数据科学研究的重要工具,在灵活性.可读性和性能上都具备相当的优势,近年来已成为学术界实现深度学习算法最常用的框架.考虑到PyTorch的学习兼具理论储备和 ...

  5. 深入浅出Pytorch:02 PyTorch基础知识

    深入浅出Pytorch 02 PyTorch基础知识 内容属性:深度学习(实践)专题 航路开辟者:李嘉骐.牛志康.刘洋.陈安东 领航员:叶志雄 航海士:李嘉骐.牛志康.刘洋.陈安东 开源内容:http ...

  6. 第02章 PyTorch基础知识

    文章目录 第02章 Pytorch基础知识 2.1 张量 2.2 自动求导 2.3 并行计算简介 2.3.1 为什么要做并行计算 2.3.2 CUDA是个啥 2.3.3 做并行的方法 补充:通过股票数 ...

  7. [源码解析] PyTorch 流水线并行实现 (1)--基础知识

    [源码解析] PyTorch 流水线并行实现 (1)–基础知识 文章目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 ...

  8. 【Pytorch神经网络理论篇】 25 基于谱域图神经网络GNN:基础知识+GNN功能+矩阵基础+图卷积神经网络+拉普拉斯矩阵

    图神经网络(Graph Neural Network,GNN)是一类能够从图结构数据中学习特征规律的神经网络,是解决图结构数据(非欧氏空间数据)机器学习问题的最重要的技术. 1 图神经网络的基础知识 ...

  9. Pytorch基础知识(15)基于PyTorch的多标签图像分类

    早在 2012 年,神经网络就首次赢得了 ImageNet 大规模视觉识别挑战.Alex Krizhevsky,Ilya Sutskever 和 Geoffrey Hinton 彻底改变了图像分类领域 ...

  10. 一文读懂PyTorch张量基础(附代码)

    作者:Matthew Mayo, KDnuggets 翻译:和中华 校对:丁楠雅 本文约1000字,建议阅读5分钟. 本文介绍了PyTorch Tensor最基础的知识以及如何跟Numpy的ndarr ...

最新文章

  1. pr如何处理音效_Pr基础全通关:从0到1,进阶剪辑大神
  2. 图像处理学习--前篇--像素相关
  3. Error No matching provisioning profiles found
  4. Chrome浏览器插件Tab Groups Extension使用方法
  5. JQuery如何与数据库交互
  6. 苹果推送iOS13.3正式版 联通用户有惊喜!
  7. 漫画算法python版下载_用 Python 下载漫画
  8. px4讲解(一)历史起源
  9. windows server 2003 远程拨号服务器
  10. 三星S3 939/9300 android 4.3 如何打开开发者模式
  11. 笔记本电脑进入BIOS设置快捷键大全
  12. wintc下为什么有getchar()但屏幕却还是没有输出?
  13. 阿里80亿贷款细节曝光 并购背后图谋大数据
  14. iOS高性能Model转换框架----YYModel学习
  15. java 写作速度_GMAT写作提高速度4条实用经验分享
  16. ifm电感式传感器IE5238
  17. 攻防世界 web NaNNaNNaNNaN-Batman
  18. Android性能优化之图片压缩综合解决方案
  19. Python·.·.print()函数格式化输出-超详解
  20. 负片与水印效果(OpenCV)

热门文章

  1. 豫西大数据项目_大数据AI+智能雷达,上海公寓项目选址
  2. Linux 执行 Shell脚本报错,“syntax error: unexpected end of file” 原因及处理
  3. 什么是Prettier?
  4. 指数函数在c语言中怎么输入,指数函数如果想得到整型的值怎样做?
  5. 微信浪漫告白小程序java_厉害了,微信小程序可以这样表白,还怕他(她)拒绝你?...
  6. 如何删除git本地分支
  7. maven工程打包老是报错_Maven 项目打包及启动时的报错解决
  8. 计算机网络作用范围网络分为,【填空题】从不同作用范围分类,计算机网络可以分为广域网、_______、_________、_________四种...
  9. c语言迷宫算法坐标怎么定义,[原创]递归随机迷宫生成算法详解
  10. springboot源码解析autoconfigure之AopAutoConfiguration