PySpark线性回归与广义线性模型

  • 1.线性回归
  • 2.岭回归(Ridge Regression)与LASSO回归(LASSO Regression)
  • 3.广义线性模型 (GLM)

本文为销量预测第7篇:线性回归与广义线性模型
1篇:PySpark与DataFrame简介
2篇:PySpark时间序列数据统计描述,分布特性与内部特性
3篇:缺失值填充与异常值处理
4篇:时间序列特征工程
5篇:特征选择
6篇:简单预测模型
8篇:机器学习调参方法
9篇:销量预测建模中常用的损失函数与模型评估指标

本节从原理和代码上讲解销量预测任务中使用到的Spark.ML内置线性回归模型和广义线性模型。

1.线性回归

回归分析是预测建模中最为基础的技术,通过拟合回归线(regression line)建立输入变量与目标变量之间的线性关系,构建损失函数,求解损失函数最小时的参数w和b。表达式如下:

y^=wx+b其中,y是因变量,x是自变量,w是斜率,b为截距;\hat{y}=w x+b\\ 其中,y是因变量,\ x是自变量, \ w是斜率, \ b为截距; y^​=wx+b其中,y是因变量, x是自变量, w是斜率, b为截距;
损失函数为:
L(w,b)=1n∑i=1n(wxi+b−yi)2L(w, b)=\frac{1}{n} \sum_{i=1}^{n}\left(w x_{i}+b-y_{i}\right)^{2} L(w,b)=n1​i=1∑n​(wxi​+b−yi​)2

利用梯度下降(gradient descent)迭代更新(针对w和b求偏导)。
W←W−α∂J(W)∂Wb←b−α∂L∂bW \leftarrow W-\alpha \frac{\partial J(W)}{\partial W} \\ b \leftarrow b-\alpha \frac{\partial L}{\partial b} W←W−α∂W∂J(W)​b←b−α∂b∂L​

线性回归模型计算效率高,可解释强,具有完备的数量统计理论支撑等优点,作为基础模型广泛使用在回归建模任务中,但也存在针对非正态分布的数据,线性表达能力较弱,若多个自变量间存在共线性,求解的模型参数不稳定,导致预测能力下降等问题,故有下文提到的Ridge Regression、LASSO Regression以及广义线性回归来解决或补充。

2.岭回归(Ridge Regression)与LASSO回归(LASSO Regression)

当使用最小二乘法计算线性回归模型参数时,如果数据特征之间存在多重共线性,那么最小二乘法对输入变量中的噪声非常敏感,其解会变得极为不稳定。为了解决这个问题,就有了岭回归(Ridge Regression )。

岭回归(Ridge Regression)是在一般线性回归的基础上加入L2正则项,通过限制参数权重(回归)系数,使weight不会变得特别大,则模型对输入特征中的噪声敏感度就会降低,故在保证最佳拟合误差的同时,参数尽可能的“简单”,模型的泛化能力得以增强,岭回归公式如下:

J(θ)=1m∑i=1m(y(i)−(wx(i)+b))2+λ∥w∥22=MSE(θ)+λ∑i=1nθi2J(\theta)=\frac{1}{m} \sum_{i=1}^{m}\left(y^{(i)}-\left(w x^{(i)}+b\right)\right)^{2}+\lambda\|w\|_{2}^{2}\\ =M S E(\theta)+\lambda \sum_{i=1}^{n} \theta_{i}^{2} J(θ)=m1​i=1∑m​(y(i)−(wx(i)+b))2+λ∥w∥22​=MSE(θ)+λi=1∑n​θi2​

迭代优化函数如下
θj:=θj−α∑i=1m(hθ(x(i))−y(i))xj(i)−2λθj\theta_{j}:=\theta_{j}-\alpha \sum_{i=1}^{m}\left(h_{\theta}\left(x^{(i)}\right)-y^{(i)}\right) x_{j}^{(i)}-2 \lambda \theta_{j} θj​:=θj​−αi=1∑m​(hθ​(x(i))−y(i))xj(i)​−2λθj​

LASSO(Least absolute shrinkage and selection operator, Tibshirani(1996))方法是一种压缩估计。其基本思想是在构建L1正则化,在回归系数的绝对值之和小于一个常数的约束条件下,使残差平方和最小化,将一些作用较小的特征的权重系数置为0,从而获得稀疏解,实现了降维(特征筛选)的同时解决多重共线性的问题。

LASSO的代价函数为:

