分位数回归(Quantile regression)笔记
分位数回归(Quantile regression)是在给定 X \mathbf{X} X的条件下估计 y \mathbf{y} y的中位数或其他分位数, 这是与最小二乘法估计条件均值最大的不同。
分位数回归也是一种线性回归,它为第 q q q个分位数( q ∈ ( 0 , 1 ) q\in (0, 1) q∈(0,1))训练得到线性预测 y ^ ( w , X ) = X w \hat{y}(w, \mathbf{X})=\mathbf{Xw} y^(w,X)=Xw, 权重 w w w通过最小化下面的公式得到
min w 1 n samples ∑ i P B q ( y i − X i w ) + α ∣ ∣ w ∣ ∣ 1 . \min_{w} {\frac{1}{n_{\text{samples}}} \sum_i PB_q(y_i - X_i w) + \alpha ||w||_1}. wminnsamples1i∑PBq(yi−Xiw)+α∣∣w∣∣1.
其中的 P B PB PB 是pinball loss(也被称为linear loss), 定义如下式, α \alpha α 调整L1损失的大小。
P B q ( t ) = q max ( t , 0 ) + ( 1 − q ) max ( − t , 0 ) = { q t , t > 0 0 , t = 0 ( q − 1 ) t , t < 0 P B_q(t)=q \max (t, 0)+(1-q) \max (-t, 0)= \begin{cases}q t, & t>0 \\ 0, & t=0 \\ (q-1) t, & t<0\end{cases} PBq(t)=qmax(t,0)+(1−q)max(−t,0)=⎩ ⎨ ⎧qt,0,(q−1)t,t>0t=0t<0
分位数回归的特点:
- 分位数回归对于异常点没有那么敏感
- 对于数据也不要求完全符合正态分布,不用假设数据的分布服从方差固定的分布。
- 当我们的预测结果是一个区间而不是一个点的时候更有用
- 分位数回归相比于线性回归减少的是MAE
当然相比于普通的线性回归,分位数要求更多的训练数据,同时也比线性回归的计算更复杂。
在下面的图片是对加了符合帕累托分布( Pareto Distribution)噪声的数据进行分位数回归学习的结果
相应的示例代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.utils.fixes import sp_version, parse_version
from sklearn.linear_model import QuantileRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import cross_validate# 为了避免老版本SciPy的不兼容问题
solver = "highs" if sp_version >= parse_version("1.6.0") else "interior-point"# 生成样本
rng = np.random.RandomState(42)
x = np.linspace(start=0, stop=10, num=100)
X = x[:, np.newaxis]
y_true_mean = 10 + 0.5 * x# 生成满足pareto distribution的数据
y_pareto = y_true_mean + 10 * (rng.pareto(a, size=x.shape[0]) - 1 / (a - 1))# 生成分位数为0.05, 0.5, 0.95的分位数回归
quantiles = [0.05, 0.5, 0.95]
predictions = {}
out_bounds_predictions = np.zeros_like(y_true_mean, dtype=np.bool_)
for quantile in quantiles:qr = QuantileRegressor(quantile=quantile, alpha=0, solver=solver)y_pred = qr.fit(X, y_pareto).predict(X)predictions[quantile] = y_predif quantile == min(quantiles):out_bounds_predictions = np.logical_or(out_bounds_predictions, y_pred >= y_pareto)elif quantile == max(quantiles):out_bounds_predictions = np.logical_or(out_bounds_predictions, y_pred <= y_pareto)plt.plot(X, y_true_mean, color="black", linestyle="dashed", label="True mean")for quantile, y_pred in predictions.items():plt.plot(X, y_pred, label=f"Quantile: {quantile}")plt.scatter(x[out_bounds_predictions],y_pareto[out_bounds_predictions],color="black",marker="+",alpha=0.5,label="Outside interval",
)
plt.scatter(x[~out_bounds_predictions],y_pareto[~out_bounds_predictions],color="black",alpha=0.5,label="Inside interval",
)
plt.legend()
plt.xlabel("x")
plt.ylabel("y")
_ = plt.title("Quantiles of asymmetric Pareto distributed target")# 使用交叉验证比较线性回归与分位数回归
linear_regression = LinearRegression()
quantile_regression = QuantileRegressor(quantile=0.5, alpha=0, solver=solver)cv_results_lr = cross_validate(linear_regression,X,y_pareto,cv=3,scoring=["neg_mean_absolute_error", "neg_mean_squared_error"],
)
cv_results_qr = cross_validate(quantile_regression,X,y_pareto,cv=3,scoring=["neg_mean_absolute_error", "neg_mean_squared_error"],
)
print(f"""Test error (cross-validated performance){linear_regression.__class__.__name__}:MAE = {-cv_results_lr["test_neg_mean_absolute_error"].mean():.3f}MSE = {-cv_results_lr["test_neg_mean_squared_error"].mean():.3f}{quantile_regression.__class__.__name__}:MAE = {-cv_results_qr["test_neg_mean_absolute_error"].mean():.3f}MSE = {-cv_results_qr["test_neg_mean_squared_error"].mean():.3f}"""
)#Test error (cross-validated performance)
# LinearRegression:
# MAE = 1.732
# MSE = 6.690
# QuantileRegressor:
# MAE = 1.679
# MSE = 7.129
keras 定义pinball loss 示例
# multiple quantiles损失定义示例
def quantile_regression_loss0(y_true, y_pred, qs=[0.025, 0.1, 0.5, 0.9, 0.975]):q = tf.constant(np.array([qs]), dtype=tf.float32)e = y_true - y_predv = tf.maximum(q*e, (q-1)*e)return K.mean(v)# 如果是同时预测M个multiple quantiles
def quantile_regression_loss(y_true, y_pred, taus=tf.constant([0.025, 0.1, 0.5, 0.9, 0.975])):"""Function that computes the quantile regression lossArguments:y_pred : Shape (B x M x N) model regression predictionsy_true : Shape (B x M) ground truth targetstaus : Shape (N, ) Vector of used quantilesReturns:loss (float): value of quantile regression loss"""# 这里是因为y_true 定义的是只有一个值,如果与taus的个数一样就不要广播了y_true = tf.expand_dims(y_true, 2)# print(y_pred.shape, y_true.shape)iy = tf.broadcast_to(y_true, tf.shape(y_pred)) # 这里使用y_hat.shape会报错 ValueError: Tried to convert 'shape' to a tensor and failed. Error: Cannot convert a partially known TensorShape (None, 3, 5) to a Tensor.error = (iy - y_pred)loss = tf.maximum(taus * error, (taus - 1.) * error)return K.mean(loss)
参考资料:
- scikit-learn 相关的文档: 1, 2.
- kaggle 实现分位数回归
分位数回归(Quantile regression)笔记相关推荐
- R语言分位数回归Quantile Regression分析租房价格
全文链接:http://tecdat.cn/?p=18422 本文想在R软件中更好地了解分位数回归优化.在查看分位数回归之前,让我们从样本中计算中位数或分位数(点击文末"阅读原文" ...
- 【regression】分位数回归 quantile regression
quantile regression --python实现 前言 分位数回归可调用的库 1. scikit-learn 2. statsmodels quantile loss function - ...
- Stata分位数回归I:理解边际效应和条件边际效应
全文阅读:Stata分位数回归I:理解边际效应和条件边际效应| 连享会主页 目录 1. 简介 2. 从线性回归模型开始 3. 三种边际效应解释 3.1 个体效应--对 "我" 来说 ...
- R中怎么做加权最小二乘_Stata+R:分位数回归一文读懂
NEW!连享会·推文专辑: Stata资源 | 数据处理 | Stata绘图 | Stata程序 结果输出 | 回归分析 | 时间序列 | 面板数据 | 离散数据 交乘调节 | DID | RDD ...
- 用R语言的quantreg包进行分位数回归
什么是分位数回归 分位数回归(Quantile Regression)是计量经济学的研究前沿方向之一,它利用解释变量的多个分位数(例如四分位.十分位.百分位等)来得到被解释变量的条件分布的相应的分位数 ...
- R构建分位数回归模型(Quantile Regression)
R构建分位数回归模型(Quantile Regression) 目录 R构建分位数回归模型(Quantile Regression) 数据集 分位数回归模型
- 多项式回归、分位数回归(Quantile Regression)、保序回归(Isotonic Regression)、RANSAC回归、核岭回归、基准回归模型(baseline)
多项式回归.分位数回归(Quantile Regression).保序回归(Isotonic Regression).RANSAC回归.核岭回归.基准回归模型(baseline) 目录
- R语言构建分位数回归(Quantile Regression)并计算R方指标实战
R语言构建分位数回归(Quantile Regression)并计算R方指标实战 目录 R语言构建分位数回归(Quantile Regression)并计算R方指标实战 R方指标 调整的R方指标
- 分位数回归(Quantile Regression)代码解析
实验代码 本文采用python sklearn库中,作为quantile regression的示例代码.以下为详细解析: import numpy as np import matplotlib.p ...
最新文章
- 9. 混合模型和EM(3)
- wcf 远程终结点已终止该序列 可靠会话出错
- webpack使用教程
- vue - blog开发学7
- java依赖和约束有啥区别_Java – Maven依赖关系太多了
- 【渝粤题库】陕西师范大学292961 会计学 作业 (高起专)
- 协议簇:TCP 解析: Sequence Number
- django使用mysql原始语句,Django中使用mysql数据库并使用原生sql语句操作
- mysql查询结果插原表_新建表需要原表的数据,mysql 如何把查询到的结果插入到新表中...
- 系统架构设计_分布式、服务化的ERP系统架构设计
- MATLAB中的命令行输出
- python的ctypes模块详解数组_如何使用Python的ctypes和readinto读取包含数组的结构?...
- 编译原理初学者入门指南
- argis怎么关掉对象捕捉_ArcGIS ArcMap编辑捕捉教程
- 泛微e9隐藏明细表_泛微协同 泛微OA e-cology产品功能清单 模块列表
- 区块链入门教程(1)--概述
- IBM InfoSphere Optim数据增长解决方案:在Optim归档文件上启用安全性
- 2022-2028全球针织捆包网行业调研及趋势分析报告
- 解决命令提示符已被系统管理员停用的问题
- elasticsearch 使用词干提取器处理英语语言
热门文章
- 什么是人脸识别,人脸识别的主要分为哪几步?
- GPIO的8种工作模式——基于STM32F767IGT6
- camunda modeler 汉化方法
- sklearn决策树(Decision Trees)模型
- Java计算机毕业设计电子竞技赛事管理系统源码+系统+数据库+lw文档
- nginx配置错误页面,处理tomat版本号泄露问题
- Linux(Centos7) 运行脚本程序,终端只返回 “已杀死”
- 【Visual c++】+【EasyX】游戏组件1 移动的小人
- mysql修改user表密码_修改MySQL数据库中表的用户名和密码
- 【梳理】简明操作系统原理:银行家算法(内附文档高清截图)