目录

  • 1. 原理详解
    • 1.1. 线性回归
    • 1.2. 回归系数
  • 2. 公式推导
    • 2.1. 单元线性回归
    • 2.2. 多元线性回归
  • 3. 简单实例
    • 3.1. 实例1:一元线性回归
    • 实例2: 多元线性回归
    • 3.3. 实例3:房价预测

1. 原理详解

1.1. 线性回归

  假设一个空间中有一堆散点,线性回归的目的就是希望用一条直线,最大程度地“概括”这些散点。它不要求经过每一个散点,但是希望能考虑到每个散点的特点。按照西瓜书的例子就是,好瓜的评判标准y可以由 x i x_i xi​表示,也就是说, f g o o d ( x ) = w 1 x 色泽 + w 1 x 根蒂 + w 1 x 敲声 + b f_{good}(x)=w_1x_{色泽}+w_1x_{根蒂}+w_1x_{敲声}+b fgood​(x)=w1​x色泽​+w1​x根蒂​+w1​x敲声​+b。
  那么我们不难发现,线性回归需要考虑的几个问题:

  • 确定系数 w i w_i wi​以及偏置 b b b
  • 如何确定 f g o o d ( x ) f_{good}(x) fgood​(x)能很好地概括瓜的特点

1.2. 回归系数

  关于这点,我们需要确定,我们算出来的回归系数一定是当前最优的结果,怎么确定呢?

  • 均方误差(西瓜书)
  • R^2(用于模型评估)

均方误差(MSE)

  这个其实就是残差平方和的平均值。
M S E = ∑ i = 0 n y i − f ( x i ) n MSE=\frac{\sum_{i=0}^ny_i-f(x_i)}{n} MSE=n∑i=0n​yi​−f(xi​)​

R^2

R 2 = S S R S S T = S S T − S S E S S T = 1 − S S E S S T R^2=\frac{SSR}{SST}=\frac{SST-SSE}{SST}=1-\frac{SSE}{SST} R2=SSTSSR​=SSTSST−SSE​=1−SSTSSE​

  其中,SST是总偏差平方和
S S T = ∑ i = 0 n ( y i − y ˉ ) 2 SST=\sum_{i=0}^n(y_i-\bar y)^2 SST=i=0∑n​(yi​−yˉ​)2
  SSR是回归平方和
S S R = ∑ i = 0 n ( f ( x i ) − y ˉ ) 2 SSR=\sum_{i=0}^n(f(x_i)-\bar y)^2 SSR=i=0∑n​(f(xi​)−yˉ​)2
  SSE是残差平方和
S S E = ∑ i = 0 n ( y i − f ( x i ) ) 2 SSE=\sum_{i=0}^n(y_i-f(x_i))^2 SSE=i=0∑n​(yi​−f(xi​))2

2. 公式推导

2.1. 单元线性回归

这里我们跟西瓜书一样采取均方误差。

计算得w与b。

2.2. 多元线性回归

多元线性回归涉及到矩阵运算。

若X为m * n的矩阵,则 X T X X^TX XTX为n * n的方阵。 X T X X^TX XTX的意义在于保持其为可逆矩阵,因为若它不可逆,则导致其行列式为0,就会导致w趋向无穷。

3. 简单实例

3.1. 实例1:一元线性回归

计算这个二元线性回归

index x y
1 6 2
2 8 1
3 10 0
4 14 2
5 18 0

我们这里采用几种解法

  1. 西瓜书内的公式
  2. 最小二乘估计w, b
  3. linalg直接解
# -*- coding:utf-8 -*-
# 2022.09.05
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3Ddef task1_vis(x, y, w, b):fig = plt.figure()ax = fig.add_subplot(1, 1, 1)ax.scatter(x, y)x = np.linspace(0, 20, 100)y = w * x + bax.plot(x, y)# plt.title('Pizza price plotted against diameter')ax.set_xlabel('x', fontdict={'size': 10, 'color': 'black'})ax.set_ylabel('y', fontdict={'size': 10, 'color': 'black'})plt.show()def task1_way1(x, y):w = np.dot(y, (x - x.mean())).sum() / (sum(np.square(x)) - np.square(sum(x)) / x.shape[0])b = sum(y - np.multiply(w, x)) / x.shape[0]print("方法一:\t\tw:{}\tb:{}".format(w, b))def task1_way2(x, y):x_bar = x.mean()y_bar = y.mean()# 计算协方差cov = np.multiply((x - x_bar).transpose(), (y - y_bar)).sum() / (x.shape[0] - 1)var = np.var(x, ddof=1)w = cov / var# w = (y_bar - w * x_bar) / (x.shape[0])b = y_bar - w * x_barprint("方法二:\t\tw:{}\tb:{}".format(w, b))def task1_way3(x, y):from numpy.linalg import lstsqx = np.vstack([x, [1 for i in range(x.shape[0])]])w = lstsq(x.T, y.reshape(-1, 1))[0][0][0]b = lstsq(x.T, y.reshape(-1, 1))[0][1][0]print("方法三:\t\tw:{}\tb:{}".format(w, b))return w, bdef task1():x = np.array([6, 8, 10, 14, 18])y = np.array([7, 9, 13, 17.5, 18])task1_way1(x, y)task1_way2(x, y)w, b = task1_way3(x, y)task1_vis(x, y, w, b)if __name__ == '__main__':task1()

