作者:Samuele Mazzanti翻译:欧阳锦
校对:赵茹萱本文约3900字,建议阅读10分钟本文通过实验验证了一个通用模型优于多个专用模型的有效性的结论。

比较专门针对不同群体训练多个 ML 模型与为所有数据训练一个独特模型的有效性。

图源作者

我最近听到一家公司宣称:“我们在生产中有60个流失模型。”(注:流失模型是一种通过数学来建模流失对业务的影响。)我问他们为什么这么多。他们回答说,他们拥有 5 个品牌,在 12 个国家/地区运营,并且由于他们想为每个品牌和国家/地区的组合开发一种模型,因此共计 60 种模型。于是,我问他们:“你试过只用一种模型吗?” 

他们认为这没有意义,因为他们的品牌彼此之间非常不同,他们经营的目标国家也是如此:“你不能训练一个单一的模型并期望它对品牌 A 的美国客户和对品牌 B 的德国客户”。

由于在业内经常听到这样的说法,我很好奇这个论点是否反映在数据中,或者只是没有事实支持的猜测。

这就是为什么在本文中,我将系统地比较两种方法:

  • 将所有数据提供给一个模型,也就是一个通用模型(general model);

  • 为每个细分市场构建一个模型(在前面的示例中,品牌和国家/地区的组合),也就是许多专业模型(specialized models)。

我将在流行的Python库Pycaret提供的12个真实数据集上测试这两种策略。

通用模型与专用模型

这两种方法究竟是如何工作的?

假设我们有一个数据集。数据集由预测变量矩阵(称为X)和目标变量(称为y)组成。此外,X包含一个或多个可用于分割数据集的列(在前面的示例中,这些列是“品牌”和“国家/地区”)。

现在让我们尝试以图形方式表示这些元素。我们可以使用X的其中一列来可视化这些段:每种颜色(蓝色、黄色和红色)标识不同的段。我们还需要一个额外的向量来表示训练集(绿色)和测试集(粉色)的划分。

训练集的行标记为绿色,而测试集的行标记为粉红色。X 的彩色列是分段列:每种颜色标识不同的分段(例如,蓝色是美国,黄色是英国,红色是德国)。图源作者。

鉴于这些要素,以下是这两种方法的不同之处。

第一种策略:通用模型

在整个训练集上拟合一个独特的模型,然后在整个测试集上测量其性能:

通用模型。所有片段(蓝色、黄色和红色)都被馈送到同一个模型。图源作者

第二个策略:专业模型

第二种策略涉及为每个段建立模型,这意味着重复训练/测试过程k次(其中k是片段数,在本例中为 3)。

专用模型。每个段被馈送到不同的模型。[作者图片]

请注意,在实际用例中,分段的数量可能是相关的,从几十个到数百个不等。因此,与使用一个通用模型相比,使用专用模型存在几个实际缺点,例如:

  • 更高的维护工作量;

  • 更高的系统复杂度;

  • 更高的(累积的)培训时间;

  • 更高的计算成本:

  • 更高的存储成本。

那么,为什么会有人想要这样做呢?

对通用模型的偏见

专用模型的支持者声称,独特的通用模型在给定的细分市场(比如美国客户)上可能不太精确,因为它还了解了不同细分市场(例如欧洲客户)的特征。

我认为这是因使用简单模型(例如逻辑回归)而产生的错误认识。让我用一个例子来解释。

假设我们有一个汽车数据集,由三列组成:

  • 汽车类型(经典或现代);

  • 汽车时代;

  • 车价。

我们想使用前两个特征来预测汽车价格。这些是数据点:

具有两个部分(经典汽车和现代汽车)的数据集,显示出与目标变量相关的非常不同的行为。图源作者

如您所见,根据汽车类型,有两种完全不同的行为:随着时间的推移,现代汽车贬值,而老爷车价格上涨。

现在,如果我们在完整数据集上训练线性回归:

