深度学习框架PyTorch:入门与实践 学习(二)
Tensor和Autograd------Tensor
Tensor
创建:
- Tensor(*size),创建时不会立马分配空间,使用时才会分配,而其他均创建时立马就分配。
- ones(*sizes)
- zeros(*sizes)
- eye(*sizes)对角线为1其他为0
- arange(s,e,step)从s到e步长为step
- linspace(s,e,steps)从s到e,均匀切分成steps份
- rand/randn(*sizes)均匀/标准整体分布
- normal(mean,std)/uniform(from,to)
- randperm(m)随机排列
利用List创建tensor
a = torch.Tensor([[1, 2, 3], [4, 5, 6]])
print(a)
b = a.tolist()
print(b)
print(a.size())
print(a.numel())#输出the num of element
# 创建一个和a一样大小的Tensor
c = torch.Tensor(a.size())
print(c.size())
# 查看c的形状
print(c.shape)
常用tensor操作:
- Tensor.view()可以调整tensor的形状,但不会修改自身的数据,两者共享内存
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) print(a) b = a.view(3, 2) print(b)
# -1自动计算该维度的大小 c = a.view(-1, 6) print(c)
squeeze,unsqueeze减少维度增加维度
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) b = a.view(3, 2) # 在1维增加一维 b = b.unsqueeze(1) print(b.size()) # 减去倒数第二维 c = b.squeeze(-2) print(c.size())
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) print(a.size()) b = a.view(1, 1, 1, 2, 3) print(b.size()) c = b.squeeze(0) print(c.size()) # 把所有维度为1的都删去 c = c.squeeze() print(c.size())
- resize:调整size,可以修改tensor的尺寸,如果新尺寸大于原尺寸会分配新的内存空间,如果小于原尺寸,数据依旧会被保存。
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) print(a) print(a.size()) b = a.resize_(1, 3) print(b.size()) print(b) b = a.resize_(3, 3) print(b.size()) print(b)
索引:
a = torch.Tensor([[1, 2, 3], [4, 5, 6]]) print(a) # 输出每个元素是否满足条件,满足为1否则为0 print(a > 1) # 输出满足条件的元素 print(a[a>1])
gather(input, dim, index):根据index在dim维选取数据,选取的大小跟index一样。dim=0时,out[i][j]=input[index[i][j]][j],dim=1时,out[i][j]=input[i][index[i][j]]
a = torch.arange(0, 16).view(4, 4) print(a) index = torch.LongTensor([[0, 1, 2, 3]]) b = a.gather(0, index) print(b) index = torch.LongTensor([[3], [2], [1], [0]]) c = a.gather(1, index) print(c)
scatter:gather的逆操作,
c = torch.zeros((4, 4)) c.scatter_(0, index, b) print(c.size()) print(c)
高级索引:
x[[1,2],[1,2],[2,0]]# x[1,1,2],x[2,2,0] x[[2,1,0],[0],[1]] #x[2,0,1],x[1,0,1],x[0,0,1]
线性回归
import torch as t
from matplotlib import pyplot as plt
from IPython import displayt.manual_seed(100)def get_fake_data(batch_size=8):x = t.rand(batch_size, 1) * 20y = x * 2 + (1 + t.randn(batch_size, 1)) * 3return x, yx, y = get_fake_data()
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())w = t.rand(1, 1)
b = t.zeros(1, 1)
lr = 0.001for ii in range(20000):x, y = get_fake_data()y_pred = x.mul(w) + b.expand_as(y)loss = 0.5 * (y_pred - y) ** 2loss = loss.sum()dloss = 1dy_pred = dloss * (y_pred - y)dw = x * dy_preddb = dy_pred.sum()w.sub_((lr * dw).sum())b.sub_(lr * db)if ii % 1000 == 0:display.clear_output(wait=True)x = t.arange(0, 20).view(-1, 1)y = x.mul(w) + b.expand_as(x)plt.plot(x.numpy(), y.numpy())x2, y2 = get_fake_data(batch_size=20)plt.scatter(x2.numpy(), y2.numpy())plt.xlim(0, 20)plt.ylim(0, 41)plt.show()plt.pause(0.5)
Tensor 和Autograd------Autograd
Variable不支持部分Inplace 函数,因为这些函数会修改tensor自身,但在反向传播中,variable需要缓存原来的tensor计算梯度。
torch.autograd.grad(z, y)
输出z对y的梯度
- autograd根据用户对variable的操作构建计算图,对variable的操作抽象为function
- 由用户创建的节点称为叶子节点,叶子节点的grad_fn为None,叶子节点中需要求导的variable具有accumulateGrad,因为其梯度是累加的。
- variable默认requires_grad=false。当一个节点的rrquires_grad设置为true时,其他依赖它的节点的requires_grad均为true
- volatile=True,将所有依赖它的节点全部设置为vllatile=true,优先级比require_grad=True高,volatile的节点不会求导,也无法进行反向传播。
用Variable实现线性回归
import torch as t
from torch.autograd import Variable as V
from matplotlib import pyplot as plt
from IPython import displayt.manual_seed(1000)def get_fake_data(batch_size=16):x = t.rand(batch_size, 1) * 20y = x * 2 + (1 + t.randn(batch_size, 1)) * 3return x, yx, y = get_fake_data()
plt.scatter(x.squeeze().numpy(), y.squeeze().numpy())
w = V(t.rand(1, 1), requires_grad=True)
b = V(t.zeros(1, 1), requires_grad=True)
lr = 0.0001
for ii in range(8000):x, y = get_fake_data()x = V(x)y = V(y)y_pred = x.mul(w) + b.expand_as(y)loss = 0.5 * (y_pred - y) ** 2loss = loss.sum()loss.backward()dloss = 1w.data = w.data - lr * w.grad.datab.data = b.data - lr * b.grad.dataw.grad.data.zero_()b.grad.data.zero_()if ii % 1000 == 0:display.clear_output(wait=True)x = t.arange(0, 20).view(-1, 1)y = x.mul(w.data) + b.data.expand_as(x)plt.plot(x.numpy(), y.numpy())x2, y2 = get_fake_data(batch_size=20)plt.scatter(x2.numpy(), y2.numpy())plt.xlim(0, 20)plt.ylim(0, 41)plt.show()plt.pause(0.5)
print(w.data.squeeze()[0], b.data.squeeze()[0])
深度学习框架PyTorch:入门与实践 学习(二)相关推荐
- numpy pytorch 接口对应_拆书分享篇深度学习框架PyTorch入门与实践
<<深度学习框架PyTorch入门与实践>>读书笔记 <深度学习框架PyTorch入门与实践>读后感 小作者:马苗苗 读完<<深度学习框架PyTorc ...
- 深度学习框架Pytorch入门与实践——读书笔记
2 快速入门 2.1 安装和配置 pip install torch pip install torchvision#IPython魔术命令 import torch as t a=t.Tensor( ...
- 深度学习框架PyTorch入门与实践:第二章 快速入门
本章主要介绍两个内容,2.1节介绍如何安装PyTorch,以及如何配置学习环境:2.2节将带领读者快速浏览PyTorch中主要内容,给读者一个关于PyTorch的大致印象. 2.1 安装与配置 2.1 ...
- 深度学习框架PyTorch入门与实践:第八章 AI艺术家:神经网络风格迁移
本章我们将介绍一个酷炫的深度学习应用--风格迁移(Style Transfer).近年来,由深度学习引领的人工智能技术浪潮越来越广泛地应用到社会各个领域.这其中,手机应用Prisma,尝试为用户的照片 ...
- 深度学习框架PyTorch入门与实践:第七章 AI插画师:生成对抗网络
生成对抗网络(Generative Adversarial Net,GAN)是近年来深度学习中一个十分热门的方向,卷积网络之父.深度学习元老级人物LeCun Yan就曾说过"GAN is t ...
- 深度学习框架PyTorch入门与实践:第九章 AI诗人:用RNN写诗
我们先来看一首诗. 深宫有奇物,璞玉冠何有. 度岁忽如何,遐龄复何欲. 学来玉阶上,仰望金闺籍. 习协万壑间,高高万象逼. 这是一首藏头诗,每句诗的第一个字连起来就是"深度学习". ...
- 深度学习框架pytorch入门之张量Tensor(一)
文章目录 一.简介 二.查看帮助文档 三.Tensor常用方法 1.概述 2.新建方法 (1)Tensor(*sizes) tensor基础构造函数 (2)ones(*sizes) 构造一个全为1的T ...
- 深度学习框架PyTorch一书的学习-第四章-神经网络工具箱nn
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 本章介绍的nn模块是构建与autogr ...
- 深度学习框架PyTorch一书的学习-第三章-Tensor和autograd-1-Tensor
参考https://github.com/chenyuntc/pytorch-book/tree/v1.0 希望大家直接到上面的网址去查看代码,下面是本人的笔记 Tensor Tensor可以是一个数 ...
- 好书分享——《深度学习框架PyTorch:入门与实践》
内容简介 : <深度学习框架PyTorch:入门与实践>从多维数组Tensor开始,循序渐进地带领读者了解PyTorch各方面的基础知识.结合基础知识和前沿研究,带领读者从零开始完成几个经 ...
最新文章
- 50篇经典珍藏 | Docker、Mesos、微服务、云原生技术干货
- ocr识别技术-车牌识别一体机的核心关键
- 确保 PHP 应用程序的安全
- delphi listview怎么自动宽度_自动门日常使用出现这些问题应尽快维修以免因小失大...
- Apache Tez介绍,术语,安装,监控等
- 51Nod 1105 第K大的数 二分答案
- 两道二分coming~
- Python学习笔记[5]---else语句和with语句
- linux访问不了apache页面,nginx做前端,apache部分页面不能访问
- Windows 10原创知识题(第三版)
- 鸟哥的Linux私房菜_服务器架设篇 第三版
- 数据分析师为什么能拿高薪
- 惠普服务器故障代码_HP服务器常见代码
- 百度地图开发入门(6):3D建筑
- 透明显示屏(隐形显示屏)简述
- 小样本学习之关系网络:让机器自己学习如何度量
- 微信联盟链接不到服务器怎么,LOL微信绑定方法及无法登录处理方案推荐
- 遇到VerifyError束手无策?
- 从键盘读入一个字符,如果该字符是大写字母则转小写,如果该字符是小写字母则转大写,如果不是字符则输出不是字母。
- mongo-go-driver 踩坑心得 server selection error