最近开发公司的机器学习平台的XGBoost控件。结果报了一个bug,说“feature_names mismatch”。

现在我们来复现这个bug:

import xgboost as xgb
import pandas as pd
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_splitX, y = make_classification(n_features=4)
X_train, X_test, y_train, y_test = train_test_split(X, y)
features = ['a1', 'a2', 'a3', 'a4']
df = pd.DataFrame(data=X_train, columns=features)
df['y'] = y_trainX_train = df[['a1', 'a2', 'a3', 'a4']]
y_train = df['y']
model = xgb.XGBClassifier()
model.fit(X_train, y_train)

以上代码随机生成了一个分类数据集,它的特征的名字是a1, a2, a3, a4。我们用pandas的数据集,而不是numpy的数列,传给XGB的fit函数来训练模型。这里还没有出现bug。

然而在预测的时候,出现了bug。

model.predict(X_test)

bug的描述如下:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-23-decd67bbf386> in <module>
----> 1 model.predict(X_test)C:\ProgramData\Anaconda3\lib\site-packages\xgboost\sklearn.py in predict(self, data, output_margin, ntree_limit, validate_features, base_margin)882         if ntree_limit is None:883             ntree_limit = getattr(self, "best_ntree_limit", 0)
--> 884         class_probs = self.get_booster().predict(885             test_dmatrix,886             output_margin=output_margin,C:\ProgramData\Anaconda3\lib\site-packages\xgboost\core.py in predict(self, data, output_margin, ntree_limit, pred_leaf, pred_contribs, approx_contribs, pred_interactions, validate_features, training)1569 1570         if validate_features:
-> 1571             self._validate_features(data)1572 1573         length = c_bst_ulong()C:\ProgramData\Anaconda3\lib\site-packages\xgboost\core.py in _validate_features(self, data)2128                             ', '.join(str(s) for s in my_missing))2129
-> 2130                 raise ValueError(msg.format(self.feature_names,2131                                             data.feature_names))2132 ValueError: feature_names mismatch: ['a1', 'a2', 'a3', 'a4'] ['f0', 'f1', 'f2', 'f3']
expected a1, a3, a2, a4 in input data
training data did not have the following fields: f2, f1, f3, f0

遇到Bug,首先就是去stackoverflow上面找答案,很快找到了解决方案。可以在模型预测的时候加上validate_features=False

model.predict(X_test, validate_features=False)

这里问题来了。平时如果我们自己建模,到这里,问题就解决了。但是在机器学习平台中,预测是另外一个组件。如果我改了预测组件的代码,可能会影响到很多地方,可能会引起更多的bug。

于是,有的人可能会这样做:

if is_xgb:model.predict(X_test, validate_features=False)
else:model.predict(X_test)

我一开始学编程的时候,也是这样解决问题的。但其实还有更好的办法,一个可以不修改预测组件的办法。

我们看到,XGB和其他算法的【行为不一致】。其他算法,包括sklearn里面的主要算法,我们可以用pandas数据集训练模型,然后用numpy数列做预测。整个系统也是基于sklearn里面的模型的这种行为设计的。可是偏偏XGB的行为方式和大家不一样。在一个系统当中,应该要求每个模块都遵循一定的规则。这个就是为什么在java和其他语言中,我们会使用继承和接口。所以,我们这里要做的是,修改XGB的predict函数,在不给定validate_features=False的情况下,也实现一样的功能。所以,这里要怎么办呢?

有人会说,可以修改XGB的源代码。也可以,就是难度太大。有没有简单一点的方法?

(在我继续之前,我特地放了一张一休的图片,让大家闭上眼睛想一想。)

我提醒你一下,我们在连接投影仪的时候,经常碰到电脑显卡的插口和投影仪的不一致。这个时候怎么办?是不是可以用一个名字叫adapter的中间设备来转换一下?

我的方法,就是用一个adapter类,套在原始的XGBClassifier的外面。用新类的predict调用老类的predict,在调用的时候,加上validate_features=False。具体如下:

class XGBClassifierAdapter():model = Nonedef __init__(self, **params):self.model = xgb.XGBClassifier(**params)def fit(self, X, y):self.model.fit(X, y)def predict(self, X):return self.model.predict(X, validate_features=False)

这时,我们用这个新类去替换老类,就不会报错了。

X, y=make_classification(n_features=4)
X_train, X_test, y_train, y_test = train_test_split(X, y)
features=['a1', 'a2', 'a3', 'a4']
df=pd.DataFrame(data=X_train, columns=features)
df['y']=y_trainX_train = df[['a1', 'a2', 'a3', 'a4']]
y_train = df['y']
model = XGBClassifierAdapter()
model.fit(X_train, y_train)
model.predict(X_test)

这种方法,有个学名,叫adapter,是一种设计模式。它将一个类的接口转换成另外一个接口,使得原本由于接口不兼容而不能一起工作的那些类可以一起工作。

(这里转化前后的函数名字都叫predict,但是他们其实不一样。)

也就是说XGBClassifierAdepter要满足两个条件。第一它要改变XGBClassifier的接口,那么它就要依赖于这个类。另外,它又要和其他的sklearn的类接口要一致,或者说他们要具有共同的父类/接口,我们这里想象一个接口叫Classifier,那么XGBClassifierAdepter也要同时继承这个父类/接口。