linear_regression = LinearRegression().fit(df[[ "car_type_classic" , "car_age" ]], df[ "car_price" ])

得到的系数是:

在数据集上训练的线性回归系数。图源作者

这意味着模型将始终为任何输入预测相同的值12。

通常,如果数据集包含不同的行为(除非您进行额外的特征工程),简单模型将无法正常工作。因此,在这种情况下,人们可能会想训练两种专门的模型:一种用于经典汽车,一种用于现代汽车。

但是让我们看看如果我们使用决策树而不是线性回归会发生什么。为了使比较公平,我们将生成一棵有3个分支的树(即3个决策阈值),因为线性回归也有3个参数(3个系数)。

decision_tree = DecisionTreeRegressor(max_depth= 2 ).fit(df[
[ "car_type_classic" , "car_age" ]], df[ "car_price" ])

这是结果:

在玩具数据集上训练的决策树。[作者图片]

这比我们用线性回归得到的结果要好得多!

关键是基于树的模型(例如 XGBoost、LightGBM 或 Catboost)能够处理不同的行为,因为它们天生就可以很好地处理特征交互。

这就是为什么在理论上没有理由比一个通用模型更喜欢几个专用模型的主要原因。但是,一如既往,我们并不满足于理论解释。我们还想确保这一猜想得到真实数据的支持。

实验细节

在本段中,我们将看到测试哪种策略效果更好所需的 Python 代码。如果您对细节不感兴趣,可以直接跳到下一段,我将在这里讨论结果。

我们的目标是定量比较两种策略:

  • 训练一个通用模型;

  • 训练许多个专用模型。

比较它们的最明显方法如下:

1. 获取数据集;

2. 根据一列的值选择数据集的一部分;

3. 将数据集拆分为训练数据集和测试数据集;

4. 在整个训练数据集上训练通用模型;

5. 在属于该段的训练数据集部分上训练专用模型;

6. 比较通用模型和专用模型在属于该段的测试数据集部分上的性能。

图形化:

X 中的彩色列是我们用来对数据集进行分层的列。[作者图片]

这工作得很好,但是,由于我们不想被随机性愚弄,我们将重复这个过程:

  • 对于不同的数据集;

  • 使用不同的列来分割数据集本身;

  • 使用同一列的不同值来定义段。

换句话说,这就是我们要用伪代码做的:

for each dataset:train general model on the training setfor each column of the dataset:for each value of the column:train specialized model on the portion of the training set for which column = valuecompare performance of general model vs. specialized model

实际上,我们需要对这个过程做一些微小的调整。

首先,我们说过我们正在使用数据集的列来分割数据集本身。这适用于分类列和具有很少值的离散数字列。对于剩余的数字列,我们必须通过分箱(binning)使它们分类。

其次,我们不能简单地使用所有的列。如果我们这样做,我们将会惩罚专用模型。事实上,如果我们根据与目标变量无关的列选择细分,就没有理由相信专门的模型可以表现得更好。为避免这种情况,我们将只使用与目标变量有某种关系的列。

此外,出于类似的原因,我们不会使用所有细分列的值。我们将避免过于频繁(超过50%)的值,因为期望在大多数数据集上训练的模型与在完整数据集上训练的模型具有不同的性能是没有意义的。我们还将避免测试集中少于100个案例的值,因为结果肯定不会很重要。

鉴于此,这是我使用的完整代码:

