pytorch:线性回归实战
pytorch:线性回归实战
- 数学原理
- pytorch实现
数学原理
首先我们看一下什么是线性回归,直观的看就是一张图上有很多散列的点
我们要找一个函数来描述这种关系,而线性回归的话表示我们要找一个形如:
f(x)=w∗x+bf(x)=w*x+bf(x)=w∗x+b
这样的函数来描述这样关系,也可以写成:
f(x)=w1∗x1+w2∗x2f(x)=w_1*x_1+w_2*x_2 f(x)=w1∗x1+w2∗x2
其中x1=1x_1=1x1=1
而如果写成矩阵的形式的话,就是:
f(x)=WXTf(x)=WX^Tf(x)=WXT
其中:
W=(w2,w1),X=(x2,1)W=(w_2,w_1),X=(x_2,1)W=(w2,w1),X=(x2,1)
因此问题就变成了,求最优w,使损失函数最小,损失函数用均方误差(MSE)来描述,即:
loss=12∗(y^−y)2loss=\frac{1}{2}*(\hat{y}-y)^2loss=21∗(y^−y)2
其中y hat 表示预测的y值,y表示真实的y值。
最优w求解是使用梯度下降算法,通过对凸函数loss求导,另一阶导等于0,得到w的最优值,详细推倒可以参考西瓜数和吴恩达的视频讲解。
最后得到如下结果图:
pytorch实现
使用pytoch实现起来很简单,按照数学的思路,我们的实现也大致分为:1.定义好函数,在pytoch中使用torch.nn.Linear实现。2.定义好损失函数,pytorch里面有MSE函数可用。3.使用梯度下降优化得到最优w,pytorch里面同样有现成函数实现。其他都是一些象加载数据,处理数据,画图显示这样的旁线操作。把握住数学推倒过程这一主线,拿pytorch给你造好的轮子快速实现。
完整代码如下:
import torch as t
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as nplr = 0.01
num_epochs = 100
in_size = 1
out_size = 1
# Toy dataset
x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],[9.779], [6.182], [7.59], [2.167], [7.042],[10.791], [5.313], [7.997], [3.1]], dtype=np.float32)y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],[3.366], [2.596], [2.53], [1.221], [2.827],[3.465], [1.65], [2.904], [1.3]], dtype=np.float32)#define function,can accomplish using a line code: model= nn.Linear(in_size, out_size)
class LinerRegression(nn.Module):def __init__(self, in_size, out_size):super(LinerRegression, self).__init__()self.fc1 = nn.Linear(in_size, out_size)def forward(self, x):y_hat = self.fc1(x)return y_hatmodel = LinerRegression(in_size, out_size)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=lr)for epoch in range(num_epochs):x = t.from_numpy(x_train)y = t.from_numpy(y_train)y_hat = model(x)loss = criterion(y_hat, y)optimizer.zero_grad()loss.backward()optimizer.step()print("[{}/{}] loss:{:.4f}".format(epoch+1, num_epochs, loss))#plot graph
y_pred = model(t.from_numpy(x_train)).detach().numpy()
plt.plot(x_train, y_train, 'ro', label='Original Data')
plt.plot(x_train, y_pred, 'b-', label='Fitted Line')
plt.legend()
plt.show()
pytorch:线性回归实战相关推荐
- 网易云课程:深度学习与PyTorch入门实战
网易云课程:深度学习与PyTorch入门实战 01 深度学习初见 1.1 深度学习框架简介 1.2 pytorch功能演示 2开发环境安装 3回归问题 3.1简单的回归问题(梯度下降算法) 3.3回归 ...
- 视频教程-深度学习与PyTorch入门实战教程-深度学习
深度学习与PyTorch入门实战教程 新加坡国立大学研究员 龙良曲 ¥399.00 立即订阅 扫码下载「CSDN程序员学院APP」,1000+技术好课免费看 APP订阅课程,领取优惠,最少立减5元 ↓ ...
- PyTorch 入门实战
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_36556893/article/ ...
- 一、pytorch搭建实战以及sequential的使用
一.pytorch搭建实战以及sequential的使用 1.A sequential container 2.搭建cifar10 model structure 3.创建实例进行测试(可以检查网络是 ...
- pytorch线性回归_PyTorch中的线性回归
pytorch线性回归 For all those amateur Machine Learning and Deep Learning enthusiasts out there, Linear R ...
- 【Pytorch神经网络实战案例】21 基于Cora数据集实现Multi_Sample Dropout图卷积网络模型的论文分类
Multi-sample Dropout是Dropout的一个变种方法,该方法比普通Dropout的泛化能力更好,同时又可以缩短模型的训练时间.XMuli-sampleDropout还可以降低训练集和 ...
- (pytorch-深度学习系列)pytorch线性回归的便捷实现
pytorch线性回归的便捷实现 继上一篇blog,使用更加简洁的方法实现线性回归 生成数据集: num_inputs = 2 num_examples = 1000 true_w = [2, -3. ...
- [机器学习-回归算法]Sklearn之线性回归实战
Sklearn之线性回归实战 一,前言 二,热身例子 三,贸易公司的简单例子 四,Sklearn 官网里的一个例子 五,预测每月的地铁故障数 参考资料 一,前言 一元线性回归的理论片请看我这个链接 二 ...
- 吴恩达机器学习总结五:单变量线性回归实战
线性回归实战总结: 单变量线性回归: 1.加载和查看数据(准备工作) data = load('ex1data1.txt'); x=data(:,1); y=data(:,2); plot(x,y,' ...
- Pytorch框架实战——102类花卉分类
本篇博文为[唐宇迪]计算机视觉实训营第二天-Pytorch框架实战课程的个人笔记. 代码来自:qiuzitao深度学习之PyTorch实战(十),与视频教学流程记录一致,课程详情可参考该篇. 下文数据 ...
最新文章
- MaxCompute Studio使用心得系列7—作业对比
- Python学习笔记__13.2章 requests
- 2021 届校招宣讲会来啦!神策数据,与你一起「数说」未来
- spring cloud 概念
- HTML5 浏览器支持(怎么样让低版本浏览器支持html5?)
- android高德地图自定义带数字marker图标,自定义图标-点标记-示例中心-JS API 示例 | 高德地图API...
- qrcode.js 二维码生成器
- Markdown 基础语法与常见问题总结
- 自定义工作流界面开发
- GaussDB(DWS)中共享消息队列实现的三大功能
- 一加8系列新机有望亮相CES 2020:全系支持5G网络
- Java中Arrays类的两个方法:deepEquals和equals
- putty 配色方案分享
- 等保-机房项目验收方法
- AS13 facets cannot be loaded. you can mark them as ignored to suppress this error notification处理
- 杭电2022 海选女主角
- Windows无法安装,选中的磁盘为GPT分区形式 --解决办法
- nyoj 1239-引水工程 //并查集
- 演讲者模式投影到幕布也看到备注_ppt备注怎么用在放映时怎么可以不在投影仪上显示...
- 【python数据类型】
热门文章
- linux实现文件共享的方式,Linux文件共享的实现方式
- 分享一些小技巧吧,MATLAB中常见问题及解决方案
- 微信企业号回调模式 java_java微信企业号开发之开发模式的开启
- 亚马逊Alexa技能的创建流程
- hyperledger fabric 2.3.3 环境搭建教程
- oenwrt 进不了bios_J1900在openwrt不能正常重启的BIOS选项说明
- indexOf 的使用
- MySQL中Index与Key的区别
- macOS Mojave(软件篇):微信 for Mac 防撤回插件(WeChatTweak-macOS)
- Exoplayer的详细使用UI篇