运行结果如下

实例2: 多元线性回归

# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3Ddef task2():from numpy.linalg import invX = np.array([[1, 6, 2], [1, 8, 1], [1, 10, 0], [1, 14, 2], [1, 18, 0]])X[:, 2] = X[:, 1] * X[:, 1]Y = np.array([[7], [9], [13], [17.5], [18]])beita = np.dot(inv(np.dot(np.transpose(X), X)), np.dot(np.transpose(X), Y))print(beita)from numpy.linalg import lstsqprint(lstsq(X, Y)[0])if __name__ == '__main__':# task1()task2()

3.3. 实例3:房价预测

# -*- coding:utf-8 -*-
# 2022.09.05
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3Ddef task1_vis(x, y, w, b):fig = plt.figure()ax = fig.add_subplot(1, 1, 1)ax.scatter(x, y)y = w * x + bax.plot(x, y, 'r')# plt.title('Pizza price plotted against diameter')ax.set_xlabel('x', fontdict={'size': 10, 'color': 'black'})ax.set_ylabel('y', fontdict={'size': 10, 'color': 'black'})plt.show()def task1_way1(x, y):w = np.dot(y, (x - x.mean())).sum() / (sum(np.square(x)) - np.square(sum(x)) / x.shape[0])b = sum(y - np.multiply(w, x)) / x.shape[0]print("方法一:\t\tw:{}\tb:{}".format(w, b))return w, bdef task1_way2(x, y):x_bar = x.mean()y_bar = y.mean()# 计算协方差cov = np.multiply((x - x_bar).transpose(), (y - y_bar)).sum() / (x.shape[0] - 1)var = np.var(x, ddof=1)w = cov / var# w = (y_bar - w * x_bar) / (x.shape[0])b = y_bar - w * x_barprint("方法二:\t\tw:{}\tb:{}".format(w, b))def task1_way3(x, y):from numpy.linalg import lstsqx = np.vstack([x, [1 for i in range(x.shape[0])]])w = lstsq(x.T, y.reshape(-1, 1))[0][0][0]b = lstsq(x.T, y.reshape(-1, 1))[0][1][0]print("方法三:\t\tw:{}\tb:{}".format(w, b))return w, bdef task1():x = np.array([6, 8, 10, 14, 18])y = np.array([7, 9, 13, 17.5, 18])task1_way1(x, y)task1_way2(x, y)w, b = task1_way3(x, y)task1_vis(x, y, w, b)def task2():from numpy.linalg import invX = np.array([[1, 6, 2], [1, 8, 1], [1, 10, 0], [1, 14, 2], [1, 18, 0]])X[:, 2] = X[:, 1] * X[:, 1]Y = np.array([[7], [9], [13], [17.5], [18]])beita = np.dot(inv(np.dot(np.transpose(X), X)), np.dot(np.transpose(X), Y))print(beita)from numpy.linalg import lstsqprint(lstsq(X, Y)[0])def task3():x_train = np.array([77.36, 116.74, 116.7, 100.68, 116.1, 115.81, 104.24, 106.73, 115.86])y_train = np.array([470, 730, 760, 680, 700, 720, 700, 690, 730])x_test = np.array([56.6, 78.4, 58, 123.5, 56.8, 77, 150.6])w, b = task1_way1(x_train, y_train)y_pre = x_test * w + bprint(y_pre)task1_vis(x_train, y_train, w, b)if __name__ == '__main__':# task1()# task2()task3()

