线性回归原理

一元线性回归

一元线性回归其实就是从一堆训练集中去算出一条直线,使数据集到直线之间的距离差最小。

举个栗子:

唯一特征X,共有m = 500个数据数量,Y是实际结果,要从中找到一条直线,使数据集到直线之间的距离差最小,如下图所示:

那要如何去完成这个操作呢?

线性回归所提供的思路是,先假设一条直线:

可以将特征X中每一个值都带入其中,得到对应的,定义可以将损失定义为 之间的差值平方的和:

而为了之后计算将其修改为

接下来问题就简单了,只需要求出最小的 就可以了。


梯度下降

可以看出,现在的问题已经变成了一个求极值的问题,这里面有很多种方法,有最小二乘,标准方程方法,以及梯度下降等等,在这里只简要分析一下梯度下降法。

分别从(-50,50)中取值,算出其的损失值,可以做出如下图:

其中x、y轴就是,z轴是的值,现在目标就变成了从图中找到最低点,而最低点的xy轴坐标就是我们要的

而梯度下降,顾名思义,就是要往下走,而这个往下走,方向并不随意,而是要沿着梯度最大的位置往下走,而最大,不久是它的偏导数吗,接下来要做的就是对求偏导:

之后更新得到新的

在这个公式里突然多了一个,它表示的是学习率,用来限定步长的大小,也很容易理解,毕竟得到的偏导是类似斜率的东西,只是一个方向,总得加个数值,才能表示向这个方向移动的距离。

对于的取值,一般都取的比较小,但也不要太小,太小就意味着迭代步数要增加,运算时间边长。

另外,迭代的步数要自己设置。

python实现

因为并没有数据集,就自己做了一个,代码只是简单的实现了一下。

# -*- coding: utf-8 -*-
"""
Created on Wed Jun 20 17:09:13 2018@author: 96jie
"""#导入cv模块
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt#数据
a = np.random.standard_normal((1, 500))
x = np.arange(0,50,0.1)
y = np.arange(20,120,0.2)
y = y - a*10
y = y[0]#梯度下降
def Optimization(x,y,theta,learning_rate):for i in range(iter):theta = Updata(x,y,theta,learning_rate)return thetadef Updata(x,y,theta,learning_rate):m = len(x)sum = 0.0sum1 = 0.0alpha = learning_rateh = 0for i in range(m):h = theta[0] + theta[1] * x[i]sum += (h - y[i])sum1 += (h - y[i]) * x[i]theta[0] -= alpha * sum / m theta[1] -= alpha * sum1 / m return theta#数据初始化
learning_rate = 0.001
theta = [0,0]
iter = 1000
theta = Optimization(x,y,theta,learning_rate)plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
'''
plt.figure(figsize=(35,35))
plt.scatter(x,y,marker='o')
plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.xlabel('特征X',fontsize=40)
plt.ylabel('Y',fontsize=40)
plt.title('样本',fontsize=40)
plt.savefig("样本.jpg")
'''
#可视化
b = np.arange(0,50)
c = theta[0] + b * theta[1]plt.figure(figsize=(35,35))
plt.scatter(x,y,marker='o')
plt.plot(b,c)
plt.xticks(fontsize=40)
plt.yticks(fontsize=40)
plt.xlabel('特征X',fontsize=40)
plt.ylabel('Y',fontsize=40)
plt.title('结果',fontsize=40)
plt.savefig("结果.jpg")

