线性回归是分析一个变量与另外一(多)个变量之间关系的方法。求解步骤如下:

  1. 确定模型
    y=Wx+by = Wx + by=Wx+b
  2. 选择损失函数
    MSE=1m∑i=1m(yi−yi^)2MSE=\frac{1}{m}\sum_{i=1}^{m}(y_i-\hat{y_i})^2MSE=m1​i=1∑m​(yi​−yi​^​)2
  3. 求解梯度并更新w、b
    w=w−LR∗w.gradb=b−LR∗w.gradw=w-LR*w.grad\\ b = b-LR * w.gradw=w−LR∗w.gradb=b−LR∗w.grad
"""
@author: admin
@file: torch实现线性回归.py
@time: 2021/07/14
@desc:
"""
import torch
import matplotlib.pyplot as plttorch.manual_seed(10)# 学习率
lr = 0.05# 创建训练数据
x = torch.rand(20, 1) * 10  # x data (tensor), shape=(20, 1)
y = 2 * x + (5 + torch.randn(20, 1))  # y data (tensor), shape=(20, 1)# 构建线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)for iteration in range(1000):# 前向传播wx = torch.mul(w, x)y_pred = torch.add(wx, b)# 计算 MSE lossloss = (0.5 * (y - y_pred) ** 2).mean()# 反向传播—自动求导loss.backward()# 更新参数b.data.sub_(lr * b.grad)w.data.sub_(lr * w.grad)# 清零张量的梯度w.grad.zero_()b.grad.zero_()# 绘图if iteration % 20 == 0:plt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})plt.xlim(1.5, 10)plt.ylim(8, 28)plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))plt.pause(0.5)if loss.data.numpy() < 1:breakplt.show()

如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!


PyTorch学习—3.pytorch实现线性回归相关推荐

  1. Pytorch学习 - Task5 PyTorch卷积层原理和使用

    Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...

  2. Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用

    Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...

  3. PyTorch学习记录——PyTorch进阶训练技巧

    PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...

  4. PyTorch学习记录——PyTorch生态

    Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...

  5. 1.pytorch学习:安装pytorch

    目录 安装pytorch 检查pytorch安装是否成功 总结 安装pytorch 官方网址: Start Locally | PyTorchhttps://pytorch.org/get-start ...

  6. PyTorch学习笔记——pytorch图像处理(transforms)

    原始图像 2.图像处理.转不同格式显示 import torch import torchvision import torchvision.transforms as transforms impo ...

  7. add函数 pytorch_Pytorch学习记录-Pytorch可视化使用tensorboardX

    Pytorch学习记录-Pytorch可视化使用tensorboardX 在很早很早以前(至少一个半月),我做过几节关于tensorboard的学习记录. https://www.jianshu.co ...

  8. Pytorch学习笔记总结

    往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...

  9. pytorch学习笔记(五):线性回归的简洁实现

    文章目录 前言 1 生成数据集 2 读取数据 3 定义模型 4 初始化模型参数 5 定义损失函数 6 定义优化算法 7 训练模型 小结 前言 随着深度学习框架的发展,开发深度学习应用变得越来越便利.实 ...

  10. 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】

    深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...

最新文章

  1. 【北京-知春路】这10家互联网公司值得你加入
  2. 在Mac配置adb命令
  3. Nebula3的Input系统
  4. 一文让你轻松了解 JAVA 开发中的四种加密方法
  5. 《R语言实战》第7章
  6. 前端学习(1042):todoList存储
  7. Android学习问题:关于AlertDialog中自定义布局带有的EditText无法弹出键盘
  8. Java中的NIO非阻塞编程
  9. 命令行下对apk签名
  10. Windows64位 python3.6安装pyHook
  11. redis过期策略有哪些?内存淘汰机制有哪些?
  12. Android系统源代码目录
  13. English Summary~July
  14. 大白菜u盘启动盘清除系统登录密码详细教程
  15. filezilla,怎么下载filezilla
  16. 渗透测试-CTF_AWD专题篇
  17. 正点原子DS100手持示波器测试记录
  18. 你的Idea还可用吗?不妨试试另一个开发神器!
  19. 如何注册表里修改计算机用户名,更改电脑用户名(可更改C:\Users\用户名)
  20. gensim 主题模型 seed

热门文章

  1. 洗衣机一边进水一边出水 更换排水阀皮碗
  2. package--math
  3. 国内10大广告联盟各自有哪些优势?
  4. Sprig 面试中 问及 DI,IOC, AOP
  5. 利用c#反射提高设计灵活性
  6. 【matlab】解决每次打开.m文件都会弹出新窗口
  7. Struts2接受页面传值过程中出现input的问题
  8. OpenJudge 2990:符号三角形 解析报告
  9. [解答]对‘’未定义的引用 collect2: 错误: ld 返回 1
  10. poj 2754 Similarity of necklaces 2