分位数回归(Quantile regression)是在给定 X \mathbf{X} X的条件下估计 y \mathbf{y} y的中位数或其他分位数, 这是与最小二乘法估计条件均值最大的不同。

分位数回归也是一种线性回归,它为第 q q q个分位数( q ∈ ( 0 , 1 ) q\in (0, 1) q∈(0,1))训练得到线性预测 y ^ ( w , X ) = X w \hat{y}(w, \mathbf{X})=\mathbf{Xw} y^​(w,X)=Xw, 权重 w w w通过最小化下面的公式得到
min ⁡ w 1 n samples ∑ i P B q ( y i − X i w ) + α ∣ ∣ w ∣ ∣ 1 . \min_{w} {\frac{1}{n_{\text{samples}}} \sum_i PB_q(y_i - X_i w) + \alpha ||w||_1}. wmin​nsamples​1​i∑​PBq​(yi​−Xi​w)+α∣∣w∣∣1​.
其中的 P B PB PB 是pinball loss(也被称为linear loss), 定义如下式, α \alpha α 调整L1损失的大小。
P B q ( t ) = q max ⁡ ( t , 0 ) + ( 1 − q ) max ⁡ ( − t , 0 ) = { q t , t > 0 0 , t = 0 ( q − 1 ) t , t < 0 P B_q(t)=q \max (t, 0)+(1-q) \max (-t, 0)= \begin{cases}q t, & t>0 \\ 0, & t=0 \\ (q-1) t, & t<0\end{cases} PBq​(t)=qmax(t,0)+(1−q)max(−t,0)=⎩ ⎨ ⎧​qt,0,(q−1)t,​t>0t=0t<0​

分位数回归的特点:

  • 分位数回归对于异常点没有那么敏感
  • 对于数据也不要求完全符合正态分布,不用假设数据的分布服从方差固定的分布。
  • 当我们的预测结果是一个区间而不是一个点的时候更有用
  • 分位数回归相比于线性回归减少的是MAE

当然相比于普通的线性回归,分位数要求更多的训练数据,同时也比线性回归的计算更复杂。

在下面的图片是对加了符合帕累托分布( Pareto Distribution)噪声的数据进行分位数回归学习的结果

相应的示例代码如下:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils.fixes import sp_version, parse_version
from sklearn.linear_model import QuantileRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import cross_validate# 为了避免老版本SciPy的不兼容问题
solver = "highs" if sp_version >= parse_version("1.6.0") else "interior-point"# 生成样本
rng = np.random.RandomState(42)
x = np.linspace(start=0, stop=10, num=100)
X = x[:, np.newaxis]
y_true_mean = 10 + 0.5 * x# 生成满足pareto distribution的数据
y_pareto = y_true_mean + 10 * (rng.pareto(a, size=x.shape[0]) - 1 / (a - 1))# 生成分位数为0.05, 0.5, 0.95的分位数回归
quantiles = [0.05, 0.5, 0.95]
predictions = {}
out_bounds_predictions = np.zeros_like(y_true_mean, dtype=np.bool_)
for quantile in quantiles:qr = QuantileRegressor(quantile=quantile, alpha=0, solver=solver)y_pred = qr.fit(X, y_pareto).predict(X)predictions[quantile] = y_predif quantile == min(quantiles):out_bounds_predictions = np.logical_or(out_bounds_predictions, y_pred >= y_pareto)elif quantile == max(quantiles):out_bounds_predictions = np.logical_or(out_bounds_predictions, y_pred <= y_pareto)plt.plot(X, y_true_mean, color="black", linestyle="dashed", label="True mean")for quantile, y_pred in predictions.items():plt.plot(X, y_pred, label=f"Quantile: {quantile}")plt.scatter(x[out_bounds_predictions],y_pareto[out_bounds_predictions],color="black",marker="+",alpha=0.5,label="Outside interval",
)
plt.scatter(x[~out_bounds_predictions],y_pareto[~out_bounds_predictions],color="black",alpha=0.5,label="Inside interval",
)
plt.legend()
plt.xlabel("x")
plt.ylabel("y")
_ = plt.title("Quantiles of asymmetric Pareto distributed target")# 使用交叉验证比较线性回归与分位数回归
linear_regression = LinearRegression()
quantile_regression = QuantileRegressor(quantile=0.5, alpha=0, solver=solver)cv_results_lr = cross_validate(linear_regression,X,y_pareto,cv=3,scoring=["neg_mean_absolute_error", "neg_mean_squared_error"],
)
cv_results_qr = cross_validate(quantile_regression,X,y_pareto,cv=3,scoring=["neg_mean_absolute_error", "neg_mean_squared_error"],
)
print(f"""Test error (cross-validated performance){linear_regression.__class__.__name__}:MAE = {-cv_results_lr["test_neg_mean_absolute_error"].mean():.3f}MSE = {-cv_results_lr["test_neg_mean_squared_error"].mean():.3f}{quantile_regression.__class__.__name__}:MAE = {-cv_results_qr["test_neg_mean_absolute_error"].mean():.3f}MSE = {-cv_results_qr["test_neg_mean_squared_error"].mean():.3f}"""
)#Test error (cross-validated performance)
#    LinearRegression:
#    MAE = 1.732
#    MSE = 6.690
#    QuantileRegressor:
#    MAE = 1.679
#    MSE = 7.129

keras 定义pinball loss 示例

# multiple quantiles损失定义示例
def quantile_regression_loss0(y_true, y_pred, qs=[0.025, 0.1, 0.5, 0.9, 0.975]):q = tf.constant(np.array([qs]), dtype=tf.float32)e = y_true - y_predv = tf.maximum(q*e, (q-1)*e)return K.mean(v)# 如果是同时预测M个multiple quantiles
def quantile_regression_loss(y_true, y_pred, taus=tf.constant([0.025, 0.1, 0.5, 0.9, 0.975])):"""Function that computes the quantile regression lossArguments:y_pred : Shape (B x M x N) model regression predictionsy_true : Shape (B x M) ground truth targetstaus : Shape (N, ) Vector of used quantilesReturns:loss (float): value of quantile regression loss"""# 这里是因为y_true 定义的是只有一个值,如果与taus的个数一样就不要广播了y_true = tf.expand_dims(y_true, 2)# print(y_pred.shape, y_true.shape)iy = tf.broadcast_to(y_true, tf.shape(y_pred))  # 这里使用y_hat.shape会报错 ValueError: Tried to convert 'shape' to a tensor and failed. Error: Cannot convert a partially known TensorShape (None, 3, 5) to a Tensor.error = (iy - y_pred)loss = tf.maximum(taus * error, (taus - 1.) * error)return K.mean(loss)

参考资料:

  1. scikit-learn 相关的文档: 1, 2.
  2. kaggle 实现分位数回归

分位数回归(Quantile regression)笔记相关推荐

  1. R语言分位数回归Quantile Regression分析租房价格

    全文链接:http://tecdat.cn/?p=18422 本文想在R软件中更好地了解分位数回归优化.在查看分位数回归之前,让我们从样本中计算中位数或分位数(点击文末"阅读原文" ...

  2. 【regression】分位数回归 quantile regression

    quantile regression --python实现 前言 分位数回归可调用的库 1. scikit-learn 2. statsmodels quantile loss function - ...

  3. Stata分位数回归I:理解边际效应和条件边际效应

    全文阅读:Stata分位数回归I:理解边际效应和条件边际效应| 连享会主页 目录 1. 简介 2. 从线性回归模型开始 3. 三种边际效应解释 3.1 个体效应--对 "我" 来说 ...

  4. R中怎么做加权最小二乘_Stata+R:分位数回归一文读懂

    NEW!连享会·推文专辑: Stata资源 | 数据处理 | Stata绘图 | Stata程序 结果输出 | 回归分析 | 时间序列 | 面板数据 | 离散数据 交乘调节 | DID | RDD   ...

  5. 用R语言的quantreg包进行分位数回归

    什么是分位数回归 分位数回归(Quantile Regression)是计量经济学的研究前沿方向之一,它利用解释变量的多个分位数(例如四分位.十分位.百分位等)来得到被解释变量的条件分布的相应的分位数 ...

  6. R构建分位数回归模型(Quantile Regression)

    R构建分位数回归模型(Quantile Regression) 目录 R构建分位数回归模型(Quantile Regression) 数据集 分位数回归模型

  7. 多项式回归、分位数回归(Quantile Regression)、保序回归(Isotonic Regression)、RANSAC回归、核岭回归、基准回归模型(baseline)

    多项式回归.分位数回归(Quantile Regression).保序回归(Isotonic Regression).RANSAC回归.核岭回归.基准回归模型(baseline) 目录

  8. R语言构建分位数回归(Quantile Regression)并计算R方指标实战

    R语言构建分位数回归(Quantile Regression)并计算R方指标实战 目录 R语言构建分位数回归(Quantile Regression)并计算R方指标实战 R方指标 调整的R方指标

  9. 分位数回归(Quantile Regression)代码解析

    实验代码 本文采用python sklearn库中,作为quantile regression的示例代码.以下为详细解析: import numpy as np import matplotlib.p ...

最新文章

  1. 9. 混合模型和EM(3)
  2. wcf 远程终结点已终止该序列 可靠会话出错
  3. webpack使用教程
  4. vue - blog开发学7
  5. java依赖和约束有啥区别_Java – Maven依赖关系太多了
  6. 【渝粤题库】陕西师范大学292961 会计学 作业 (高起专)
  7. 协议簇:TCP 解析: Sequence Number
  8. django使用mysql原始语句,Django中使用mysql数据库并使用原生sql语句操作
  9. mysql查询结果插原表_新建表需要原表的数据,mysql 如何把查询到的结果插入到新表中...
  10. 系统架构设计_分布式、服务化的ERP系统架构设计
  11. MATLAB中的命令行输出
  12. python的ctypes模块详解数组_如何使用Python的ctypes和readinto读取包含数组的结构?...
  13. 编译原理初学者入门指南
  14. argis怎么关掉对象捕捉_ArcGIS ArcMap编辑捕捉教程
  15. 泛微e9隐藏明细表_泛微协同 泛微OA e-cology产品功能清单 模块列表
  16. 区块链入门教程(1)--概述
  17. IBM InfoSphere Optim数据增长解决方案:在Optim归档文件上启用安全性
  18. 2022-2028全球针织捆包网行业调研及趋势分析报告
  19. 解决命令提示符已被系统管理员停用的问题
  20. elasticsearch 使用词干提取器处理英语语言

热门文章

  1. 什么是人脸识别,人脸识别的主要分为哪几步?
  2. GPIO的8种工作模式——基于STM32F767IGT6
  3. camunda modeler 汉化方法
  4. sklearn决策树(Decision Trees)模型
  5. Java计算机毕业设计电子竞技赛事管理系统源码+系统+数据库+lw文档
  6. nginx配置错误页面,处理tomat版本号泄露问题
  7. Linux(Centos7) 运行脚本程序,终端只返回 “已杀死”
  8. 【Visual c++】+【EasyX】游戏组件1 移动的小人
  9. mysql修改user表密码_修改MySQL数据库中表的用户名和密码
  10. 【梳理】简明操作系统原理:银行家算法(内附文档高清截图)