线性回归(一):一元线性回归(附python实现)相关推荐

  1. java 一元线性回归_一元线性回归的java实现

    我们有两组数据,比如连续5年的pv与uv. 我们想预测一下,uv达到500k那么pv会是多少.当然更有意思可能是,如果销售额是500w的话,pv会是多少. 机器学习里的一元线性回归方法是比较简单的方法 ...

  2. 机器学习初探:(二)线性回归之一元线性回归

    (二)一元线性回归 文章目录 (二)一元线性回归 一元线性回归(Univariate Linear Regression) 模型形式和基本假设 损失函数(Cost Function) 梯度下降(Gra ...

  3. 一元线性回归模型及其Python案例

    回归的概念:(其实就是用曲线拟合的方式探索数据规律) 回归问题的分类: 一元线性回归: 线性回归模型是利用线性拟合的方式探寻数据背后的规律.如下图所示,先通过搭建线性回归模型寻找这些散点(也称样本点) ...

  4. 机器学习——一元线性回归和多元线性回归

    一元线性回归:梯度下降法 一元线性回归是线性回归的最简单的一种,即只有一个特征变量.首先是梯度下降法,这是比较经典的求法.一元线性回归通俗易懂地说,就是一元一次方程.只不过这里的斜率和截距要通过最小二 ...

  5. 一元线性回归决定系数_回归分析|笔记整理(1)——引入,一元线性回归(上)...

    大家好! 新学期开始了,不知道大家又是否能够适应新的一学期呢?先祝所有大学生和中小学生开学快乐! 本学期我的专业课是概率论,回归分析,偏微分方程,数值代数,数值逼近,金融时间序列分析,应用金融计量学和 ...

  6. 线性回归(一元、多元)

    目录 一元线性回归 多元线性回归 一元线性回归 在一元线性回归中,输入只有一个特征.现有输入特征为 x,需要预测的目标特征为y ,一元线性回归模型为 y=w1x+w0y=w_1x+w_0y=w1​x+ ...

  7. matlab重复线性回归,(MATLAB)一元线性回归和多元线性回归

    (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归 (MATLAB)一元线性回归和多元线性回归1.一元线性回归 2.多元线性回归2.1数据说明 2.2程序运行结果 ...

  8. R语言计量(一):一元线性回归与多元线性回归分析

    文章目录 一.数据调用与预处理 二.一元线性回归分析 三.多元线性回归分析 (一)解释变量的多重共线性检测 (二)多元回归 1. 多元最小二乘回归 2. 逐步回归 (三)回归诊断 四.模型评价-常用的 ...

  9. 机器学习——回归——一元线性回归

    目录 理论部分 1.1 回归问题 1.2 回归问题分类 1.3 线性回归 1.4 一元线性回归 1.4.1 基本形式 1.4.2 损失函数 1.4.3 训练集与测试集 1.4.4 学习目标 1.4.5 ...

  10. 第十一章 一元线性回归

    主要分析数值型自变量与数值型自变量之间的关系. 从变量个数上看,可分为简单相关与简单回归分析和多元相关与多元回归分析:从变量之间的关系形态上看,有线性相关与线性回归分析和非线性相关与非线性回归分析. ...

最新文章

  1. nginx安装-添加MP4播放模块
  2. php 面向对象 创建OOP
  3. Netflix Curator 使用 Zookeeper 编程
  4. 2017 ACM/ICPC Asia Regional Qingdao Online 记录
  5. UVA - 1587 Box
  6. pbs 支持 java_Linux下Java安装与配置
  7. 浅谈 多任务学习 在推荐系统中的应用
  8. Win 2012 OS 安装.Net Framework 3.5
  9. 当前企业最流行的三种软件开发模式
  10. RabbitMQ安装问题
  11. javaweb项目大概轮廓
  12. [附源码]java毕业设计景区门票系统
  13. 关于清理系统lj.bat的问题
  14. 初装vs2010旗舰版 遇到的错误
  15. Matlab根据excel数据画图
  16. javaScript 琐碎
  17. UML系列——时序图(顺序图)
  18. Oracle数据库第一天
  19. Jenkins-Slave分布式架构搭建
  20. Windows10安装Ubuntu桌面子系统WSL2

热门文章

  1. 计算机硬件技术基础教程mcs-51单片机原理及应用,mcs51单片机原理及应用
  2. 计算机xp系统恢复以前设置,电脑xp系统怎么恢复出厂设置,xp系统怎么恢复出厂设置...
  3. 【总结】EJB开发过程中遇到的几个问题
  4. python读取dat文件代码-基于python批量处理dat文件及科学计算方法详解
  5. 猎豹网校java版算法_猎豹网校JAVA语言数据结构与算法视频教程 Java语言
  6. NAND flash和NOR flash的区别详解
  7. mysql 时间查询_MYSQL按时间段查询语句大全
  8. Linux nexus3的搭建
  9. 思科路由器2811如何重设密码
  10. Android签名概述