机器学习平台系列——XGB feature_names mismatch 问题解决方案相关推荐

  1. 机器学习平台系列(六) - 再探 Jupyter Lab:在 CentOS 下制作 Docker 镜像

    文章目录 1.环境版本 2.准备工作 2.1 安装 Docker 2.2.上传 Anaconda3 3.制作镜像 3.1 拉取镜像 3.2 安装 Anaconda 3.3 安装 Jupyter Lab ...

  2. 美团十年,支撑最大规模外卖配送的一站式机器学习平台如何炼成?

    作者 | 艳伟,美团配送技术团队资深技术专家 编辑 | 唐小引 题图 | 东方 IC AI 是目前互联网行业炙手可热的"明星",无论是老牌巨头,还是流量新贵,都在大力研发 AI 技 ...

  3. 一站式机器学习平台建设实践

    本文根据美团配送资深技术专家郑艳伟在2019 SACC(中国系统架构师大会)上的演讲内容整理而成,主要介绍了美团配送技术团队在建设一站式机器学习平台过程中的经验总结和探索,希望对从事此领域的同学有所帮 ...

  4. 美团十年,支撑全球最大规模外卖配送的一站式机器学习平台是如何炼成的?...

    作者 | 艳伟,美团配送技术团队资深技术专家 责编 | 唐小引 封图 | CSDN 下载自东方 IC 本文为美团技术团队投稿 AI 是目前互联网行业炙手可热的"明星",无论是老牌巨 ...

  5. cube云原生机器学习平台-架构(三)

    全栈工程师开发手册 (作者:栾鹏) 一站式云原生机器学习平台 前言:cube是开源的云原生机器学习平台,目前包含特征平台,支持在/离线特征:数据源管理,支持结构数据和媒体标注数据管理:在线开发,在线的 ...

  6. 京东到家机器学习平台建设

    文|巩学超/戴枫/魏铮 编辑|刘慧卿/闫文广 目录 前言 机器学习平台总体架构 模型训练平台 特征模型管理平台 在线模型预测服务 算法应用实践 总结和展望 1. 前言 京东到家作为行业领先的即时零售平 ...

  7. 如何建设机器学习平台

    00. 平台的业务 从平台这个概念本身来说,它提供的是支撑作用,通过整合.管理不同的基础设施.技术框架,一些通用的流程规范来形成一个通用的.易用的GUI来给用户使用.通用性是它的考量之一.也是所有平台 ...

  8. Large Scale Machine Learning--An Engineering Perspective--1. 大规模机器学习平台的构成

    机器学习/数据挖掘在各种业务场景中的应用已经非常之多了,在线广告/搜索/商品推荐/风险建模/图像处理/语音识别/机器翻译都是机器学习成功应用的典型case. 有效应用机器学习解决业务问题,在我看来依赖 ...

  9. 谷歌Cloud AutoML自动机器学习平台初步研究

    一.AutoML背景 机器学习(Machine Learning, ML)技术近年来已取得较大成功,越来越多行业领域依赖它.但目前成功的关键还需依赖人类机器学习工程师完成以下工作: 预处理数据 选择适 ...

最新文章

  1. 推荐几个开源类库,超好用,远离996!
  2. 用户系统-开放平台的一些思考
  3. 对称矩阵、Hermite矩阵、正交矩阵、酉矩阵、奇异矩阵、正规矩阵、幂等矩阵
  4. [JZOJ4788] 【NOIP2016提高A组模拟9.17】序列
  5. Linq 数据库操作(增删改查)
  6. 自动驾驶横向运动学分析和非线性问题处理方法
  7. 使用go来做系统,如何比java node php 更 简单
  8. poj 1651区间dp
  9. OpenCV自然场景文本检测(附Python代码)
  10. 20170910算法工程师在线笔试之求第n个丑数
  11. 串口线的交叉直连之痛
  12. 《自控力》第七章读书笔记
  13. 高通820(msm8996)camera hal源码分析
  14. Conflicted Confucians
  15. Valley Numer
  16. angular async和await (实用)
  17. 神盾加密php文件夹,[宜配屋]听图阁
  18. 谷歌、Microsoft、火狐浏览器主页被篡改解决方法
  19. 员工激励对组织绩效的影响
  20. win7将mysql换个盘_win7系统镜像文件包win7系统驱动光盘

热门文章

  1. A type incompatibility occurred while executing org.springframework.boot:spring-boot-maven-plugin:2.
  2. 失落的帝国:盛大业务大收缩
  3. 苦练基本功《如何阅读看懂一篇Datasheet》
  4. 电子科技大学格拉斯哥学院基础实践————寝室情况及存在问题
  5. Hi3518E音频部分设计
  6. 自然语言处理--MM、RMM算法及Python 复习
  7. IBM ILOG CPLEX Optimization Studio V12.9.0官方文档
  8. 冒泡排序保姆级心得分享
  9. 企业号、企业微信、企业邮箱三者融合,IBOS微信生态掘金之路
  10. PHP系统常量及判断某常量是否被定义