Stacking

简述

主要的三类集成学习方法为Bagging、Boosting和Stacking。目前,大型的数据挖掘比赛(如Kaggle),排名靠前的基本上都是集成机器学习模型或者深度神经网络。

将训练好的所有基模型对整个训练集进行预测,第jjj个基模型对第i个训练样本的预测值将作为新的训练集中第iii个样本的第jjj个特征值,最后基于新的训练集进行训练。同理,预测的过程也要先经过所有基模型的预测形成新的测试集,最后再对测试集进行预测。

具体原理讲解参考这篇博客,简单来说,集成学习其实都是将基本模型组合形成更优秀的模型,Stacking也不例外。stacking是得到各个算法训练全样本的结果再用一个元算法融合这些结果,它可以选择使用网格搜索和交叉验证。

Mlxtend框架

众所周知,如今传统机器学习领域的库基本上被sciket-learn(sklearn)占领,若你没有使用过sklearn库,那就不能称为使用过机器学习算法进行数据挖掘。但是,自定义集成学习库依然没有什么太过主流的框架,sklearn也只是实现了一些比较主流的集成学习方法如随机森林(RF)、AdaBoost等。当然,这也是因为bagging和boosting可以直接调用而stacking需要自己设计。

Mlxtend完美兼容sklearn,可以使用sklearn的模型进行组合生成新模型。它同时集成了stacking分类和回归模型以及它们的交叉验证的版本。由于已经有很多stacking的分类介绍,本例以回归为例讲讲stacking的回归实现。

Mlxtend安装

pip install mlxtend

官方文档

地址

项目实战

stacking回归

stacking回归是一种通过元回归器(meta-regressor)组合多个回归模型(lr,svr等)的集成学习技术。而且,每个基回归模型(就是上述的多个回归模型)在训练时都要使用完整训练集,集成学习过程中每个基回归模型的输出作为元特征成为元回归器的输入,元回归器通过拟合这些元特征来组合多个模型。

使用StackingRegressor

简单使用stacking模型预测波士顿房价(使用经典波士顿房价数据集)由于大数据集需要精细调参,这里简单使用100条数据进行回归测试。

from mlxtend.regressor import StackingRegressor
from mlxtend.data import boston_housing_data
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as pltx, y = boston_housing_data()
x = x[:100]
y = y[:100]
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
# 初始化基模型
lr = LinearRegression()
svr_lin = SVR(kernel='linear', gamma='auto')
svr_rbf = SVR(kernel='rbf', gamma='auto')
ridge = Ridge(random_state=2019)
models = [lr, svr_lin, svr_rbf, ridge]print('base model')
for model in models:model.fit(x_train, y_train)pred = model.predict(x_test)print("loss is {}".format(mean_squared_error(y_test, pred)))
sclf = StackingRegressor(regressors=models, meta_regressor=ridge)
# 训练回归器
sclf.fit(x_train, y_train)
pred = sclf.predict(x_test)print('stacking model')
print("loss is {}".format(mean_squared_error(y_test, pred)))
plt.scatter(np.arange(len(pred)), pred)
plt.plot(np.arange(len(y_test)), y_test)
plt.show()


可以看到stacking模型的一般预测准确率是高于所有基模型的。

对stacking模型网格搜索调参

这里仍然使用上一个案例的模型,下面是代码及结果。

from mlxtend.regressor import StackingRegressor
from mlxtend.data import boston_housing_data
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as pltx, y = boston_housing_data()
x = x[:100]
y = y[:100]
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
# 初始化基模型
lr = LinearRegression()
svr_lin = SVR(kernel='linear', gamma='auto')
svr_rbf = SVR(kernel='rbf', gamma='auto')
ridge = Ridge(random_state=2019,)
models = [lr, svr_lin, svr_rbf, ridge]params = {'ridge__alpha': [0.1, 1.0, 10.0],}
sclf = StackingRegressor(regressors=models, meta_regressor=ridge)
# 训练回归器
grid = GridSearchCV(estimator=sclf, param_grid=params, cv=5, refit=True)
grid.fit(x_train, y_train)
print(grid.best_params_, grid.best_score_)

