文章目录

  • 1.前言
  • 2.初识如何更新梯度
  • 3.手动更新梯度
  • 4.自动更新梯度

1.前言

大体分为三步:
(1)前向传播,计算loss
(2)计算局部梯度
(3)反向传播,用链式求导法则计算梯度

2.初识如何更新梯度

import torchx = torch.tensor(1.0)   #指定输入x
y = torch.tensor(2.0)   #指定输出y
w = torch.tensor(1.0, requires_grad=True) #初始化待求的w, requires_grad=True代表需要求导y_predicted = w*x     #预测值
loss = (y - y_predicted)**2     #计算预测值与真实值间的loss
print(loss)loss.backward()       #反向传播
print(w.grad)        #打印此时的梯度with torch.no_grad():w -= 0.01*w.grad    #更新梯度print(w)          #打印更新后的梯度
w.grad.zero_()      #梯度清零,防止梯度累加

3.手动更新梯度

import numpy as npX = np.array([1,2,3,4], dtype=np.float32)    #输出X
Y = np.array([2,4,6,8], dtype=np.float32)    #输出Yw = 0.0     #初始化权重def forward(x):return w*x      #预测值def loss(y, y_pred):return ((y_pred-y)**2).mean()    #计算loss#Loss = 1/N*(w*x-y)**2
#dLoss/dw = 1/N*2x(w*x-y)
def gradient(x,y,y_pred):          #计算梯度return np.dot(2*x, y_pred-y).mean()print(f'Prediction before training:f(5)={forward(5):.3f}')    #当权重为0是,预测值为0learning_rate = 0.01    #学习率
n_iters = 20      #迭代次数for epoch in range(n_iters): y_pred = forward(X)    #预测值l = loss(Y,y_pred)     #真实值与预测值的误差dw = gradient(X,Y,y_pred)    #求梯度w -= learning_rate * dw      #更新参数if epoch % 2 == 0:          #每隔两个epoch更新参数print(f'epoch {epoch+1}: w={w:.3f},loss={l:.8f}')
print(f'Prediction after training:f(5)={forward(5):.3f}')    #运行最后的权重算f(5)

4.自动更新梯度

import torchX = torch.tensor([1,2,3,4], dtype=torch.float32)
Y = torch.tensor([2,4,6,8], dtype=torch.float32)w = torch.tensor(0.0, dtype=torch.float32, requires_grad=True)    #初始化权重,并设置需要求导def forward(x):return w * x      #预测def loss(y, y_pred):return ((y_pred-y)**2).mean()   #算loss
print(f'Prediction before training:f(5)={forward(5).item():.3f}')   #打印初始的预测learning_rate = 0.1
n_iters = 20for epoch in range(n_iters):y_pred = forward(X)l = loss(Y, y_pred)l.backward()   #等同于计算梯度with torch.no_grad():w -= learning_rate*w.grad     #更新参数w.grad.zero_()     #梯度清零,防止梯度累计if epoch % 2==0:print(f'epoch {epoch+1}: w={w.item():.3f}, loss={l.item():.8f}')
print(f'Prediction after training:f(5)={forward(5).item():.3f}')

