tensor数据类型

Tensor在使用时可以有不同的数据类型,官方给出了 7种CPU Tensor类型与8种GPU Tensor类型。16位半精度浮点是专为GPU模型设计的,以尽可能地节省GPU显存占用,但这种节省显存空间的方式也缩小了所能表达数据的大小。PyTorch中默认的数据类型是 torch.FloatTensor,即torch.Tensor等同于torch.FloatTensor。

PyTorch可以通过set_default_tensor_type函数设置默认使用的Tensor 类型,在局部使用完后如果需要其他类型,则还需要重新设置回所需的 类型。

torch.set_default_tensor_type('torch.DoubleTensor')

类型转换

对于Tensor之间的类型转换,可以通过type(new_type)、type_as()、 int()等多种方式进行操作,尤其是type_as()函数,在后续的模型学习中 可以看到,我们想保持Tensor之间的类型一致,只需要使用type_as()即 可,并不需要明确具体是哪种类型。下面分别举例讲解这几种方法的使 用方式。

# 创建新Tensor,默认类型为torch.FloatTensor
>>> a = torch.Tensor(2, 2)
>>> a tensor(1.00000e-36 * [[-4.0315, 0.0000], [ 0.0700, 0.0000]])
# 使用int()、float()、double()等直接进行数据类型转换
>>> b = a.double()
>>> b tensor(1.00000e-36 * [[-4.0315, 0.0000], [ 0.0700, 0.0000]], dtype=torch.float64)
# 使用type()函数 >>> c = a.type(torch.DoubleTensor)
>>> c
tensor(1.00000e-36 * [[-4.0315, 0.0000], [ 0.0700, 0.0000]], dtype=torch.float64)
# 使用type_as()函数
>>> d = a.type_as(b)
>>> dtensor(1.00000e-36 *[[-4.0315, 0.0000],[ 0.0700, 0.0000]], dtype=torch.float64)

Tensor的创建与维度查看

Tensor有多种创建方法,如基础的构造函数Tensor(),还有多种与 NumPy十分类似的方法,如ones()、eye()、zeros()和randn()等,图2.1列 举了常见的Tensor创建方法。

# 最基础的Tensor()函数创建方法,参数为Tensor的每一维大小
>>> a=torch.Tensor(2,2)
>>> a
tensor(1.00000e-18 *[[-8.2390, 0.0000],[ 0.0000, 0.0000]])
>>> b = torch.DoubleTensor(2,2)
>>> b
tensor(1.00000e-310 *[[ 0.0000, 0.0000], [ 6.9452, 0.0000]], dtype=torch.float64)
# 使用Python的list序列进行创建
>>> c = torch.Tensor([[1, 2], [3, 4]])
>>> c
tensor([[ 1., 2.], [ 3., 4.]])
# 使用zeros()函数,所有元素均为0
>>> d = torch.zeros(2, 2)
>>> d
tensor([[ 0., 0.], [ 0., 0.]])
# 使用ones()函数,所有元素均为1 >>> e = torch.ones(2, 2)
>>> e
tensor([[ 1., 1.], [ 1., 1.]])
# 使用eye()函数,对角线元素为1,不要求行列数相同,生成二维矩阵 >>> f = torch.eye(2, 2)
>>> ftensor([[ 1., 0.], [ 0., 1.]])
# 使用randn()函数,生成随机数矩阵
>>> g = torch.randn(2, 2)
>>> gtensor([[-0.3979, 0.2728], [ 1.4558, -0.4451]])
# 使用arange(start, end, step)函数,表示从start到end,间距为step,一维向量
>>> h = torch.arange(1, 6, 2)
>>> htensor([ 1., 3., 5.])
# 使用linspace(start, end, steps)函数,表示从start到end,一共steps份,一维向量
>>> i = torch.linspace(1, 6, 2)
>>> itensor([ 1., 6.]) .
# 使用randperm(num)函数,生成长度为num的随机排列向量
>>> j = torch.randperm(4)
>>> j
tensor([ 1, 2, 0, 3])
# PyTorch 0.4中增加了torch.tensor()方法,参数可以为Python的list、NumPy的ndarray等
>>> k = torch.tensor([1, 2, 3]) tensor([ 1, 2, 3])

对于Tensor的维度,可使用Tensor.shape或者size()函数查看每一维 的大小,两者等价。

>>> a=torch.randn(2,2)
# 使用shape查看Tensor维度
>>> a.shape
torch.Size([2, 2])
# 使用size()函数查看Tensor维度 torch.Size([2, 2])>>> a.size() torch.Size([2, 2])

