Pytorch专题实战——反向传播(Backpropagation)
文章目录
- 1.前言
- 2.初识如何更新梯度
- 3.手动更新梯度
- 4.自动更新梯度
1.前言
大体分为三步:
(1)前向传播,计算loss
(2)计算局部梯度
(3)反向传播,用链式求导法则计算梯度
2.初识如何更新梯度
import torchx = torch.tensor(1.0) #指定输入x
y = torch.tensor(2.0) #指定输出y
w = torch.tensor(1.0, requires_grad=True) #初始化待求的w, requires_grad=True代表需要求导y_predicted = w*x #预测值
loss = (y - y_predicted)**2 #计算预测值与真实值间的loss
print(loss)loss.backward() #反向传播
print(w.grad) #打印此时的梯度with torch.no_grad():w -= 0.01*w.grad #更新梯度print(w) #打印更新后的梯度
w.grad.zero_() #梯度清零,防止梯度累加
3.手动更新梯度
import numpy as npX = np.array([1,2,3,4], dtype=np.float32) #输出X
Y = np.array([2,4,6,8], dtype=np.float32) #输出Yw = 0.0 #初始化权重def forward(x):return w*x #预测值def loss(y, y_pred):return ((y_pred-y)**2).mean() #计算loss#Loss = 1/N*(w*x-y)**2
#dLoss/dw = 1/N*2x(w*x-y)
def gradient(x,y,y_pred): #计算梯度return np.dot(2*x, y_pred-y).mean()print(f'Prediction before training:f(5)={forward(5):.3f}') #当权重为0是,预测值为0learning_rate = 0.01 #学习率
n_iters = 20 #迭代次数for epoch in range(n_iters): y_pred = forward(X) #预测值l = loss(Y,y_pred) #真实值与预测值的误差dw = gradient(X,Y,y_pred) #求梯度w -= learning_rate * dw #更新参数if epoch % 2 == 0: #每隔两个epoch更新参数print(f'epoch {epoch+1}: w={w:.3f},loss={l:.8f}')
print(f'Prediction after training:f(5)={forward(5):.3f}') #运行最后的权重算f(5)
4.自动更新梯度
import torchX = torch.tensor([1,2,3,4], dtype=torch.float32)
Y = torch.tensor([2,4,6,8], dtype=torch.float32)w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True) #初始化权重,并设置需要求导def forward(x):return w * x #预测def loss(y, y_pred):return ((y_pred-y)**2).mean() #算loss
print(f'Prediction before training:f(5)={forward(5).item():.3f}') #打印初始的预测learning_rate = 0.1
n_iters = 20for epoch in range(n_iters):y_pred = forward(X)l = loss(Y, y_pred)l.backward() #等同于计算梯度with torch.no_grad():w -= learning_rate*w.grad #更新参数w.grad.zero_() #梯度清零,防止梯度累计if epoch % 2==0:print(f'epoch {epoch+1}: w={w.item():.3f}, loss={l.item():.8f}')
print(f'Prediction after training:f(5)={forward(5).item():.3f}')
Pytorch专题实战——反向传播(Backpropagation)相关推荐
- pytorch 入门学习反向传播-4
pytorch 入门学习反向传播 反向传播 import numpy as np import matplotlib.pyplot as plt import torchdef forward(x): ...
- Pytorch(5)-梯度反向传播
自动求梯度 1. 函数对自变量x求梯度--ax^2+b 2. 网络对参数w求梯度- loss(w,x) 3. 自动求梯度的底层支持--torch.autograd 3.1 Variable 3.1.1 ...
- 损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()
目录 回答一: 回答二: 回答三: 传统的训练函数,一个batch是这么训练的: 使用梯度累加是这么写的: 回答一: 一句话,用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而 ...
- pytorch 正向与反向传播的过程 获取模型的梯度(gradient),并绘制梯度的直方图
记录一下怎样pytorch框架下怎样获得模型的梯度 文章目录 引入所需要的库 一个简单的函数 模型梯度获取 先定义一个model 如下定义两个获取梯度的函数 定义一些过程与调用上述函数的方法 可视化一 ...
- Pytorch专题实战——交叉熵损失函数(CrossEntropyLoss )
文章目录 1.用CrossEntropyLoss预测单个目标 2.用CrossEntropyLoss预测多个目标 3.二分类使用BCELoss损失函数 4.多分类使用CrossEntropyLoss损 ...
- Pytorch专题实战——批训练数据(DataLoader)
文章目录 1.计算流程 2.Pytorch构造批处理数据 2.1.导入必要模块 2.2.定义数据类 2.3.定义DataLoader 2.4.打印效果 1.计算流程 # Implement a cus ...
- Pytorch专题实战——逻辑回归(Logistic Regression)
文章目录 1.计算流程 2.Pytorch搭建线性逻辑模型 2.1.导入必要模块 2.2.数据准备 2.3.构建模型 2.4.训练+计算准确率 1.计算流程 1)设计模型: Design model ...
- Pytorch专题实战——线性回归(Linear Regression)
文章目录 1.计算流程 2.Pytorch搭建线性回归模型 2.1.导入必要模块 2.2.构造训练数据 2.3.测试数据及输入输出神经元个数 2.4.搭建模型并实例化 2.5.训练 1.计算流程 1) ...
- Pytorch专题实战——前馈神经网络(Feed-Forward Neural Network)
文章目录 1.导入必要模块 2.超参数设置 3.数据准备 4.打印部分加载的数据 5.模型建立 6.训练 1.导入必要模块 import torch import torch.nn as nn imp ...
最新文章
- 惊天大谎:让穷人都能上网是Facebook的殖民阴谋?
- SpringMVC使用及知识点提炼
- POJ2718【DFS】
- 渗透之cookie截取
- oracle 源代码输出,oracle-如何将DBMS_OUTPUT.PUT_LINE的输出重定向到文件?
- $_SERVER[HTTP_HOST]
- registerModule: 动态注册vuex模块,对于自定义生成组件很有用
- 解决办法:安装cuda时一直失败(如提示Reboot required to continue)
- python订餐系统简单版
- Java学习网站推荐
- java8.0安装教程_jdk8安装教程详解
- Linux使用Jstack查看Java堆栈快照脚本
- ios视频播放器封装(全屏播放,锁屏、手势调节亮度、音量、进度)
- 公众号推送长图最佳尺寸_微信公众平台图片尺寸是多少
- qqc什么梗_网络语cpdd是什么意思 王者荣耀QQ飞车里很常见
- 骁龙888发布,小米11首发,有14家厂商首批搭载!
- Jenkins高级篇之Pipeline语法篇-7-Declarative Pipeline指令:triggers/stage/tool
- 2.系统测试流程规范
- 设计部门领导必备能力
- android录音频谱动画,android获取和展示音乐的频谱
热门文章
- 「管理数学基础」4.2 模糊数学:扩张原理、模糊数、可能性分布与模糊概率
- 【数据结构笔记18】堆中的路径与C实现(堆元素到根的路)径)
- java httpclient 下载文件_httpclient 上传文件、下载文件
- 异常关闭MyEclipse 8.6后,不能重启
- idea中编辑*.vue文件没有任何提示
- oracle中lag()函数和lead()函数的用法(图文)
- MySql数据库导出完整版(导出数据库,导出表,导出数据库结构)
- log4j配置以及logback配置
- 超强1000个jquery极品插件!
- 在Web.Config中指定页面的基类