J(θ)=12m∑i=1m(y(i)−(wx(i)+b))2+λ∥w∥1=12MSE(θ)+λ∑i=1n∣θi∣J(\theta)=\frac{1}{2 m} \sum_{i=1}^{m}\left(y^{(i)}-\left(w x^{(i)}+b\right)\right)^{2}+\lambda\|w\|_{1} \\ = \frac{1}{2} M S E(\theta)+\lambda \sum_{i=1}^{n}\left|\theta_{i}\right| \quad J(θ)=2m1​i=1∑m​(y(i)−(wx(i)+b))2+λ∥w∥1​=21​MSE(θ)+λi=1∑n​∣θi​∣

示例代码

import sys
import datetime
from pyspark.sql import SparkSession
from pyspark.ml.regression import LinearRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import Normalizer
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoderimport warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings('ignore')spark = SparkSession. \Builder(). \config("spark.sql.crossJoin.enabled", "true"). \config("spark.sql.execution.arrow.enabled", "false"). \enableHiveSupport(). \getOrCreate()class linear_predict(object):def __init__(self,data,importance_feature,reg, inter, elastic):self.data=dataself.importance_feature=importance_featureself.reg=regself.inter=interself.elastic=elasticdef prediction(self):reg, inter, elastic=self.reg,self.inter,self.elasticdf=self.datainputCols = self.importance_featuredf = df.na.fill(0)df = df.withColumn('dayofweek', dayofweek('dt'))df = df.withColumn("dayofweek", df["dayofweek"].cast(StringType()))dayofweek_ind = StringIndexer(inputCol='dayofweek', outputCol='dayofweek_index')dayofweek_ind_model = dayofweek_ind.fit(df)dayofweek_ind_ = dayofweek_ind_model.transform(df)onehotencoder = OneHotEncoder(inputCol='dayofweek_index', outputCol='dayofweek_Vec')df = onehotencoder.transform(dayofweek_ind_)feature_vector = VectorAssembler(inputCols=inputCols, outputCol="features")output = feature_vector.transform(df)features_label = output.select("shop_number", "item_number", "dt", "features", "label")train_set =features_label.where(features_label['dt'] <='2020-08-28')train_data, val_data = train_set.randomSplit([0.8, 0.2])pred_data=features_label.where(features_label['dt']>'2020-08-28').where(features_label['dt']<'2020-09-01')lr = LinearRegression(regParam=reg, fitIntercept=inter, elasticNetParam=elastic,solver="normal")model = lr.fit(train_data)print('{}{}'.format('model_intercept:', model.intercept))print('{}{}'.format('model_coeff:', model.coefficients))feature_map=dict(zip(inputCols, model.coefficients))print("feature_map",feature_map)#model predictionpredictions = model.transform(pred_data)predictions.select("shop_number", "item_number", "dt","prediction").createOrReplaceTempView('linear_predict_out')insert_sql="""insert overwrite table scm.linear_regression_prediction partition (dt='{dt}')selectstore_code,goods_code,dt,predictionfrom linear_predict_out"""spark.sql(insert_sql)spark.stop()def read_importance_feature():""":return: list of importance of feature"""importance_feature = spark.sql("""select feature from app.selection_result_v1 where cum_sum<0.95 and update_date in (select max(update_date) as update_date from app.selection_result_v1)""").select("feature").collect()importance_list = [row.feature for row in importance_feature]print('..use'+str(len(importance_list))+'numbers of feature...')return importance_listdef main():data=spark.sql("""select * from app.dataset_input_df_v2 where dt>='2020-08-04'""")importance_feature=read_importance_feature()reg, inter, elastic = 0.5,False,1.0linear_predict(data, importance_feature, reg, inter, elastic).prediction()if __name__ == '__main__':main()

其中LinearRegression的主要参数含义如下:

  • *regParam:正则化参数。用于防止过拟合
  • elasticNetParam:取值范围[0,1]。取 0时,采用L2。取 1时,采用L1正则化。
  • fitIntercept:是否拟合截距项 True(默认)/False
  • standardization:模型拟合前是否对训练特征进行标准化处理
  • solver:求解算法的优化。支持的选项:auto, normal, l-bfgs
  • aggregationDepth:树栅建议深度(>= 2)
  • loss:模型待优化的损失函数。选项有:squaredError, huber。
  • epsilon:对形状参数进行鲁棒性控制。必须是> 1.0。只有在损失函数是huber时才有效

弹性网络(Elastic Net),是在岭回归和Lasso回归中进行了折中,elasticNetParam=0时为Ridge Regression,elasticNetParam=1时为LASSO Regression。