查看Tensor中的元素总个数,可使用Tensor.numel()或者 Tensor.nelement()函数,两者等价。

# 查看Tensor中总的元素个数
>> a.numel()
4
>>> a.nelement()
4

Tensor的组合与分块

组合与分块是将Tensor相互叠加或者分开,是十分常用的两个功 能,PyTorch提供了多种操作函数,如图2.2所示。

组合操作是指将不同的Tensor叠加起来,主要有torch.cat()和 torch.stack()两个函数。cat即concatenate的意思,是指沿着已有的数据的 某一维度进行拼接,操作后数据的总维数不变,在进行拼接时,除了拼 接的维度之外,其他维度必须相同。而torch.stack()函数指新增维度,并 按照指定的维度进行叠加,具体示例如下:

# 创建两个2×2的Tensor
>>> a=torch.Tensor([[1,2],[3,4]])
>>> a tensor([[ 1., 2.], [ 3., 4.]])
>>> b = torch.Tensor([[5,6], [7,8]])
>>> btensor([[ 5., 6.], [ 7., 8.]])
# 以第一维进行拼接,则变成4×2的矩阵
>>> torch.cat([a,b], 0)
tensor([[ 1., 2.], [ 3., 4.], [ 5., 6.], [ 7., 8.]]) # 以第二维进行拼接,则变成24的矩阵
>>> torch.cat([a,b], 1)
tensor([[ 1., 2., 5., 6.], [ 3., 4., 7., 8.]])
# 以第0维进行stack,叠加的基本单位为序列本身,即a与b,因此输出[a, b],输出维度为2×2×2
>>> torch.stack([a,b], 0)
tensor([[[ 1., 2.], [ 3., 4.]], [[ 5., 6.], [ 7., 8.]]]) # 以第1维进行stack,叠加的基本单位为每一行,输出维度为2×2×2
>>> torch.stack([a,b], 1)
tensor([[[ 1., 2.], [ 5., 6.]], [[ 3., 4.], [ 7., 8.]]]) # 以第2维进行stack,叠加的基本单位为每一行的每一个元素,输出维度为2×2×2
>>> torch.stack([a,b], 2)
tensor([[[ 1., 5.], [ 2., 6.]], [[ 3., 7.], [ 4., 8.]]])

分块则是与组合相反的操作,指将Tensor分割成不同的子Tensor, 主要有torch.chunk()与torch.split()两个函数,前者需要指定分块的数量, 而后者则需要指定每一块的大小,以整型或者list来表示。具体示例如 下:

>>> a=torch.Tensor([[1,2,3],[4,5,6]])
>>> a tensor([[ 1., 2., 3.], [ 4., 5., 6.]])
# 使用chunk,沿着第0维进行分块,一共分两块,因此分割成两个1×3的Tensor >>> torch.chunk(a, 2, 0)
(tensor([[ 1., 2., 3.]]), tensor([[ 4., 5., 6.]]))
# 沿着第1维进行分块,因此分割成两个Tensor,当不能整除时,最后一个的维数会小于前面的因此第一个Tensor为2×2,第二个为2×1 >>> torch.chunk(a, 2, 1)
(tensor([[ 1., 2.], [ 4., 5.]]), tensor([[ 3.], [ 6.]])) # 使用split,沿着第0维分块,每一块维度为2,由于第一维维度总共为2,因此相当于没有分割
>>> torch.split(a, 2, 0)
(tensor([[ 1., 2., 3.], [ 4., 5., 6.]]),)
# 沿着第1维分块,每一块维度为2,因此第一个Tensor为2×2,第二个为2×1
>>>> torch.split(a, 2, 1)
(tensor([[ 1., 2.], [ 4., 5.]]), tensor([[ 3.], [ 6.]]))
# split也可以根据输入的list进行自动分块,list中的元素代表了每一个块占的维度
>>> torch.split(a, [1,2], 1)
(tensor([[ 1.], [ 4.]]), tensor([[ 2., 3.], [ 5., 6.]]))