使用StackingCVRegressor

mlxtend.regressor中的StackingCVRegressor是一种集成学习元回归器。StackingCVRegressor扩展了标准Stacking算法(在mlxtend中的实现为StackingRegressor)。在标准Stacking算法中,拟合一级回归器的时候,我们如果使用与第二级回归器的输入的同一个训练集,这很可能会导致过拟合。 然而,StackingCVRegressor使用了"非折叠预测"的概念:数据集被分成k个折叠,并且在k个连续的循环中,使用k-1折来拟合第一级回归器,其实也就是k折交叉验证的StackingRegressor。在K轮中每一轮中,一级回归器被应用于在每次迭代中还未用于模型拟合的剩余1个子集。然后将得到的预测叠加起来并作为输入数据提供给二级回归器。在StackingCVRegressor的训练完成之后,一级回归器拟合整个数据集以获得最佳预测。

from mlxtend.regressor import StackingCVRegressor
from mlxtend.data import boston_housing_data
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as pltx, y = boston_housing_data()
x = x[:100]
y = y[:100]
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
# 初始化基模型
lr = LinearRegression()
svr_lin = SVR(kernel='linear', gamma='auto')
ridge = Ridge(random_state=2019,)
lasso =Lasso()
models = [lr, svr_lin, ridge, lasso]print("base model")
for model in models:score = cross_val_score(model, x_train, y_train, cv=5)print(score.mean(), "+/-", score.std())
sclf = StackingCVRegressor(regressors=models, meta_regressor=lasso)
# 训练回归器
print("stacking model")
score = cross_val_score(sclf, x_train, y_train, cv=5)
print(score.mean(), "+/-", score.std())sclf.fit(x_train, y_train)
pred = sclf.predict(x_test)
print("loss is {}".format(mean_squared_error(y_test, pred)))

)

可以看到,对比第一次使用StackingRegressor模型的损失降低了。(尽管由于调参问题,评分没有基回归器高)

使用StackingCVRegressor网格搜索

代码及结果如下。

from mlxtend.regressor import StackingCVRegressor
from mlxtend.data import boston_housing_data
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.svm import SVR
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as pltx, y = boston_housing_data()
x = x[:100]
y = y[:100]
# 划分数据集
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
# 初始化基模型
lr = LinearRegression()
svr_lin = SVR(kernel='linear', gamma='auto')
ridge = Ridge(random_state=2019,)
lasso =Lasso()
models = [lr, svr_lin, ridge, lasso]params = {'lasso__alpha': [0.1, 1.0, 10.0],'ridge__alpha': [0.1, 1.0, 10.0]}sclf = StackingCVRegressor(regressors=models, meta_regressor=ridge)
grid = GridSearchCV(estimator=sclf, param_grid=params, cv=5, refit=True)
grid.fit(x_train, y_train)
print(grid.best_score_, grid.best_params_)

补充说明

本文主要介绍了框架Mlxtend的使用 ,具体的API函数见上面提到的官方文档。Stacking等集成模型可以说是大型数据挖掘比赛的利器。本文涉及到的具体代码见我的Github,欢迎Star或者Fork。

