Datawhale

作者:尹晓丹,Datawhale优秀学习者

寄语:本文对线性回归算法的原理及模型,学习策略、算法求解和sklearn参数做了详细的讲解。同时,用例子进行Python代码实践。

线性回归是利用数理统计中的回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方法,是机器学习最基础的算法之一。

学习框架

模型建立

线性回归原理

进入一家房产网,可以看到房价、面积、厅室呈现以下数据:

将价格和面积、厅室数量的关系习得为:
































使得















, 这就是一个直观的线性回归的样式。

线性回归模型

1. 线性回归的一般形式

有数据集,其中,其中




表示变量的数量,




表示每个变量的维度。

可以用以下函数来描述y和x之间的关系:

如何来确定




的值,使得







尽可能接近




的值呢?均方误差是回归中常用的性能度量,即:

我们可以选择




,试图让均方误差最小化。

2. 极大似然估计(概率角度阐释)

下面我们用极大似然估计,来解释为什么要用均方误差作为性能度量。可以把目标值和变量写成如下等式:








































表示我们未观测到的变量的印象,即随机噪音。我们假定




是独立同分布,服从高斯分布。(根据中心极限定理)

因此,

我们建立极大似然函数,即描述数据遵从当前样本分布的概率分布函数。由于样本的数据集独立同分布,因此可以写成:

选择 ???? ,使得似然函数最大化,这就是极大似然估计的思想。为了方便计算,我们计算时通常对对数似然函数求最大值:

显然,最大化







即最小化:

这一结果即均方误差,因此用这个值作为代价函数来优化模型在统计学的角度是合理的。

学习策略

1. 损失函数(Loss Function)

度量单样本预测的错误程度,损失函数值越小,模型就越好。常用的损失函数包括:0-1损失函数、平方损失函数、绝对损失函数、对数损失函数等。

2. 代价函数(Cost Function)

度量全部样本集的平均误差。常用的代价函数包括均方误差、均方根误差、平均绝对误差等。

3. 目标函数(Object Function)

代价函数和正则化函数,最终要优化的函数。

4. 思考题

既然代价函数已经可以度量样本集的平均误差,为什么还要设定目标函数?

答:当模型复杂度增加时,有可能对训练集可以模拟的很好,但是预测测试集的效果不好,出现过拟合现象,这就出现了所谓的“结构化风险”。结构风险最小化即为了防止过拟合而提出来的策略,定义模型复杂度为 ????(????) ,目标函数可表示为:

例如有以下6个房价和面积关系的数据点,可以看到,当设定:



























时,可以完美拟合训练集数据,但是,真实情况下房价和面积不可能是这样的关系,出现了过拟合现象。当训练集本身存在噪声时,拟合曲线对未知影响因素的拟合往往不是最好的。

通常,随着模型复杂度的增加,训练误差会减少;但测试误差会先增加后减小。我们的最终目的时试测试误差达到最小,这就是我们为什么需要选取适合的目标函数的原因。

算法求解

梯度下降法

设定初始参数




,不断迭代,使得







最小化:































对其求导为:

即:

将所有的参数以向量形式表示,可得:

由于这个方法中,参数在每一个数据点上同时进行了移动,因此称为批梯度下降法,对应的,我们可以每一次让参数只针对一个数据点进行移动,即:

这个算法称为随机梯度下降法,随机梯度下降法的好处是,当数据点很多时,运行效率更高;

其缺点是,因为每次只针对一个样本更新参数,未必找到最快路径达到最优值,甚至有时候会出现参数在最小值附近徘徊而不是立即收敛。但当数据量很大的时候,随机梯度下降法经常优于批梯度下降法。






为凸函数时,梯度下降法相当于让参数




不断向




的最小值位置移动。

梯度下降法的缺陷:如果函数为非凸函数,有可能找到的并非全局最优值,而是局部最优值。

最小二乘法矩阵求解

令:

其中,

由于






































































对于向量来说,有




















因此可以把损失函数写作





























为最小化







,对




求导可得:

中间两项互为转置,由于求得的值是个标量,矩阵与转置相同,因此可以写成

令偏导数等于零,由于最后一项和 ???? 无关,偏导数为0。因此,

利用矩阵求导性质,































































令导数等于零,











































牛顿法


































可求得:






































重复迭代,可以让逼近取到











的最小值。当我们对损失函数











进行优化的时候,实际上是想要取到
















的最小值,因此迭代公式为:


































是向量值的时候,


























的偏导数,











的海森矩阵,





































问题:请用泰勒展开法推导牛顿法公式。

答:将泰勒公式展开到二阶:

对上式求导,并令导数等于0,求得




可以求得,






































牛顿法的收敛速度非常快,但海森矩阵的计算较为复杂,尤其当参数的维度很多时,会耗费大量计算成本。我们可以用其他矩阵替代海森矩阵,用拟牛顿法进行估计。

牛顿法比梯度下降法收敛速度更快,红色的牛顿法的迭代路径,绿色的是梯度下降法的迭代路径。

拟牛顿法

常用的拟牛顿法算法包括DFP,BFGS等。拟牛顿法的思路是用一个矩阵替代计算复杂的海森矩阵




,因此要找到符合H性质的矩阵。

要求得海森矩阵符合的条件,同样对泰勒公式求导

即迭代后的值,代入可得:

更一般的,

为第k个迭代值。即找到矩阵




,使得它符合上式。

线性回归的评估指标

均方误差(MSE):















































均方根误差(RMSE):

平均绝对误差(MAE):













































但以上评价指标都无法消除量纲不一致而导致的误差值差别大的问题,最常用的指标是







,可以避免量纲不一致问题。

我们可以把







理解为,回归模型可以成功解释的数据方差部分在数据固有方差中所占的比例,







越接近1,表示可解释力度越大,模型拟合的效果越好。

sklearn参数详解

1. it_intercept

默认为True,是否计算该模型的截距。如果使用中心化的数据,可以考虑设置为False,不考虑截距。一般还是要考虑截距。

2. normalize

默认为false. 当fit_intercept设置为false的时候,这个参数会被自动忽略。如果为True,回归器会标准化输入参数:减去平均值,并且除以相应的二范数。当然啦,在这里还是建议将标准化的工作放在训练模型之前。通过设置sklearn.preprocessing.StandardScaler来实现,而在此处设置为false。

3. copy_X

默认为True, 否则X会被改写

4. n_jobs

int 默认为1. 当-1时默认使用全部CPUs ??(这个参数有待尝试)

5. 可用属性

**coef_????*训练后的输入端模型系数,如果label有两个,即y值有两列。那么是一个2D的array

6. intercept_: 截距

7. 可用的methods

  • fit(X,y,sample_weight=None):

    X: array, 稀疏矩阵 [n_samples,n_features]

    y: array [n_samples, n_targets]

    sample_weight: 权重 array [n_samples] 在版本0.17后添加了sample_weight

  • get_params(deep=True):返回对regressor 的设置值

  • predict(X): 预测 基于 R^2值

  • score:评估

练习题

请用以下数据(可自行生成尝试,或用其他已有数据集)

  • 首先尝试调用sklearn的线性回归函数进行训练;

  • 用最小二乘法的矩阵求解法训练数据;

  • 用梯度下降法训练数据;

  • 比较各方法得出的结果是否一致。

1. sklearn的线性回归
生成数据:

#生成数据
import numpy as np
#生成随机数
np.random.seed(1234)
x = np.random.rand(500,3)
#构建映射关系,模拟真实的数据待预测值,映射关系为y = 4.2x1 + 5.7*x2 + 10.8*x3,可自行设置值进行尝试
y = x.dot(np.array([4.2,5.7,10.8]))
import numpy as np
from sklearn.linear_model import LinearRegression
import matplotlib.pyplot as plt
%matplotlib inlinelr = LinearRegression(fit_intercept=True)# 默认即可
#训练model
lr.fit(x, y)
print("估计的参数值:%s"%(lr.coef_))
print("估计的截距:%s"%(lr.intercept_))
#计算R方
print('R2:',(lr.score(x,y)))
#测试
x_test = np.array([4,5,7]).reshape(1,-1)
y_hat = lr.predict(x_test)
print('真实值为:',x_test.dot(np.array([4.2,5.7,10.8])))
print('预测值为:',y_hat)

