PyTorch学习—3.pytorch实现线性回归
线性回归是分析一个变量与另外一(多)个变量之间关系的方法。求解步骤如下:
- 确定模型
y=Wx+by = Wx + by=Wx+b - 选择损失函数
MSE=1m∑i=1m(yi−yi^)2MSE=\frac{1}{m}\sum_{i=1}^{m}(y_i-\hat{y_i})^2MSE=m1i=1∑m(yi−yi^)2 - 求解梯度并更新w、b
w=w−LR∗w.gradb=b−LR∗w.gradw=w-LR*w.grad\\ b = b-LR * w.gradw=w−LR∗w.gradb=b−LR∗w.grad
"""
@author: admin
@file: torch实现线性回归.py
@time: 2021/07/14
@desc:
"""
import torch
import matplotlib.pyplot as plttorch.manual_seed(10)# 学习率
lr = 0.05# 创建训练数据
x = torch.rand(20, 1) * 10 # x data (tensor), shape=(20, 1)
y = 2 * x + (5 + torch.randn(20, 1)) # y data (tensor), shape=(20, 1)# 构建线性回归参数
w = torch.randn((1), requires_grad=True)
b = torch.zeros((1), requires_grad=True)for iteration in range(1000):# 前向传播wx = torch.mul(w, x)y_pred = torch.add(wx, b)# 计算 MSE lossloss = (0.5 * (y - y_pred) ** 2).mean()# 反向传播—自动求导loss.backward()# 更新参数b.data.sub_(lr * b.grad)w.data.sub_(lr * w.grad)# 清零张量的梯度w.grad.zero_()b.grad.zero_()# 绘图if iteration % 20 == 0:plt.cla()plt.scatter(x.data.numpy(), y.data.numpy())plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-', lw=5)plt.text(2, 20, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 20, 'color': 'red'})plt.xlim(1.5, 10)plt.ylim(8, 28)plt.title("Iteration: {}\nw: {} b: {}".format(iteration, w.data.numpy(), b.data.numpy()))plt.pause(0.5)if loss.data.numpy() < 1:breakplt.show()
如果对您有帮助,麻烦点赞关注,这真的对我很重要!!!如果需要互关,请评论或者私信!
PyTorch学习—3.pytorch实现线性回归相关推荐
- Pytorch学习 - Task5 PyTorch卷积层原理和使用
Pytorch学习 - Task5 PyTorch卷积层原理和使用 1. 卷积层 (1)介绍 (torch.nn下的) 1) class torch.nn.Conv1d() 一维卷积层 2) clas ...
- Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用
Pytorch学习 - Task6 PyTorch常见的损失函数和优化器使用 官方参考链接 1. 损失函数 (1)BCELoss 二分类 计算公式 小例子: (2) BCEWithLogitsLoss ...
- PyTorch学习记录——PyTorch进阶训练技巧
PyTorch学习记录--PyTorch进阶训练技巧 1.自定义损失函数 1.1 以函数的方式定义损失函数 1.2 以类的方式定义损失函数 1.3 比较与思考 2.动态调整学习率 2.1 官方提供的s ...
- PyTorch学习记录——PyTorch生态
Pytorch的强大并不仅局限于自身的易用性,更在于开源社区围绕PyTorch所产生的一系列工具包(一般是Python package)和程序,这些优秀的工具包极大地方便了PyTorch在特定领域的使 ...
- 1.pytorch学习:安装pytorch
目录 安装pytorch 检查pytorch安装是否成功 总结 安装pytorch 官方网址: Start Locally | PyTorchhttps://pytorch.org/get-start ...
- PyTorch学习笔记——pytorch图像处理(transforms)
原始图像 2.图像处理.转不同格式显示 import torch import torchvision import torchvision.transforms as transforms impo ...
- add函数 pytorch_Pytorch学习记录-Pytorch可视化使用tensorboardX
Pytorch学习记录-Pytorch可视化使用tensorboardX 在很早很早以前(至少一个半月),我做过几节关于tensorboard的学习记录. https://www.jianshu.co ...
- Pytorch学习笔记总结
往期Pytorch学习笔记总结: 1.Pytorch实战笔记_GoAI的博客-CSDN博客 2.Pytorch入门教程_GoAI的博客-CSDN博客 Pytorch系列目录: PyTorch学习笔记( ...
- pytorch学习笔记(五):线性回归的简洁实现
文章目录 前言 1 生成数据集 2 读取数据 3 定义模型 4 初始化模型参数 5 定义损失函数 6 定义优化算法 7 训练模型 小结 前言 随着深度学习框架的发展,开发深度学习应用变得越来越便利.实 ...
- 【从线性回归到 卷积神经网络CNN 循环神经网络RNN Pytorch 学习笔记 目录整合 源码解读 B站刘二大人 绪论(0/10)】
深度学习 Pytorch 学习笔记 目录整合 数学推导与源码详解 B站刘二大人 目录传送门: 线性模型 Linear-Model 数学原理分析以及源码详解 深度学习 Pytorch笔记 B站刘二大人( ...
最新文章
- 【北京-知春路】这10家互联网公司值得你加入
- 在Mac配置adb命令
- Nebula3的Input系统
- 一文让你轻松了解 JAVA 开发中的四种加密方法
- 《R语言实战》第7章
- 前端学习(1042):todoList存储
- Android学习问题:关于AlertDialog中自定义布局带有的EditText无法弹出键盘
- Java中的NIO非阻塞编程
- 命令行下对apk签名
- Windows64位 python3.6安装pyHook
- redis过期策略有哪些?内存淘汰机制有哪些?
- Android系统源代码目录
- English Summary~July
- 大白菜u盘启动盘清除系统登录密码详细教程
- filezilla,怎么下载filezilla
- 渗透测试-CTF_AWD专题篇
- 正点原子DS100手持示波器测试记录
- 你的Idea还可用吗?不妨试试另一个开发神器!
- 如何注册表里修改计算机用户名,更改电脑用户名(可更改C:\Users\用户名)
- gensim 主题模型 seed