Pytorch的基本数据结构是张量Tensor。张量即多维数组。Pytorch的张量和numpy中的array很类似。

本节我们主要介绍张量的数据类型、张量的维度、张量的尺寸、张量和numpy数组等基本概念。

一,张量的数据类型

张量的数据类型和numpy.array基本一一对应,但是不支持str类型。

包括:

torch.float64(torch.double),

torch.float32(torch.float),

torch.float16,

torch.int64(torch.long),

torch.int32(torch.int),

torch.int16,

torch.int8,

torch.uint8,

torch.bool

一般神经网络建模使用的都是torch.float32类型。

i = torch.IntTensor(1);print(i,i.dtype)
x = torch.Tensor(np.array(2.0));print(x,x.dtype) #等价于torch.FloatTensor
b = torch.BoolTensor(np.array([1,0,2,0])); print(b,b.dtype)
# 不同类型进行转换i = torch.tensor(1); print(i,i.dtype)
x = i.float(); print(x,x.dtype) #调用 float方法转换成浮点类型
y = i.type(torch.float); print(y,y.dtype) #使用type函数转换成浮点类型
z = i.type_as(x);print(z,z.dtype) #使用type_as方法转换成某个Tensor相同类型

二,张量的维度

不同类型的数据可以用不同维度(dimension)的张量来表示。

标量为0维张量,向量为1维张量,矩阵为2维张量。

彩色图像有rgb三个通道,可以表示为3维张量。

视频还有时间维,可以表示为4维张量。

可以简单地总结为:有几层中括号,就是多少维的张量。

tensor3 = torch.tensor([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]])  # 3维张量
print(tensor3)
print(tensor3.dim())

三,张量的尺寸

可以使用 shape属性或者 size()方法查看张量在每一维的长度.

可以使用view方法改变张量的尺寸。

如果view方法改变尺寸失败,可以使用reshape方法.

scalar = torch.tensor(True)
print(scalar.size())
print(vector.shape)
# 使用view可以改变张量尺寸vector = torch.arange(0,12)
print(vector)
print(vector.shape)matrix34 = vector.view(3,4)
print(matrix34)
print(matrix34.shape)matrix43 = vector.view(4,-1) #-1表示该位置长度由程序自动推断
print(matrix43)
print(matrix43.shape)
# 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法matrix26 = torch.arange(0,12).view(2,6)
print(matrix26)
print(matrix26.shape)# 转置操作让张量存储结构扭曲
matrix62 = matrix26.t()
print(matrix62.is_contiguous())# 直接使用view方法会失败,可以使用reshape方法
#matrix34 = matrix62.view(3,4) #error!
matrix34 = matrix62.reshape(3,4) #等价于matrix34 = matrix62.contiguous().view(3,4)
print(matrix34)

四,张量和numpy数组

可以用numpy方法从Tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到Tensor。

两种方法关联的Tensor和numpy数组是共享数据内存的

如果改变其中一个,另外一个的值也会发生改变。

如果有需要,可以用张量的clone方法拷贝张量,中断这种关联

此外,还可以使用item方法从标量张量得到对应的Python数值

使用tolist方法从张量得到对应的Python数值列表。

import numpy as np
import torch
#torch.from_numpy函数从numpy数组得到Tensorarr = np.zeros(3)
tensor = torch.from_numpy(arr)
print("before add 1:")
print(arr)
print(tensor)print("\nafter add 1:")
np.add(arr,1, out = arr) #给 arr增加1,tensor也随之改变
print(arr)
print(tensor)

# numpy方法从Tensor得到numpy数组tensor = torch.zeros(3)
arr = tensor.numpy()
print("before add 1:")
print(tensor)
print(arr)print("\nafter add 1:")#使用带下划线的方法表示计算结果会返回给调用 张量
tensor.add_(1) #给 tensor增加1,arr也随之改变
#或: torch.add(tensor,1,out = tensor)
print(tensor)
print(arr)