机器学习-Stacking方法的原理及实现相关推荐

  1. 诗人般的机器学习,ML工作原理大揭秘

    诗人般的机器学习,ML工作原理大揭秘 https://www.cnblogs.com/DicksonJYL/p/9698208.html 选自arXiv 作者:Cassie Kozyrkov 机器之心 ...

  2. 新闻上的文本分类:机器学习大乱斗 王岳王院长 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处

    新闻上的文本分类:机器学习大乱斗 王岳王院长 5 个月前 目标 从头开始实践中文短文本分类,记录一下实验流程与遇到的坑 运用多种机器学习(深度学习 + 传统机器学习)方法比较短文本分类处理过程与结果差 ...

  3. [机器学习]正则化方法 -- Regularization

    一.参数方法和非参数方法 在讲正则化之前,需要介绍2个概念.机器学习的方法,可以大致分成两类. 参数方法(Parametric Methods) 通过训练来确定一组参数.当我们参数的值定下来,可以说预 ...

  4. 【机器学习】Weighted LSSVM原理与Python实现:LSSVM的稀疏化改进

    [机器学习]Weighted LSSVM原理与Python实现:LSSVM的稀疏化改进 一.LSSVM 1.LSSVM用于回归 2.LSSVM模型的缺点 二.WLSSVM的数学原理 三.WLSSVM的 ...

  5. 第七十四篇:机器学习优化方法及超参数设置综述

    第七十四篇:机器学习优化方法及超参数设置综述 置顶 2019-08-25 23:03:44 廖佳才 阅读数 207更多 分类专栏: 深度学习 版权声明:本文为博主原创文章,遵循 CC 4.0 BY-S ...

  6. 熟练掌握R语言的Meta分析全流程和不确定性分析,并结合机器学习等方法讲解Meta分析在文献大数据的延伸应用

    Meta分析是针对某一科研问题,根据明确的搜索策略.选择筛选文献标准.采用严格的评价方法,对来源不同的研究成果进行收集.合并及定量统计分析的方法,最早出现于"循证医学",现已广泛应 ...

  7. 机器学习 鸢尾花分类的原理和实现(一)

    机器学习 鸢尾花分类的原理和实现(一) 前言: 鸢尾花数据集是机器学习中的经典小规模数据集.通过查阅资料和视频进行学习,将整个实验的学习心得和实验过程分享,希望对喜爱机器学习并入门的新手提供帮助,同时 ...

  8. 用机器学习的方法鉴别红楼梦作者

    为什么80%的码农都做不了架构师?>>>    在学界一般认为,<红楼梦>后 40 回并非曹雪芹所著.本文尝试应用机器学习的方法来分析原著文本中作者的用词习惯,从技术角度 ...

  9. 机器学习 — K-Means、K-Means++ 原理及算法实现

    文章目录 机器学习 - K-Means.K-Means++ 原理及算法实现 一.K-Means 1. 概念 2. K-Means算法思想 3. K-Means特点 4.K-Means算法实现 二.K- ...

最新文章

  1. 【 FPGA 】FIR 滤波器之固定分数率重采样滤波器
  2. 交通预测论文笔记《Attention Based Spatial-Temporal Graph Convolutional Networks for Traffic Flow Forecasting》
  3. mysql b-a全局索引_MySQL中B+树索引的使用
  4. 大学生学编程系列」第五篇:自学编程需要多久才能找到工作?
  5. 服务端和客户端测试连通ip设置记录
  6. 【Python】浮点数计算时的不准确性以及如何进行精确计算
  7. Linux下编译运行Go程序
  8. 定时任务crontab
  9. ubuntu文件系统知识
  10. C# 使用 NPOI操作excle文件(读取与新建重写)
  11. C++11 关键字noexcept
  12. c语言自学基础知识视频,C语言 基础课堂视频教程
  13. 设备\Device\Harddisk1\DR1 有一个不正确的区块
  14. 计算机无法识别 手机,手机连接电脑无法识别usb设备的解决教程
  15. 微信小程序上传代码, Error: 分包大小超过限制,main package source size 4732KB exceed max limit 2MB
  16. java ad域 单点登录_系统集成-SSO微软ADSF单点认证-AD域认证
  17. jQuery获取元素定位位置:给td添加选中样式
  18. 利用卷积神经网络实现手写字识别
  19. jpush android 离线推送,JPush极光推送3分钟搞定Android Push
  20. SVG SMIL 动画(基本动画 、变换动画)

热门文章

  1. MyBatis 源码解读-databaseIdProviderElement()
  2. 代码演示:先来后到的特例、优劣、源码分析
  3. 权限基本操作:实体类和dao
  4. 定时任务四种实现方式
  5. java无忧网_零基础java入门课程 - 学途无忧网 - 做技术的王者 - Powered By EduSoho
  6. 我什么计算机作文600字,我家的电脑作文600字
  7. C#中所有对象共同的基类是System.Object
  8. MySQL 复制 - 性能与扩展性的基石:概述及其原理
  9. 事件监听一直报错Cannot set property 'display' of undefined
  10. 当网页太多时,用锚点 以及超链接的使用