Pytorch专题实战——反向传播(Backpropagation)相关推荐

  1. pytorch 入门学习反向传播-4

    pytorch 入门学习反向传播 反向传播 import numpy as np import matplotlib.pyplot as plt import torchdef forward(x): ...

  2. Pytorch(5)-梯度反向传播

    自动求梯度 1. 函数对自变量x求梯度--ax^2+b 2. 网络对参数w求梯度- loss(w,x) 3. 自动求梯度的底层支持--torch.autograd 3.1 Variable 3.1.1 ...

  3. 损失函数与优化器理解+【PyTorch】在反向传播前为什么要手动将梯度清零?optimizer.zero_grad()

    目录 回答一: 回答二: 回答三: 传统的训练函数,一个batch是这么训练的: 使用梯度累加是这么写的: 回答一: 一句话,用来更新和计算影响模型训练和模型输出的网络参数,使其逼近或达到最优值,从而 ...

  4. pytorch 正向与反向传播的过程 获取模型的梯度(gradient),并绘制梯度的直方图

    记录一下怎样pytorch框架下怎样获得模型的梯度 文章目录 引入所需要的库 一个简单的函数 模型梯度获取 先定义一个model 如下定义两个获取梯度的函数 定义一些过程与调用上述函数的方法 可视化一 ...

  5. Pytorch专题实战——交叉熵损失函数(CrossEntropyLoss )

    文章目录 1.用CrossEntropyLoss预测单个目标 2.用CrossEntropyLoss预测多个目标 3.二分类使用BCELoss损失函数 4.多分类使用CrossEntropyLoss损 ...

  6. Pytorch专题实战——批训练数据(DataLoader)

    文章目录 1.计算流程 2.Pytorch构造批处理数据 2.1.导入必要模块 2.2.定义数据类 2.3.定义DataLoader 2.4.打印效果 1.计算流程 # Implement a cus ...

  7. Pytorch专题实战——逻辑回归(Logistic Regression)

    文章目录 1.计算流程 2.Pytorch搭建线性逻辑模型 2.1.导入必要模块 2.2.数据准备 2.3.构建模型 2.4.训练+计算准确率 1.计算流程 1)设计模型: Design model ...

  8. Pytorch专题实战——线性回归(Linear Regression)

    文章目录 1.计算流程 2.Pytorch搭建线性回归模型 2.1.导入必要模块 2.2.构造训练数据 2.3.测试数据及输入输出神经元个数 2.4.搭建模型并实例化 2.5.训练 1.计算流程 1) ...

  9. Pytorch专题实战——前馈神经网络(Feed-Forward Neural Network)

    文章目录 1.导入必要模块 2.超参数设置 3.数据准备 4.打印部分加载的数据 5.模型建立 6.训练 1.导入必要模块 import torch import torch.nn as nn imp ...

最新文章

  1. 惊天大谎:让穷人都能上网是Facebook的殖民阴谋?
  2. SpringMVC使用及知识点提炼
  3. POJ2718【DFS】
  4. 渗透之cookie截取
  5. oracle 源代码输出,oracle-如何将DBMS_OUTPUT.PUT_LINE的输出重定向到文件?
  6. $_SERVER[HTTP_HOST]
  7. registerModule: 动态注册vuex模块,对于自定义生成组件很有用
  8. 解决办法:安装cuda时一直失败(如提示Reboot required to continue)
  9. python订餐系统简单版
  10. Java学习网站推荐
  11. java8.0安装教程_jdk8安装教程详解
  12. Linux使用Jstack查看Java堆栈快照脚本
  13. ios视频播放器封装(全屏播放,锁屏、手势调节亮度、音量、进度)
  14. 公众号推送长图最佳尺寸_微信公众平台图片尺寸是多少
  15. qqc什么梗_网络语cpdd是什么意思 王者荣耀QQ飞车里很常见
  16. 骁龙888发布,小米11首发,有14家厂商首批搭载!
  17. Jenkins高级篇之Pipeline语法篇-7-Declarative Pipeline指令:triggers/stage/tool
  18. 2.系统测试流程规范
  19. 设计部门领导必备能力
  20. android录音频谱动画,android获取和展示音乐的频谱

热门文章

  1. 「管理数学基础」4.2 模糊数学:扩张原理、模糊数、可能性分布与模糊概率
  2. 【数据结构笔记18】堆中的路径与C实现(堆元素到根的路)径)
  3. java httpclient 下载文件_httpclient 上传文件、下载文件
  4. 异常关闭MyEclipse 8.6后,不能重启
  5. idea中编辑*.vue文件没有任何提示
  6. oracle中lag()函数和lead()函数的用法(图文)
  7. MySql数据库导出完整版(导出数据库,导出表,导出数据库结构)
  8. log4j配置以及logback配置
  9. 超强1000个jquery极品插件!
  10. 在Web.Config中指定页面的基类