pytorch:线性回归实战

  • 数学原理
  • pytorch实现

数学原理

首先我们看一下什么是线性回归,直观的看就是一张图上有很多散列的点

我们要找一个函数来描述这种关系,而线性回归的话表示我们要找一个形如:
f(x)=w∗x+bf(x)=w*x+bf(x)=w∗x+b
这样的函数来描述这样关系,也可以写成:
f(x)=w1∗x1+w2∗x2f(x)=w_1*x_1+w_2*x_2 f(x)=w1​∗x1​+w2​∗x2​
其中x1=1x_1=1x1​=1
而如果写成矩阵的形式的话,就是:
f(x)=WXTf(x)=WX^Tf(x)=WXT
其中:
W=(w2,w1),X=(x2,1)W=(w_2,w_1),X=(x_2,1)W=(w2​,w1​),X=(x2​,1)
因此问题就变成了,求最优w,使损失函数最小,损失函数用均方误差(MSE)来描述,即:
loss=12∗(y^−y)2loss=\frac{1}{2}*(\hat{y}-y)^2loss=21​∗(y^​−y)2
其中y hat 表示预测的y值,y表示真实的y值。
最优w求解是使用梯度下降算法,通过对凸函数loss求导,另一阶导等于0,得到w的最优值,详细推倒可以参考西瓜数和吴恩达的视频讲解。
最后得到如下结果图:

pytorch实现

使用pytoch实现起来很简单,按照数学的思路,我们的实现也大致分为:1.定义好函数,在pytoch中使用torch.nn.Linear实现。2.定义好损失函数,pytorch里面有MSE函数可用。3.使用梯度下降优化得到最优w,pytorch里面同样有现成函数实现。其他都是一些象加载数据,处理数据,画图显示这样的旁线操作。把握住数学推倒过程这一主线,拿pytorch给你造好的轮子快速实现。
完整代码如下:

import torch as t
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as nplr = 0.01
num_epochs = 100
in_size = 1
out_size = 1
# Toy dataset
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.182], [7.59], [2.167], [7.042],[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],[3.366], [2.596], [2.53], [1.221], [2.827],[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)#define function,can accomplish using a line code: model= nn.Linear(in_size, out_size)
class LinerRegression(nn.Module):def __init__(self, in_size, out_size):super(LinerRegression, self).__init__()self.fc1 = nn.Linear(in_size, out_size)def forward(self, x):y_hat = self.fc1(x)return y_hatmodel = LinerRegression(in_size, out_size)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=lr)for epoch in range(num_epochs):x = t.from_numpy(x_train)y = t.from_numpy(y_train)y_hat = model(x)loss = criterion(y_hat, y)optimizer.zero_grad()loss.backward()optimizer.step()print("[{}/{}] loss:{:.4f}".format(epoch+1, num_epochs, loss))#plot graph
y_pred = model(t.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original Data')
plt.plot(x_train, y_pred, 'b-', label='Fitted Line')
plt.legend()
plt.show()

pytorch:线性回归实战相关推荐

  1. 网易云课程:深度学习与PyTorch入门实战

    网易云课程:深度学习与PyTorch入门实战 01 深度学习初见 1.1 深度学习框架简介 1.2 pytorch功能演示 2开发环境安装 3回归问题 3.1简单的回归问题(梯度下降算法) 3.3回归 ...

  2. 视频教程-深度学习与PyTorch入门实战教程-深度学习

    深度学习与PyTorch入门实战教程 新加坡国立大学研究员 龙良曲 ¥399.00 立即订阅 扫码下载「CSDN程序员学院APP」,1000+技术好课免费看 APP订阅课程,领取优惠,最少立减5元 ↓ ...

  3. PyTorch 入门实战

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_36556893/article/ ...

  4. 一、pytorch搭建实战以及sequential的使用

    一.pytorch搭建实战以及sequential的使用 1.A sequential container 2.搭建cifar10 model structure 3.创建实例进行测试(可以检查网络是 ...

  5. pytorch线性回归_PyTorch中的线性回归

    pytorch线性回归 For all those amateur Machine Learning and Deep Learning enthusiasts out there, Linear R ...

  6. 【Pytorch神经网络实战案例】21 基于Cora数据集实现Multi_Sample Dropout图卷积网络模型的论文分类

    Multi-sample Dropout是Dropout的一个变种方法,该方法比普通Dropout的泛化能力更好,同时又可以缩短模型的训练时间.XMuli-sampleDropout还可以降低训练集和 ...

  7. (pytorch-深度学习系列)pytorch线性回归的便捷实现

    pytorch线性回归的便捷实现 继上一篇blog,使用更加简洁的方法实现线性回归 生成数据集: num_inputs = 2 num_examples = 1000 true_w = [2, -3. ...

  8. [机器学习-回归算法]Sklearn之线性回归实战

    Sklearn之线性回归实战 一,前言 二,热身例子 三,贸易公司的简单例子 四,Sklearn 官网里的一个例子 五,预测每月的地铁故障数 参考资料 一,前言 一元线性回归的理论片请看我这个链接 二 ...

  9. 吴恩达机器学习总结五:单变量线性回归实战

    线性回归实战总结: 单变量线性回归: 1.加载和查看数据(准备工作) data = load('ex1data1.txt'); x=data(:,1); y=data(:,2); plot(x,y,' ...

  10. Pytorch框架实战——102类花卉分类

    本篇博文为[唐宇迪]计算机视觉实训营第二天-Pytorch框架实战课程的个人笔记. 代码来自:qiuzitao深度学习之PyTorch实战(十),与视频教学流程记录一致,课程详情可参考该篇. 下文数据 ...

最新文章

  1. MaxCompute Studio使用心得系列7—作业对比
  2. Python学习笔记__13.2章 requests
  3. 2021 届校招宣讲会来啦!神策数据,与你一起「数说」未来
  4. spring cloud 概念
  5. HTML5 浏览器支持(怎么样让低版本浏览器支持html5?)
  6. android高德地图自定义带数字marker图标,自定义图标-点标记-示例中心-JS API 示例 | 高德地图API...
  7. qrcode.js 二维码生成器
  8. Markdown 基础语法与常见问题总结
  9. 自定义工作流界面开发
  10. GaussDB(DWS)中共享消息队列实现的三大功能
  11. 一加8系列新机有望亮相CES 2020:全系支持5G网络
  12. Java中Arrays类的两个方法:deepEquals和equals
  13. putty 配色方案分享
  14. 等保-机房项目验收方法
  15. AS13 facets cannot be loaded. you can mark them as ignored to suppress this error notification处理
  16. 杭电2022 海选女主角
  17. Windows无法安装,选中的磁盘为GPT分区形式 --解决办法
  18. nyoj 1239-引水工程 //并查集
  19. 演讲者模式投影到幕布也看到备注_ppt备注怎么用在放映时怎么可以不在投影仪上显示...
  20. 【python数据类型】

热门文章

  1. linux实现文件共享的方式,Linux文件共享的实现方式
  2. 分享一些小技巧吧,MATLAB中常见问题及解决方案
  3. 微信企业号回调模式 java_java微信企业号开发之开发模式的开启
  4. 亚马逊Alexa技能的创建流程
  5. hyperledger fabric 2.3.3 环境搭建教程
  6. oenwrt 进不了bios_J1900在openwrt不能正常重启的BIOS选项说明
  7. indexOf 的使用
  8. MySQL中Index与Key的区别
  9. macOS Mojave(软件篇):微信 for Mac 防撤回插件(WeChatTweak-macOS)
  10. Exoplayer的详细使用UI篇