L=min⁡(MSE+λ(1−αα∥ω∥2+α∥ω∥))=min⁡w12n∑i=1n(Xiw−yi)2+λ[1−α2∥w∥22+α∥w∥1]L=\min \left(M S E+\lambda\left(\frac{1-\alpha}{\alpha}\|\omega\|^{2}+\alpha\|\omega\|\right)\right)\\=\min _{w} \frac{1}{2 n} \sum_{i=1}^{n}\left(X_{i} w-y_{i}\right)^{2}+\lambda\left[\frac{1-\alpha}{2}\|w\|_{2}^{2}+\alpha\|w\|_{1}\right] L=min(MSE+λ(α1−α​∥ω∥2+α∥ω∥))=wmin​2n1​i=1∑n​(Xi​w−yi​)2+λ[21−α​∥w∥22​+α∥w∥1​]

另外,也可以按照前文所阐述的多项式特征生成方法,放入多项式特征,从而构建多项式回归。

3.广义线性模型 (GLM)

广义线性模型,是为了克服线性回归模型缺点而出现,是对一般线性模型的扩展。首先自变量可以是离散的,也可以是连续的。离散的可以是0-1变量,也可以是多种取值的计数变量。与线性回归模型相比较,主要有以下推广:

(1)随机误差项不一定服从正态分布,可以服从二项、泊松、负二项、正态、伽马、逆高斯等指数分布族。

(2)引入联接函数。因变量和自变量通过联接函数产生影响,联接函数满足单调可导。

链接函数描述了线性预测XβXβXβ与分布期望值E[Y]E[Y]E[Y]的关系:E[Y]=μ=g−1(Xβ)E[Y]=\mu=g^{-1}(X \beta)E[Y]=μ=g−1(Xβ),其中ggg表示链接函数,μμμ表示均值函数。 一般情况下,高斯分布对应于恒等式,泊松分布对应于自然对数函数等。常用的联接函数有:
Y=X∗βY=ln(X∗β)Y=(X∗β)ln(Y/(1−Y))=X∗β\\ Y= X*β\\ Y=ln(X*β)\\ Y= \sqrt{(X*β)}\\ ln(Y/(1-Y))=X*β \\ Y=X∗βY=ln(X∗β)Y=(X∗β)​ln(Y/(1−Y))=X∗β

针对广义线性回归在销量预测上应用,本节以泊松回归为实例进行介绍,泊松(Poisson)回归是广义线性模型中常用的一种,因变量服从Poisson分布。

在很多情况下销量数据经常计数的形式出现,如每天进店客流,或部分商品销售量会有 0, 1, 2…等计数的形式出现。如果计数值很大,销售数量分布于连续的样本空间 [100,∞),则[100,∞)与离散样本空间 {100,101,102,…}之间的差异对预测没有显著的影响,如销售数据以较小的计数值 (0,1,2,…)出现,那么就需要使用更适合非负整数的预测方法,也就是泊松回归。

在正式使用泊松回归之前,先解决一个疑问,那就是针对非正态的数据通过取对数处理把y值转成正态或者接近正态以后放入模型是否是合理的处理方式呢?而这种方法也是大多数人常用的解决办法,这里特意阐述一下取对数的弊端。

针对数据分布呈现偏态的问题,通常也有直接转对数处理,但是这种做法可能并不合理,因为在使用对数线性模型的时,隐含的一个假设就是y服从正态分布。虽然取对数能够缓解因 yyy的波动性大带来的异方差和极端值影响。针对销量数据, yyy显然只能取非负数。但销量为0或者1是经常出现的,而如果想要取对数,必须保证y>0且不等于1才有意义。当然,也可以对yyy取对数之前加上一个值,使yyy都大于0,比如所以的销售量加2,但是这样会导致估计量的不一致性(Santos Silva,J. M. C and S. Tenreyro ( 2006 )),并且如果0值多,那么因变量yyy微小的调整就会导致模型估计系数波动,也降低了模型的解释力。

如图所示,销量分布呈现,右偏。如下图。

spark.ml中的广义线性回归,通过以下方式调用,其他主体部分和以上线性回归代码相同。

from pyspark.ml.regression import GeneralizedLinearRegression
glr = GeneralizedLinearRegression(family="poisson", link="log", regParam=1.0)

在广义线性回归模型中,其分布于对应的link function可选择如下:

  • gaussian : identity, log, inverse

  • binomial : logit, probit, cloglog

  • poisson : log, identity, sqrt

  • gamma:inverse, identity, log

  • tweedie : 使用tweedie分布时,需指定linkPower参数,默认为1

在销量预测领域还有一个值得关注和尝试的广义线性是tweedie分布,Tweedie分布是一种泊松分布和伽马分布的复合。

以上就是关于线下回归和广义线性回归在预测任务中的实战示例。