for dataset_name in tqdm(dataset_names):# get dataX, y, num_features, cat_features, n_classes = get_dataset(dataset_name)# split index in training and test set, then train general model on the training setix_train, ix_test = train_test_split(X.index, test_size=.25, stratify=y)model_general = CatBoostClassifier().fit(X=X.loc[ix_train,:], y=y.loc[ix_train], cat_features=cat_features, silent=True)pred_general = pd.DataFrame(model_general.predict_proba(X.loc[ix_test, :]), index=ix_test, columns=model_general.classes_)# create a dataframe where all the columns are categorical: # numerical columns with more than 5 unique values are binnizedX_cat = X.copy()X_cat.loc[:, num_features] = X_cat.loc[:, num_features].fillna(X_cat.loc[:, num_features].median()).apply(lambda col: col if col.nunique() <= 5 else binnize(col))# get a list of columns that are not (statistically) independent # from y according to chi 2 independence testcandidate_columns = get_dependent_columns(X_cat, y)for segmentation_column in candidate_columns:# get a list of candidate values such that each candidate:# - has at least 100 examples in the test set# - is not more common than 50%vc_test = X_cat.loc[ix_test, segmentation_column].value_counts()nu_train = y.loc[ix_train].groupby(X_cat.loc[ix_train, segmentation_column]).nunique()nu_test = y.loc[ix_test].groupby(X_cat.loc[ix_test, segmentation_column]).nunique()candidate_values = vc_test[(vc_test>=100) & (vc_test/len(ix_test)<.5) & (nu_train==n_classes) & (nu_test==n_classes)].index.to_list()for value in candidate_values:# split index in training and test set, then train specialized model # on the portion of the training set that belongs to the segmentix_value = X_cat.loc[X_cat.loc[:, segmentation_column] == value, segmentation_column].index    ix_train_specialized = list(set(ix_value).intersection(ix_train))ix_test_specialized = list(set(ix_value).intersection(ix_test)) model_specialized = CatBoostClassifier().fit(X=X.loc[ix_train_specialized,:], y=y.loc[ix_train_specialized], cat_features=cat_features, silent=True)pred_specialized = pd.DataFrame(model_specialized.predict_proba(X.loc[ix_test_specialized, :]), index=ix_test_specialized, columns=model_specialized.classes_)# compute roc score of both the general model and the specialized model and save themroc_auc_score_general = get_roc_auc_score(y.loc[ix_test_specialized], pred_general.loc[ix_test_specialized, :])roc_auc_score_specialized = get_roc_auc_score(y.loc[ix_test_specialized], pred_specialized)  results = results.append(pd.Series(data=[dataset_name, segmentation_column, value, len(ix_test_specialized), y.loc[ix_test_specialized].value_counts().to_list(), roc_auc_score_general, roc_auc_score_specialized],index=results.columns),ignore_index=True)

为了便于理解,我省略了一些实用函数的代码,get_dataset例如get_dependent_columns和get_roc_auc_score。但是,您可以在此GitHub存储库中找到完整代码。

结果

为了对通用模型与专用模型进行大规模比较,我使用了Pycaret(MIT许可下的 Python库)中提供的12个真实世界数据集。

对于每个数据集,我发现列与目标变量显示出一些显着关系(独立性卡方检验的p值<1%)。对于任何一列,我只保留不太罕见(它们必须在测试集中至少有100个案例)或过于频繁(它们必须占数据集的比例不超过50%)的值。这些值中的每一个都标识数据集的一个片段。

对于每个数据集,我在整个训练数据集上训练了一个通用模型(CatBoost,没有参数调整)。然后,对于每个片段,我在属于相应片段的训练数据集部分上训练了一个专门的模型(同样是CatBoost,没有参数调整)。最后,我比较了两种方法在属于该段的测试数据集部分上的性能(ROC曲线下的面积)。

让我们看一下最终输出:

12 个真实数据集的模拟结果。每行都是一个段,由数据集、列和值的组合标识。图源作者。

原则上,要选出获胜者,我们可以只看“roc_general”和“roc_specialized”之间的区别。然而,在某些情况下,这种差异可能是偶然的。因此,我也计算了差异何时具有统计显着性(有关如何判断两个ROC分数之间的差异是否显着的详细信息,请参阅本文)。

因此,我们可以在两个维度上对601比较进行分类:通用模型是否优于专用模型以及这种差异是否显着。这是结果:

601 比较的总结。“general > specialized”表示通用模型的ROC曲线下面积高于专用模型,“specialized > general”则相反。“显着”/“不显着”表明这种差异是否显着。图源作者。

