速成pytorch学习——2天
Pytorch的基本数据结构是张量Tensor。张量即多维数组。Pytorch的张量和numpy中的array很类似。
本节我们主要介绍张量的数据类型、张量的维度、张量的尺寸、张量和numpy数组等基本概念。
一,张量的数据类型
张量的数据类型和numpy.array基本一一对应,但是不支持str类型。
包括:
torch.float64(torch.double),
torch.float32(torch.float),
torch.float16,
torch.int64(torch.long),
torch.int32(torch.int),
torch.int16,
torch.int8,
torch.uint8,
torch.bool
一般神经网络建模使用的都是torch.float32类型。
i = torch.IntTensor(1);print(i,i.dtype)
x = torch.Tensor(np.array(2.0));print(x,x.dtype) #等价于torch.FloatTensor
b = torch.BoolTensor(np.array([1,0,2,0])); print(b,b.dtype)
# 不同类型进行转换i = torch.tensor(1); print(i,i.dtype)
x = i.float(); print(x,x.dtype) #调用 float方法转换成浮点类型
y = i.type(torch.float); print(y,y.dtype) #使用type函数转换成浮点类型
z = i.type_as(x);print(z,z.dtype) #使用type_as方法转换成某个Tensor相同类型
二,张量的维度
不同类型的数据可以用不同维度(dimension)的张量来表示。
标量为0维张量,向量为1维张量,矩阵为2维张量。
彩色图像有rgb三个通道,可以表示为3维张量。
视频还有时间维,可以表示为4维张量。
可以简单地总结为:有几层中括号,就是多少维的张量。
tensor3 = torch.tensor([[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]) # 3维张量
print(tensor3)
print(tensor3.dim())
三,张量的尺寸
可以使用 shape属性或者 size()方法查看张量在每一维的长度.
可以使用view方法改变张量的尺寸。
如果view方法改变尺寸失败,可以使用reshape方法.
scalar = torch.tensor(True)
print(scalar.size())
print(vector.shape)
# 使用view可以改变张量尺寸vector = torch.arange(0,12)
print(vector)
print(vector.shape)matrix34 = vector.view(3,4)
print(matrix34)
print(matrix34.shape)matrix43 = vector.view(4,-1) #-1表示该位置长度由程序自动推断
print(matrix43)
print(matrix43.shape)
# 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法matrix26 = torch.arange(0,12).view(2,6)
print(matrix26)
print(matrix26.shape)# 转置操作让张量存储结构扭曲
matrix62 = matrix26.t()
print(matrix62.is_contiguous())# 直接使用view方法会失败,可以使用reshape方法
#matrix34 = matrix62.view(3,4) #error!
matrix34 = matrix62.reshape(3,4) #等价于matrix34 = matrix62.contiguous().view(3,4)
print(matrix34)
四,张量和numpy数组
可以用numpy方法从Tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到Tensor。
这两种方法关联的Tensor和numpy数组是共享数据内存的。
如果改变其中一个,另外一个的值也会发生改变。
如果有需要,可以用张量的clone方法拷贝张量,中断这种关联。
此外,还可以使用item方法从标量张量得到对应的Python数值。
使用tolist方法从张量得到对应的Python数值列表。
import numpy as np
import torch
#torch.from_numpy函数从numpy数组得到Tensorarr = np.zeros(3)
tensor = torch.from_numpy(arr)
print("before add 1:")
print(arr)
print(tensor)print("\nafter add 1:")
np.add(arr,1, out = arr) #给 arr增加1,tensor也随之改变
print(arr)
print(tensor)
# numpy方法从Tensor得到numpy数组tensor = torch.zeros(3)
arr = tensor.numpy()
print("before add 1:")
print(tensor)
print(arr)print("\nafter add 1:")#使用带下划线的方法表示计算结果会返回给调用 张量
tensor.add_(1) #给 tensor增加1,arr也随之改变
#或: torch.add(tensor,1,out = tensor)
print(tensor)
print(arr)
速成pytorch学习——2天相关推荐
- 速成pytorch学习——11天. 使用GPU训练模型
深度学习的训练过程常常非常耗时,一个模型训练几个小时是家常便饭,训练几天也是常有的事情,有时候甚至要训练几十天. 训练过程的耗时主要来自于两个部分,一部分来自数据准备,另一部分来自参数迭代. 当数据准 ...
- 速成pytorch学习——7天模型层layers
深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.Flatten, nn ...
- 速成pytorch学习——5天nn.functional 和 nn.Module
一,nn.functional 和 nn.Module 前面我们介绍了Pytorch的张量的结构操作和数学运算中的一些常用API. 利用这些张量的API我们可以构建出神经网络相关的组件(如激活函数,模 ...
- 速成pytorch学习——3天自动微分机制
神经网络通常依赖反向传播求梯度来更新网络参数,求梯度过程通常是一件非常复杂而容易出错的事情. 而深度学习框架可以帮助我们自动地完成这种求梯度运算. Pytorch一般通过反向传播 backward 方 ...
- 速成pytorch学习——1天
一.Pytorch的建模流程 使用Pytorch实现神经网络模型的一般流程包括: 1,准备数据 2,定义模型 3,训练模型 4,评估模型 5,使用模型 6,保存模型. 对新手来说,其中最困难的部分实际 ...
- 速成pytorch学习——10天.训练模型的3种方法
Pytorch通常需要用户编写自定义训练循环,训练循环的代码风格因人而异. 有3类典型的训练循环代码风格:脚本形式训练循环,函数形式训练循环,类形式训练循环. 下面以minist数据集的分类模型的训练 ...
- 速成pytorch学习——8天损失函数
一般来说,监督学习的目标函数由损失函数和正则化项组成.(Objective = Loss + Regularization) Pytorch中的损失函数一般在训练模型时候指定. 注意Pytorch中内 ...
- 速成pytorch学习——6天Dataset和DataLoader
Pytorch通常使用Dataset和DataLoader这两个工具类来构建数据管道. Dataset定义了数据集的内容,它相当于一个类似列表的数据结构,具有确定的长度,能够用索引获取数据集中的元素. ...
- 速成pytorch学习——4天中阶API示范
使用Pytorch的中阶API实现线性回归模型和和DNN二分类模型. Pytorch的中阶API主要包括各种模型层,损失函数,优化器,数据管道等等. 一,线性回归模型 1,准备数据 import nu ...
最新文章
- 提高oracle查询效率
- 【开发环境】Ubuntu 中使用 VSCode 开发 C/C++ ④ ( 创建 tasks.json 编译器构建配置文件 | tasks.json 编译器构建配置文件分析 )
- python3教程廖雪峰云-python3基础教程廖雪峰云(如何规划研究生三年最后成为谷歌软件工程师)...
- 电脑选购:看完才明白,一体机和台式机电脑哪个好?
- Eclipse控制项目的访问名称
- 如何在Java中读取CSV文件-Iterator和Decorator的案例研究
- 定时器中断实验 编写程序使定时器0或者定时器1工作在方式1,定时500ms使两位数码管从00、01、02……98、99每间隔500ms加1显示。
- libevent源码深度剖析四
- 报告正在使用哪些Reporting Services数据集字段?
- 叶子结点和分支节点_教你玩转二叉查找树的结点插入操作
- Oracle的字符串转换成二进制,将二进制字符串解析为文本/字符
- 未找到依赖项 ‘org.apache.spark:spark-hive_2.11:2.4.5‘
- windows下使用批处理设置环境变量
- 旅游网站php源码,基于ThinkPHP框架开发的青春旅行旅游门户整站PHP源码
- 手机淘宝列表页面 的js调用展示
- 【更新】Excel控件Spire.XLS for .NET V7.12.90发布 | 支持向工作表添加形状
- openid与商户appid不匹配
- flappy+bird+c语言程序,C语言版flappy_bird实现
- 在 Ubuntu 16.04上安装 vsFTPd
- 尚学堂-HTML-CSS(基础)的学习记录