2. 最小二乘法

class LR_LS():def __init__(self):self.w = None      def fit(self, X, y):# 最小二乘法矩阵求解#============================= show me your code =======================self.w = np.linalg.inv(X.T.dot(X)).dot(X.T).dot(y)#============================= show me your code =======================def predict(self, X):# 用已经拟合的参数值预测新自变量#============================= show me your code =======================y_pred = X.dot(self.w)#============================= show me your code =======================return y_predif __name__ == "__main__":lr_ls = LR_LS()lr_ls.fit(x,y)print("估计的参数值:%s" %(lr_ls.w))x_test = np.array([4,5,7]).reshape(1,-1)print('真实值为:',x_test.dot(np.array([4.2,5.7,10.8])))print("预测值为: %s" %(lr_ls.predict(x_test)))

3. 梯度下降法

class LR_GD():def __init__(self):self.w = None     def fit(self,X,y,alpha=0.002,loss = 1e-10): # 设定步长为0.002,判断是否收敛的条件为1e-10y = y.reshape(-1,1) #重塑y值的维度以便矩阵运算[m,d] = np.shape(X) #自变量的维度self.w = np.zeros((d)) #将参数的初始值定为0tol = 1e5#============================= show me your code =======================while tol > loss:h_f = X.dot(self.w).reshape(-1,1) theta = self.w + alpha*np.mean(X*(y - h_f),axis=0) #计算迭代的参数值tol = np.sum(np.abs(theta - self.w))self.w = theta#============================= show me your code =======================def predict(self, X):# 用已经拟合的参数值预测新自变量y_pred = X.dot(self.w)return y_pred  if __name__ == "__main__":lr_gd = LR_GD()lr_gd.fit(x,y)print("估计的参数值为:%s" %(lr_gd.w))x_test = np.array([4,5,7]).reshape(1,-1)print('真实值为:',x_test.dot(np.array([4.2,5.7,10.8])))print("预测值为:%s" %(lr_gd.predict(x_test)))

4. 测试

在3维数据上测试sklearn线性回归和最小二乘法的结果相同,梯度下降法略有误差;又在100维数据上测试了一下最小二乘法的结果比sklearn线性回归的结果更好一些。

往期精彩回顾适合初学者入门人工智能的路线及资料下载机器学习在线手册深度学习在线手册AI基础下载(pdf更新到25集)本站qq群1003271085,加入微信群请回复“加群”获取一折本站知识星球优惠券,复制链接直接打开:https://t.zsxq.com/yFQV7am喜欢文章,点个在看