速成pytorch学习——2天相关推荐

  1. 速成pytorch学习——11天. 使用GPU训练模型

    深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天. 训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代. 当数据准 ...

  2. 速成pytorch学习——7天模型层layers

    深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.Flatten, nn ...

  3. 速成pytorch学习——5天nn.functional 和 nn.Module

    一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模 ...

  4. 速成pytorch学习——3天自动微分机制

    神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Pytorch一般通过反向传播 backward 方 ...

  5. 速成pytorch学习——1天

    一.Pytorch的建模流程 使用Pytorch实现神经网络模型的一般流程包括: 1,准备数据 2,定义模型 3,训练模型 4,评估模型 5,使用模型 6,保存模型. 对新手来说,其中最困难的部分实际 ...

  6. 速成pytorch学习——10天.训练模型的3种方法

    Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异. 有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环. 下面以minist数据集的分类模型的训练 ...

  7. 速成pytorch学习——8天损失函数

    一般来说,监督学习的目标函数由损失函数和正则化项组成.(Objective = Loss + Regularization) Pytorch中的损失函数一般在训练模型时候指定. 注意Pytorch中内 ...

  8. 速成pytorch学习——6天Dataset和DataLoader

    Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...

  9. 速成pytorch学习——4天中阶API示范

    使用Pytorch的中阶API实现线性回归模型和和DNN二分类模型. Pytorch的中阶API主要包括各种模型层,损失函数,优化器,数据管道等等. 一,线性回归模型 1,准备数据 import nu ...

最新文章

  1. 提高oracle查询效率
  2. 【开发环境】Ubuntu 中使用 VSCode 开发 C/C++ ④ ( 创建 tasks.json 编译器构建配置文件 | tasks.json 编译器构建配置文件分析 )
  3. python3教程廖雪峰云-python3基础教程廖雪峰云(如何规划研究生三年最后成为谷歌软件工程师)...
  4. 电脑选购:看完才明白,一体机和台式机电脑哪个好?
  5. Eclipse控制项目的访问名称
  6. 如何在Java中读取CSV文件-Iterator和Decorator的案例研究
  7. 定时器中断实验 编写程序使定时器0或者定时器1工作在方式1,定时500ms使两位数码管从00、01、02……98、99每间隔500ms加1显示。
  8. libevent源码深度剖析四
  9. 报告正在使用哪些Reporting Services数据集字段?
  10. 叶子结点和分支节点_教你玩转二叉查找树的结点插入操作
  11. Oracle的字符串转换成二进制,将二进制字符串解析为文本/字符
  12. 未找到依赖项 ‘org.apache.spark:spark-hive_2.11:2.4.5‘
  13. windows下使用批处理设置环境变量
  14. 旅游网站php源码,基于ThinkPHP框架开发的青春旅行旅游门户整站PHP源码
  15. 手机淘宝列表页面 的js调用展示
  16. 【更新】Excel控件Spire.XLS for .NET V7.12.90发布 | 支持向工作表添加形状
  17. openid与商户appid不匹配
  18. flappy+bird+c语言程序,C语言版flappy_bird实现
  19. 在 Ubuntu 16.04上安装 vsFTPd
  20. 尚学堂-HTML-CSS(基础)的学习记录

热门文章

  1. ChaiNext:ETH底部试探后反弹,测试1500关口
  2. Kava下一阶段Kava 5主网将于3月4日上线
  3. SAP License:SAP S/4HANA Cloud介绍
  4. SAP License:FI学习笔记
  5. 将一个js项目改造成vue项目
  6. HDU 3974 Assign the task(DFS序+线段树单点查询,区间修改)
  7. idea 中maven编译速度过慢的问题的解决
  8. 【ESP8266】发送HTTP请求
  9. 十一章--软件设计与实现
  10. js 去除字符串左右两边的空格