很容易看出,通用模型在89%的时间(454+83/601)优于专用模型。但是,如果我们坚持重要的案例,一般模型在95%的时间(87个中的83个)优于专用模型。

出于好奇,我们也将87个重要案例可视化为一个图表,x轴为专用模型的ROC分数,y轴为通用模型的ROC分数。

比较:专用模型的 ROC 与通用模型的 ROC。仅包括显示出显着差异的部分。图源作者。

对角线上方的所有点都标识了通用模型比专用模型表现更好的情况。

但是,更好在哪里?

我们可以计算两个ROC分数之间的平均差。事实证明,在87个显着案例中,通用模型的 ROC 平均比专用模型高2.4%,这是很多!


结论

在本文中,我们比较了两种策略:使用在整个数据集上训练的通用模型与使用专门针对数据集不同部分的许多模型。

我们已经看到,没有令人信服的理由使用专用模型,因为强大的算法(例如基于树的模型)可以在本地处理不同的行为。此外,从维护工作、系统复杂性、训练时间、计算成本和存储成本的角度来看,使用专用模型涉及到几个实际的复杂问题。

我们还在12个真实数据集上测试了这两种策略,总共有601个可能的片段。在这个实验中,通用模型在89%的时间里优于专用模型。只看具有统计显着性的案例,这个数字上升到95%,ROC得分平均提高2.4%。

您可以在这个GitHub存储库中找到本文使用的所有Python代码。

原文标题:

What Is Better: One General Model or Many Specialized Models?

原文链接:

https://towardsdatascience.com/what-is-better-one-general-model-or-many-specialized-models-9500d9f8751d

编辑:于腾凯

校对:杨学俊

译者简介

欧阳锦,一名在埃因霍温理工大学就读的硕士生。喜欢数据科学和人工智能相关方向。欢迎不同观点和想法的交流与碰撞,对未知充满好奇,对热爱充满坚持。

翻译组招募信息

工作内容:需要一颗细致的心,将选取好的外文文章翻译成流畅的中文。如果你是数据科学/统计学/计算机类的留学生,或在海外从事相关工作,或对自己外语水平有信心的朋友欢迎加入翻译小组。

你能得到:定期的翻译培训提高志愿者的翻译水平,提高对于数据科学前沿的认知,海外的朋友可以和国内技术应用发展保持联系,THU数据派产学研的背景为志愿者带来好的发展机遇。

其他福利:来自于名企的数据科学工作者,北大清华以及海外等名校学生他们都将成为你在翻译小组的伙伴。

点击文末“阅读原文”加入数据派团队~

转载须知

如需转载,请在开篇显著位置注明作者和出处(转自:数据派ID:DatapiTHU),并在文章结尾放置数据派醒目二维码。有原创标识文章,请发送【文章名称-待授权公众号名称及ID】至联系邮箱,申请白名单授权并按要求编辑。

发布后请将链接反馈至联系邮箱(见下方)。未经许可的转载以及改编者,我们将依法追究其法律责任。

点击“阅读原文”拥抱组织