动手学pytorch之tensor数据(一)相关推荐

  1. pytorch默认初始化_小白学PyTorch | 9 tensor数据结构与存储结构

    [机器学习炼丹术]的学习笔记分享<> 小白学PyTorch | 8 实战之MNIST小试牛刀 小白学PyTorch | 7 最新版本torchvision.transforms常用API翻 ...

  2. 动手学PyTorch | (5) Softmax回归实验

    目录 1. 图像分类数据集(Fashion-Mnist) 2. Softmax回归从0开始实现 3. Softmax回归的简洁实现 1. 图像分类数据集(Fashion-Mnist) 在介绍softm ...

  3. 动手学pytorch之通俗易懂何为卷积-深度AI科普团队

    文章目录 简介 为什么要用卷积 卷积神经网络的由来 什么是卷积 定义 解释 卷积运算 信号分析 离散卷积例子:丢骰子 图像处理卷积操作 简介 为什么要用卷积 卷积操作是机器视觉,乃至整个深度学习的核心 ...

  4. 【TL第二期】动手学数据分析-第二章 数据预处理

    文章目录 第二章 第一节 数据清洗及特征处理 第二节 数据重构1 第三节 数据重构2 第四节 数据可视化 第二章 第一节 数据清洗及特征处理 数据清洗:对于原始数据中的缺失值.异常值进行处理.相当于数 ...

  5. 动手学pytorch笔记整理12

    conv-layer 二维卷积层 二维互相关运算 特征图和感受野 填充和步幅 填充 步幅 二维卷积层 卷积神经网络(convolutional neural network)是含有卷积层(convol ...

  6. 【TL第二期】动手学数据分析-第一章 数据基本操作

    文章目录 第一章 第一节 数据载入与初步观察 0 导库 1 载入数据 2 查看数据基本信息 第二节 pandas基础 1 数据类型DataFrame 和 Series 2 对文件数据的基本操作 3 数 ...

  7. 动手学PyTorch | (35) 长短期记忆(LSTM)

    本节将介绍另一种常⽤的⻔控循环神经网络:长短期记忆(long short-term memory,LSTM).它⽐⻔控循环单元的结构稍微复杂一点. 目录 1. 长短期记忆 2. 读取数据集 3. 从0 ...

  8. 动手学PyTorch | (41) Adagrad算法

    在之前介绍过的优化算法中,⽬标函数⾃变量的每一个元素在相同时间步都使用同一个学习率来⾃我迭代.举个例子,假设⽬标函数为f,⾃变量为一个二维向量,该向量中每一个元素在迭代时都使⽤相同的学习率.例如,在学 ...

  9. 【小白学PyTorch】扩展之Tensorflow2.0 | 21 Keras的API详解(上)卷积、激活、初始化、正则...

    [机器学习炼丹术]的学习笔记分享 <<小白学PyTorch>> 扩展之Tensorflow2.0 | 20 TF2的eager模式与求导 扩展之Tensorflow2.0 | ...

最新文章

  1. 【每日一算法】1比特与2比特字符
  2. properties 配置回车_非常全面的讲解SpringCloud中Zuul网关原理及其配置,看它就够了!...
  3. IT团队之非正式沟通
  4. 【哈理工实验二】HTML+CSS3 旋转齿轮特效
  5. 小学认识计算机硬件ppt,认识计算机硬件课件.ppt
  6. 终于,腾讯也要造车了
  7. 一体机硬盘被格式化了的资料恢复法子
  8. python基础教程-Python基础
  9. python windows api截图_Winapi快速截图并打开
  10. 三分钟介绍什么是前端开发框架
  11. 【车辆识别】基于卷积神经网络yolov3识别车辆和车辆速度附matlab代码
  12. 通过telnet命令使用SMTP、POP3协议收发邮件(以QQ邮箱为例)
  13. 微信公众号三方平台开发【代微信公众号接收消息事件并响应】
  14. 学php收获与体会,实习心得体会及收获
  15. 英语字母c的语言教案,[小班英语教案认识字母]幼儿园小班英语教案:字母C.doc...
  16. Android 设置壁纸流程
  17. DOS发包攻击软件下载(需Python环境)
  18. 将office 的文件,word,xlsx,ppt,txt 转成pdf 供预览
  19. 解密Blob加密的src拼接的url视频资源
  20. 【python基础】——python 复数运算

热门文章

  1. python wms_webGIS实践:4_2_python django整合geoserver wms服务
  2. 3蛋白wb_WB常见问题原因分析及解决办法
  3. java jni调用dll_浅谈JNI的使用--java调用dll(原创)
  4. android异步加载视频缩略图,swift-如何将视频URL的缩略图异步加载到tableview列表中...
  5. 查找树的指定层级_阿里面试,问了B+树,这个回答让我通过了
  6. 读取图像矩阵维度必须一致_深度学习在放射治疗中的应用——工具篇(二)矩阵基本操作...
  7. 如何抓取一个网站的分页_如何设计一个吸引人的网站
  8. Django框架 day02
  9. odoo中页面跳转相关
  10. node.js 端口号被占用解决方法