一元线性回归模型

回归分析

一元线性回归模型如下所示:(我们只需确定此方程的两个参数即可)

第一个参数为截距,第二个参数为斜率

为了求解上述参数,我们在这里引入代价函数(cost function),在这里以一元线性回归模型为例:

上述式子为真实值与预测值之间的差值的平方(当然也可以取绝对值,但为便于后续数值操作取其平方),最后取所有训练集个数的平均值,结合下图直观理解。

上述由一些参数构成的函数称为代价函数,我们的目标就是求解对应的参数使得代价函数达到最小值,最后确定模型:

相关系数(了解)

相关系数用于衡量线性相关性的强弱:

决定系数(了解)

梯度下降:用于确定所需参数

其具体步骤如下图所示:

此方法在一元线性回归中可以用于确定代价函数,因为代价函数是二维的(两个未知量,凸函数),所以理论上可以收敛于全局最小值(对于高次函数就可能达到局部最小值),如下图:

对于初始化操作,一般情况下赋值为0即可。所谓的梯度优化,就是不断的更改参数值,使之最后到达一个全局(局部)最小值。

这里的参数更新要求是同步更新,即最后在对参数进行更新,这里的α为学习率,通常取值为0.01,0.001,0.03,0.003等学习率不易过高也不能过低,当过高是会导致永远到达不了收敛点(发散),当过低时会导致收敛过慢,影响收敛速度;在这里的一元线性回归模型我们确定的参数只有两个,其参数求解方式如下图所示:


实战

方法一:梯度下降法

步骤一:载入库

import numpy as np  #导入numpy
import matplotlib.pyplot as plt #导入图像绘制库

步骤二:读取要进行回归的数据(以CSV为例)

#载入数据
#两列数据,逗号为分隔符
#数据和代码在同一目录下可以直接写文件名,否则要加路径
data=np.getfromtxt("data.csv",delimiter=",")
x_data=data[:,0]
y_data=data[:,1]

步骤三:算法编写

# 设置学习率(步长)和初始化线性参数
lr = 0.0001
b = 0
k = 0
# 最大迭代次数
epochs = 50# 最小二乘法
def compute_error(b, k, x_data, y_data):totalError = 0for i in range(0, len(x_data)):totalError += (y_data[i] - (k * x_data[i] + b)) ** 2return totalError / float(len(x_data)) / 2.0def gradient_descent_runner(x_data, y_data, b, k, lr, epochs):# 计算总数据量m = float(len(x_data))# 循环epochs次for i in range(epochs):b_grad = 0k_grad = 0# 计算梯度的总和再求平均for j in range(0, len(x_data)):b_grad += (1/m) * (((k * x_data[j]) + b) - y_data[j])k_grad += (1/m) * x_data[j] * (((k * x_data[j]) + b) - y_data[j])# 更新b和kb = b - (lr * b_grad)k = k - (lr * k_grad)# 每迭代5次,输出一次图像if i % 5==0:print("epochs:",i)plt.plot(x_data, y_data, 'b.')plt.plot(x_data, k*x_data + b, 'r')plt.show()return b, k

第四步:调用,得到结果

b, k = gradient_descent_runner(x_data, y_data, b, k, lr, epochs)
print("After {0} iterations b = {1}, k = {2}, error = {3}".format(epochs, b, k, compute_error(b, k, x_data, y_data)))

方法二:sklearn

步骤一:调用库

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

步骤二:载入数据

#载入数据
#这里的x_data和y_data需要增加维度,从(100,)变为(100,1),这是回归算法要求的
data=np.getfromtxt("data.csv",delimiter=",")
x_data = data[:,0,np.newaxis]
y_data = data[:,1,np.newaxis]

步骤三:创建并拟合模型

model = LinearRegression()
model.fit(x_data, y_data)
a=model.coef_   #系数
b=model.intercept_   #bias
# 画图
plt.plot(x_data, y_data, 'b.')
plt.plot(x_data, model.predict(x_data), 'r')
plt.show()


这是我学习 覃秉丰老师的《机器学习算法基础》的自学笔记,课程在B站中的地址为:

机器学习算法基础-覃秉丰_哔哩哔哩_bilibili