独家 | 哪个更好:一个通用模型还是多个专用模型?相关推荐

  1. 如何设计一个通用的权限管理系统

    点击上方蓝色"方志朋",选择"设为星标" 回复"666"获取独家整理的学习资料! 作者:PioneerYi juejin.im/post/6 ...

  2. 如何设计实现一个通用的分布式事务框架?

    公众号后台回复"学习",获取作者独家秘制精品资料 扫描下方海报二维码,试听课程: 本文来源:https://www.bytesoft.org/ 一个TCC事务框架需要解决的当然是分 ...

  3. 干货!如何设计实现一个通用的分布式事务框架?

    来源:https://www.bytesoft.org/ 一个TCC事务框架需要解决的当然是分布式事务的管理.关于TCC事务机制的介绍,可以参考TCC事务机制简介. TCC事务模型虽然说起来简单,然而 ...

  4. CVPR 2023 | 一键去除视频闪烁,该研究提出了一个通用框架

    该论文成功提出了第一个无需额外指导或了解闪烁的通用去闪烁方法,可以消除各种闪烁伪影. 高质量的视频通常在时间上具有一致性,但由于各种原因,许多视频会出现闪烁.例如,由于一些老相机硬件质量较差,不能将每 ...

  5. 如何设计一个通用的权限管理系统?说的太详细了!

    作者:PioneerYi 链接:https://juejin.im/post/6850037267554287629 一个系统,如果没有安全控制,是十分危险的,一般安全控制包括身份认证和权限管理.用户 ...

  6. 浅谈AI现状:它还不是万能的 更像一个“软体动物”

    虽然人工智能这个词是在20世纪50年代正式发明的,但是人工智能(AI)这个概念可以追溯到古埃及的自动机器和早期的希腊机器人神话.人工智能 虽然人工智能这个词是在20世纪50年代正式发明的,但是人工智能 ...

  7. 如何设计一个通用的权限管理系统?说的太详细了

    一个系统,如果没有安全控制,是十分危险的,一般安全控制包括身份认证和权限管理.用户访问时,首先需要查看此用户是否是合法用户,然后检查此用户可以对那些资源进行何种操作,最终做到安全访问.身份认证的方式有 ...

  8. 写一个通用数据访问组件

    出处:http://www.csharp-corner.com willsound(翻译) 我收到过好多Email来问我如何用一个通用的数据提供者(data provider)在不失自然数据提供者(n ...

  9. 如何构建一个通用的垂直爬虫平台?

    阅读本文大约需要15~20分钟. 本文章内容较多,非常干货!如果手机阅读体验不好,建议先收藏后到 PC 端阅读. 之前做爬虫时,在公司设计开发了一个通用的垂直爬虫平台,后来在公司做了内部的技术分享,这 ...

最新文章

  1. 洛谷P1966 火柴排队(逆序对)
  2. vfp 右键发送邮件_邮件批量发送的方法教程
  3. LeetCode-95-Unique Binary Search Trees II
  4. 丁丁打折网卷能用吗_超市货架上就能买到的好用护发素,平价好用,打折时可以多囤点...
  5. 使你的C/C++代码支持Unicode
  6. PaddleOCR 文本检测训练+推理模型转换教程
  7. [转载] python hasattr函数_Python的hasattr() getattr() setattr() 函数使用方法详解
  8. mysql数据库教程官网_数据库MySQL官方推荐教程-MySQL入门到删库
  9. Android 录制视频添加时间水印,Android开发教程入门
  10. java语言如何将小写字母转化为大写_java中如何把大写字母转换成小写字母,小写字母转换成大写字母...
  11. 注塑工艺需要考虑的7个因素
  12. W25Q128 Flash
  13. WebGL简易教程(十五):加载gltf模型
  14. 第三方测试什么意思?国内知名第三方测试公司排名
  15. 在Linux(fedora 20)上解压缩rar文件
  16. 通俗地解释下密码学中的归约证明
  17. MongoDB基本操作(Nosql数据库入门与实践)
  18. nacivate premium 12.1.12 安装包加破解注册机 亲测可用
  19. javaweb接入QQ登录
  20. Java Java!

热门文章

  1. 良/恶性乳腺肿瘤预测(逻辑回归分类器)
  2. 埇 mysql 不认这个字_输入法项目-用delphi生成GBK 中文编码 GBK 扩充汉字编码表(3) GBK/3: $8140 —$A0FE(部分)...
  3. 云服务器搭建crm系统,云服务器搭建crm网站教程
  4. Kepware的完美替代者
  5. IDEA中配置Tomcat(详细教程)
  6. 检测字符串是否包含特殊字符
  7. Pygame实战外星人入侵1.1——添加飞船
  8. AES128-ecb加解密
  9. 变量类型(cpu/gpu)
  10. 假若重新度过大学四年。。