我感觉这是目前讲得最明白的线性回归的文章了相关推荐

  1. 可能是全网把 ZooKeeper 概念讲的最清楚的一篇文章

    前言 相信大家对 ZooKeeper 应该不算陌生.但是你真的了解 ZooKeeper 是个什么东西吗?如果别人/面试官让你给他讲讲  ZooKeeper 是个什么东西,你能回答到什么地步呢? 我本人 ...

  2. 可能是把Docker的概念讲的最清楚的一篇文章

    本文只是对Docker的概念做了较为详细的介绍,并不涉及一些像Docker环境的安装以及Docker的一些常见操作和命令. 阅读本文大概需要15分钟,通过阅读本文你将知道一下概念: 容器 什么是Doc ...

  3. 这可能是把Docker的概念讲的最清楚的一篇文章

    转载自  这可能是把Docker的概念讲的最清楚的一篇文章 Docker 是世界领先的软件容器平台,本文主要来介绍下关于Docker的那些事儿,主要包含以下内容: 容器 什么是Docker? Dock ...

  4. 【转载】可能是把Docker的概念讲的最清楚的一篇文章

    本文只是对Docker的概念做了较为详细的介绍,并不涉及一些像Docker环境的安装以及Docker的一些常见操作和命令.我觉得是很适合当做睡前读物了~~~~? 阅读本文大概需要15分钟,通过阅读本文 ...

  5. java eden区_(转)可能是把Java内存区域讲的最清楚的一篇文章

    写在前面 本节常见面试题: 问题答案在文中都有提到 如何判断对象是否死亡(两种方法). 简单的介绍一下强引用.软引用.弱引用.虚引用(虚引用与软引用和弱引用的区别.使用软引用能带来的好处). 如何判断 ...

  6. 冲刺!这篇 1658 页的《Java 面试突击核心讲》学明白保底年薪 30w

    前言 2022 年已经到了七月中旬了,又快要到一年一度的 "金九银十" 秋招大热门,为助力广大程序员朋友 "面试造火箭",小编今天给大家分享的便是这份--165 ...

  7. java主要内存区域_可能是把Java内存区域讲的最清楚的一篇文章

    该楼层疑似违规已被系统折叠 隐藏此楼查看此楼 介绍下 Java 内存区域(运行时数据区) Java 对象的创建过程(五步,建议能默写出来并且要知道每一步虚拟机做了什么) 对象的访问定位的两种方式(句柄 ...

  8. [No0000187]可能是把Java内存区域讲的最清楚的一篇文章

    写在前面(常见面试题) 基本问题: 介绍下 Java 内存区域(运行时数据区) Java 对象的创建过程(五步,建议能默写出来并且要知道每一步虚拟机做了什么) 对象的访问定位的两种方式(句柄和直接指针 ...

  9. 设计模式:Abstract Factory和Builder(比较区别,个人认为讲得很明白)

    如果说 Factory和Prototype是同一个层次的话,那么Abstract Factory和Builder就是更高一级的层次. 1 Abstact Factory 在上面的Factory类型中, ...

最新文章

  1. 散列表(也叫哈希表),
  2. 使用sumlime text有感
  3. asp.net 二级域名(路由方式实现)
  4. 【Java多线程】并发时的线程安全:快乐影院示例
  5. 那个拒绝北大教授,却坚持留在美国做服务员的数学天才,现状如何
  6. Ubuntu 12.04 安装配置 Apache2
  7. Java基础学习总结(154)——Synchronized与Volatile、Synchronized与ReentrantLock概念及区别
  8. Word Cookbook by Eric
  9. oracle分区区别,oracle范围分区表和INTERVAL分区表对于SPLIT分区的区别
  10. Oracle 11g 数据恢复 数据误删除后的恢复 0、执行 select log_mode from v$database;查看是否为归档模式 1、确定删除时间和被删除的表 04-23,GR
  11. Mysql 索引的学习
  12. 【Delta并联机器人Simscape仿真(正运动学、逆运动学)】
  13. 树梅派应用27:通过USB蓝牙适配器连接BLE设备
  14. win7更换锁屏壁纸(操作步骤)
  15. Latex中在字母上加上波浪线
  16. Executor框架的使用
  17. 1 python编程基础学习
  18. 提高生产力:Web开发基础平台WebCommon的设计和实现
  19. Android-谷歌语音识别之离线识别(二)
  20. 人脸与关键点检测:YOLO5Face实战

热门文章

  1. 同步和异步GET,POST请求
  2. Android中Bitmap和Drawable
  3. 设置Windbg符号文件路径
  4. element ui分页怎么做_element ui里面table分页,页数从0开始的怎么做?
  5. Spring Boot 打成war包部署到tomcat8.5.20报无法访问
  6. 基于MFCC系数的欧氏距离测量
  7. 数据结构学习笔记(四):重识数组(Array)
  8. java 4 7怎么算术运算_java四则运算
  9. android小米计算器布局,小米这8个逆天小功能一定用起来!不会用,手机简直白买...
  10. carsim学习笔记4——路面的一些设置1