机器学习(一元线性回归)相关推荐

  1. 机器学习——一元线性回归和多元线性回归

    一元线性回归:梯度下降法 一元线性回归是线性回归的最简单的一种,即只有一个特征变量.首先是梯度下降法,这是比较经典的求法.一元线性回归通俗易懂地说,就是一元一次方程.只不过这里的斜率和截距要通过最小二 ...

  2. 吴恩达-机器学习-一元线性回归模型实现

    吴恩达<机器学习>2022版 第一周 一元线性回归 房价预测简单实现 import numpy as np import math, copy#输入数据 x_train = np.arra ...

  3. 机器学习------一元线性回归算法

    文章目录 预测数据型数据:回归 回归的含义 回归应用 线性回归 利用Sklearn做线性回归的预测 线性回归拟合原理(fit方法) 损失函数 梯度下降法 梯度下降的分类 "Batch&quo ...

  4. 机器学习——回归——一元线性回归

    目录 理论部分 1.1 回归问题 1.2 回归问题分类 1.3 线性回归 1.4 一元线性回归 1.4.1 基本形式 1.4.2 损失函数 1.4.3 训练集与测试集 1.4.4 学习目标 1.4.5 ...

  5. 机器学习:回归分析—— 一元线性回归、多元线性回归的简单实现

    回归分析 回归分析概述 基本概念 可以解决的问题 基本步骤和分类 线性回归 一元线性回归 多元线性回归 回归分析概述 基本概念 回归分析是处理多变量间相关关系的一种数学方法.相关关系不同于函数关系,后 ...

  6. 【从零开始的机器学习】-03 一元线性回归与代价函数

    1. 例子:假设我们有一些房屋的数据,然后想要通过这些数据来估计某栋房子的价格.如果我们决定只使用房屋的面积来预测,如何建立一个预测模型呢? 假设我们的数据是: 房屋面积 x( m 2 m^{2} m ...

  7. 机器学习(二)-一元线性回归算法(代码实现及数学证明)

    解决回归问题 思想简单,实现容易 许多强大的非线性模型的基础 结果具有很好的可解释性 蕴含机器学习中的很多重要思想 回归问题:连续值 如果样本 特征 只有一个 称为简单线性回归 y=ax + b 通过 ...

  8. 机器学习入门(二)一元线性回归

    目录 2.一元线性回归 2.1 什么是线性回归 2.2 代价函数 2.2.1 假设函数 2.2.2 代价函数 2.3 梯度下降法 2.3.1 引出问题 2.3.2 梯度下降法 2.3.3 梯度下降法的 ...

  9. 从统计看机器学习(一) 一元线性回归

    从统计学的角度来看,机器学习大多的方法是统计学中分类与回归的方法向工程领域的推广. "回归"(Regression)一词的滥觞是英国科学家Francis Galton(1822-1 ...

  10. 机器学习初探:(二)线性回归之一元线性回归

    (二)一元线性回归 文章目录 (二)一元线性回归 一元线性回归(Univariate Linear Regression) 模型形式和基本假设 损失函数(Cost Function) 梯度下降(Gra ...

最新文章

  1. python如何控制mysql_python如何操作mysql
  2. SVN 两种存储格式(BDB和FSFS)区别
  3. python优雅写法
  4. 鸿星尔克向河南博物院捐款100万元用于灾后重建
  5. linux云自动化运维,linux云自动化系统运维17(延时服务及定时服务)
  6. java z+_Java算法练习—— Z 字形变换
  7. php实现小说字典功能_PHP实现微信小程序人脸识别刷脸登录功能
  8. java与数据库连接odbc_详解java数据库连接之JDBC-ODBC桥连方式
  9. Tomcat怎么重启 tomcat重启命令
  10. ZXPInstaller for Mac(PS扩展安装器)免费版
  11. 5G+北斗融合定位技术介绍
  12. 试图通俗地讲一下庞加莱猜想是怎么回事
  13. 如何利用R语言处理 缺失值 数据
  14. Asp.Net MVC访问数据库实现登录
  15. 深度解析 ORA-01555 原因及解决方法
  16. 计算机网络(第8版)谢希仁第一章概述笔记
  17. 交换机的全trunk模式(native vlan)
  18. 三星正在研发智能戒指,智能戒指当然少不了Find My功能
  19. RT-Thread—FAL与EasyFlash组件移植
  20. Scrum Gathering开放分享:敏捷开发早期估算by火星人陈勇,北京,6.30!

热门文章

  1. 诺顿误杀导致系统崩溃--起因及对策
  2. 洛谷 P4099 SAO —— 树形dp
  3. CallStack获取函数堆栈
  4. ukey功能适配文档
  5. 超标量处理器设计 姚永斌 第1章 超标量处理器概览 摘录
  6. QQ在线客服代码(不需要加好友即可发起临时会话)
  7. 制作U盘DOS启动盘详细教程及工具,及DOS下升级BIOS方法,传统BIOS升级为UEFI
  8. NIO 网络编程之群聊系统
  9. 【遥感微课堂】学习ENVI5.0
  10. 【遥感数字图像处理】实验:遥感影像增强方法大全处理看过来(Erdas版)