PySpark线性回归与广义线性模型相关推荐

  1. UA MATH571A 多元线性回归IV 广义线性模型

    UA MATH571A 多元线性回归IV 广义线性模型 广义线性模型 二值被解释变量 Probit模型 Logit模型 系数的最大似然估计 系数的推断 Wald检验 似然比检验 二项回归 拟合优度检验 ...

  2. 从线性回归到广义线性模型

    在谈及广义线性模型是什么之前,我想先分析一下线性回归模型有什么限制.在这里先说明一点,以下分析的线性回归模型,是考虑了随机误差项的完整的线性回归模型. 首先我想说明一下关于响应变量(y)的分布,以及响 ...

  3. ML—广义线性模型导论

    Andrew Zhang Tianjin Key Laboratory of Cognitive Computing and Application Tianjin University Nov 3, ...

  4. 指数分布族函数与广义线性模型(Generalized Linear Models,GLM)

    目录 1 综述 2 指数分布族 3 广义线性模型 3.1 定义 3.2 为什么引入GLM 3.3 连接函数的获取 4 常见连接函数求解及对应回归 4.1 伯努利分布 > Logistics回归 ...

  5. 线性回归、logistic回归、广义线性模型——斯坦福CS229机器学习个人总结(一)

    纪念我第一个博客的碎碎念 先前我花了四五个月的业余时间学习了Ng的机器学习公开课,学习的过程中我就在想,如果我能把这个课程啃完,就开始写一些博客,把自己的所得记录下来,现在是实现的时候了.也如刘未鹏的 ...

  6. 广义线性模型(Generalized Linear Models, GLM)与线性回归、逻辑回归的关系

    线性回归和逻辑回归都是广义线性模型的特例. 1 指数分布族 如果一个分布可以用如下公式表达,那么这个分布就属于指数分布族. 这是<数理统计>课本中的相关定义,大多数利用的定义如下(y不是一 ...

  7. R语言广义线性模型Logistic回归案例代码

    R语言广义线性模型Logistic回归案例代码 在实际应用中,Logistic模型主要有三大用途: 1)寻找危险因素,找到某些影响因变量的"坏因素",一般可以通过优势比发现危险因素 ...

  8. R语言广义线性模型泊松回归(Poisson Regression)模型

    R语言广义线性模型泊松回归(Poisson Regression)模型 试想一下,你现在就站在一个人流密集的马路旁,打算收集闯红灯的人群情况(?).首先,利用秒表和计数器,一分钟过去了,有5个人闯红灯 ...

  9. 广义线性模型?链接函数?sigmoid和softmax?Logistic处理多分类问题?logistic回归处理超大数据?使用logistic和randomsearch进行组合获取最优参数组合、优缺点

    广义线性模型?链接函数?sigmoid和softmax?Logistic处理多分类问题?logistic回归处理超大数据?使用logistic和randomsearch进行组合获取最优参数组合.优缺点 ...

最新文章

  1. 复习PHP-语言参考-类型
  2. ble之Transmit window offset and Transmit window size
  3. 30分钟掌握ES6/ES2015核心内容(上)
  4. [Leetcode] 第289题 生命游戏
  5. 如何优化JavaScript脚本的性能
  6. xshell使用xftp传输文件和使用pure-ftpd搭建ftp服务
  7. 2018年10月17日普级B组【模拟赛】
  8. php全部公开课,PHP公开课|这篇PHP的each()函数教学数,只为了帮你的PHP会学的更好...
  9. object detection错误之no module named nets
  10. 收藏 | Yann Lecun纽约大学《深度学习》2020课程笔记中文版
  11. 初中数学分几个模块_11.初中数学:xy4x+6y5,怎么因式分解?分组配方法再平方差...
  12. dsoframer java_基于DsoFramer控件的Office编辑控件
  13. 八种常规常用的SQL查询语句
  14. 试图执行系统不支持的操作,问题
  15. 比特大陆发布终端 AI 芯片 端云联手聚焦安防
  16. 关于RSS的聚合---OPML
  17. 最新八个免费Logo设计工具灵感网站,帮你搞定logo设计难题!
  18. 为什么我的燃尽图没更新?
  19. mac创建文件服务器,mac命令行终端怎么创建文件 mac命令行终端创建文
  20. 网络虚拟化城域网创新实践

热门文章

  1. charAt(i) 呵 charAt(i)-‘0‘的 区别
  2. 计算机网络第七版(谢希仁) 第一章 概述 1-10,1-17作业答案
  3. SVG奥林匹克五环动画
  4. 云队友丨真正限制你的,是你对潜力的一无所知
  5. Python之小波分析
  6. c# 禁用鼠标与键盘
  7. luogu P4233 射命丸文的笔记
  8. Walle多渠道打包
  9. 职场小白新建SSM项目
  10. 浅谈精益数字化工厂(Lean Digital Factory, LDF)