机器学习之线性回归原理详解、公式推导(手推)、简单实例相关推荐

  1. 《设计模式详解》手写简单的 Spring 框架

    自定义 Spring 框架 自定义 Spring 框架 Spring 使用回顾 Spring 核心功能结构 bean 概述 Spring IOC 相关接口 BeanFactory 接口 BeanDef ...

  2. ES读写原理详解和hive推送ES案例

    目录 一.ES使用场景 1.1 存储数据(基础) 1.2 搜索(核心能力) 1.3 数据分析和可视化(核心能力) 二.ES的原理 2.1 ES如何实现分布式? 2.2 ES读写数据的原理 2.2.1 ...

  3. 机器学习之K-means原理详解、公式推导、简单实例(python实现,sklearn调包)

    目录 1. 聚类原理 1.1. 无监督与聚类 1.2. K均值算法 2. 公式推导 2.1. 距离 2.2. 最小平方误差 3. 实例 3.1. python实现 3.2. sklearn实现 4. ...

  4. 机器学习-线性回归 原理详解

    一.什么是线性回归 回归算法是一种有监督算法. 回归算法可以看作是用来建立"解释"变量(自变量X)和因变量(Y)之间的关系.从机器学习的角度讲,就是构建一个算法模型来做属性(X)与 ...

  5. springmvc原理详解(手写springmvc)

    最近在复习框架 在快看小说网搜了写资料 和原理 今天总结一下 希望能加深点映像  不足之处请大家指出 我就不画流程图了 直接通过代码来了解springmvc的运行机制和原理 回想用springmvc用 ...

  6. 【机器学习】XgBoost 原理详解 数学推导

    XgBoost   (Xtreme Gradient Boosting 极限 梯度 增强) 1.基本描述:             假设Xg-模型有 t 颗决策树数,t棵树有序串联构成整个模型,各决策 ...

  7. python压缩算法_LZ77压缩算法编码原理详解(结合图片和简单代码)

    前言 LZ77算法是无损压缩算法,由以色列人Abraham Lempel发表于1977年.LZ77是典型的基于字典的压缩算法,现在很多压缩技术都是基于LZ77.鉴于其在数据压缩领域的地位,本文将结合图 ...

  8. python图片压缩原理_LZ77无损压缩算法原理详解(结合图片和简单代码)

    LZ77算法是无损压缩算法,由以色列人Abraham Lempel发表于1977年.LZ77是典型的基于字典的压缩算法,现在很多压缩技术都是基于LZ77.鉴于其在数据压缩领域的地位,本文将结合图片和源 ...

  9. 拉格朗日对偶性详解(手推笔记)

    个人原创笔记,转载请附上本文链接. 拉格朗日对偶性其实也没有那么难理解,在我梳理过后你会发现也就是那一回事罢了. 围绕着拉格朗日对偶性探讨的整个流程下来,实际上牵扯到 三个问题: 原始问题,我们记作 ...

最新文章

  1. 《深度探索C++对象模型》--3 Data语意学
  2. junit rule_Tomcat上下文JUnit @Rule
  3. 全国高等学校计算机水平考试总结,参加全国计算机等级考试的经历和总结
  4. servlet上传文件接收工具
  5. java mdpi_如何使用drawable兼容所有屏幕尺寸(idpi,mdpi,hdpi,xhdpi,xxhdpi)
  6. python读取一个文件的大小_Python-读取文件的大小
  7. memcached mysql 同步,mysql中使用UDF自动同步memcached效率笔记
  8. 文件服务器 测试,python-文件服务器测试
  9. IEC104规约学习笔记
  10. libcef-框架架构中概念介绍-命令行参数-元素布局-应用程序结构(二)
  11. 用html制作ps,ps制作图片的步骤
  12. h5调用第三方app (项目开发思路)
  13. 程序员为什么多数秃头?看完这15个瞬间,终于懂了
  14. 詹姆斯高斯林_詹姆斯·高斯林接下来要做什么?
  15. c语言公交查询系统,公交路线查询系统(基于数据结构和C语言)完整
  16. 手机修图软件测试,10款好用的手机图片编辑器软件排行榜
  17. Linux操作系统———李纳斯
  18. 《康熙王朝》剧情分集介绍【全】
  19. php开发his软件,HIS系统(his管理系统)V3.0.1 官网版
  20. python 查看处理器架构

热门文章

  1. linux系统读取plc状态,Linux系统下上位机通讯协议及PLC冗余系统组态-工业支持中心-西门子中国...
  2. BP神经网络的Java实现
  3. Flutter 2.2 更新详解
  4. 反击爬虫,前端工程师的脑洞可以有多大?
  5. MySQL系列之Natural Join用法
  6. 基于Flexlive.CQP.Framework的C# 酷Q UDP实现
  7. 生产制造业ERP系统模块
  8. 入门人工智能该读哪些书?五份AI经典书单
  9. 树的深度 递归非递归实现
  10. 试解leetcode算法题--求解方程