【机器学习与算法】python手写算法:Cart树

  • 背景
  • 代码
  • 输出示例

背景

Cart树算法原理即遍历每个变量的每个分裂节点,找到增益(gini或entropy)最大的分裂节点进行二叉分割。
这里只输出最优分割变量,最优分割点,分割后的样本数、分割后的坏账率,方便用于风控决策中指定最优策略。

代码

import pandas as pd
import numpy as npclass CartTree():def __init__(self, criterion='gini', min_sample_leaf=50):'''初始化Cart树criterion:增益计算标准可用gini 或 entropymin_sample_leaf:最小叶子节点数'''self.criterion = criterionself.min_sample_leaf = min_sample_leafdef cal_gain(self,df):'''计算分裂后的增益return:增益值'''df['sum'] = df.sum(axis=1)df = df.append(df.sum(),ignore_index=True)df[0] = df[0]/df['sum']df[1] = df[1]/df['sum']if self.criterion == 'gini':df['criterion'] = 1 - df[0]**2 - df[1]**2elif self.criterion == 'entropy':df['criterion'] = -df[0]*np.log(df[0]) - df[1]*np.log(df[1])else:raisecal_gain = df.loc[2,'criterion']- \(df.loc[1,'sum']/df.loc[2,'sum'])*df.loc[1,'criterion'] - \(df.loc[0,'sum']/df.loc[2,'sum'])*df.loc[0,'criterion']return cal_gaindef FindBestSplitPoint(self, Ser:pd.Series, y:pd.Series, num = 20):'''在num个等频分割点上寻找增益最大的分割点Ser:要计算分割增益的变量y:标签列num:等频分割数return:最大增益的增益值及分割点'''if Ser.value_counts().shape[0]==1:return 0,Nonetmp,cuts = pd.qcut(Ser,num,retbins=True,duplicates='drop')cuts = list(cuts[:-1])max_gain,best_cut=0,0for item in cuts:if (Ser[Ser<=item].shape[0]<self.min_sample_leaf)\|(Ser[Ser>item].shape[0]<self.min_sample_leaf):continuetmp = pd.cut(Ser,bins=[-np.inf,item,np.inf]).astype(str)df_binary = pd.crosstab(tmp,y)gain = self.cal_gain(df_binary)if gain>max_gain:max_gain = gainbest_cut = itemreturn max_gain,best_cutdef FindBestSplitVar(self, df:pd.DataFrame, y:str):'''遍历所有变量,寻找最优分割变量及分割点df:数据集y:标签列的列名return:最优分割变量、最大增益值、最优切割点'''best_col,max_gain,best_cut = None,0,0for col in [x for x in df.columns if x!=y]:if df[col].dtype in ['int64','float64']:gain,cut = self.FindBestSplitPoint(df[col], df[y])if df[col].dtype in ['object']:raiseif gain>max_gain:best_col = colmax_gain = gainbest_cut = cutreturn best_col,max_gain,best_cutdef CreateTree(self, df:pd.DataFrame, y:str, n_depth, pos_indicator=40):'''生成二叉树df:数据集y:标签列列名n_depth:树的深度'''if n_depth<1:returnbest_col,best_gain,best_cuts = self.FindBestSplitVar(df, y)if best_col is None:print(' '*pos_indicator+'case_number:',df.shape[0])print(' '*pos_indicator+'bad_rate:',round(df[y].mean(),3))returnprint(' '*pos_indicator+'case_number:',df.shape[0])print(' '*pos_indicator+'bad_rate:',round(df[y].mean(),3))print(' '*pos_indicator+'best_split_var:',best_col)print(' '*pos_indicator+'best_split_point:',best_cuts)print(' '*pos_indicator+'      /\\ ')print(' '*pos_indicator+'   /      \\')print(' '*pos_indicator+'/            \\')df_left = df.loc[df[best_col]<=best_cuts]df_right = df.loc[df[best_col]>best_cuts]if n_depth==1:print(' '*(pos_indicator-10)+'case_number:',df_left.shape[0])print(' '*(pos_indicator-10)+'bad_rate:',round(df_left[y].mean(),3))print(' '*(pos_indicator+10)+'case_number:',df_right.shape[0])print(' '*(pos_indicator+10)+'bad_rate:',round(df_right[y].mean(),3))self.CreateTree(df_left, y, n_depth-1, pos_indicator-20)self.CreateTree(df_right, y, n_depth-1, pos_indicator+20)if __name__ == '__main__':tree = CartTree(criterion='gini',min_sample_leaf=100)tree.CreateTree(df,'y',n_depth=2)

输出示例


最后,欢迎阅读其它算法的python实现:
【机器学习与算法】python手写算法:Cart树
【机器学习与算法】python手写算法:带正则化的逻辑回归
【机器学习与算法】python手写算法:xgboost算法
【机器学习与算法】python手写算法:Kmeans和Kmeans++算法
【机器学习与算法】python手写算法:softmax回归

【机器学习与算法】python手写算法:Cart树相关推荐

  1. 前端算法及手写算法JavaScript

    一.手写算法 1.获取url中参数列表,保存为对象 function getUrlParam(){ //获取url中参数列表,保存为对象 var url="http://jjhs/dddh? ...

  2. python实现tomasulo算法_手写算法-python代码实现KNN

    本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理 原理解析 KNN-全称K-Nearest Neighbor,最近邻算法,可以做分类任务,也可以做回归 ...

  3. 人工智能的本质:最优化 (神经网络优化算法python手写实现)

    人工智能的本质就是最优化.假设把任务比作是一碗饭, 传统的解决方法,就是根据数学公式,然后一口气吃完饭,如果饭碗小,数学公式还行,如果饭碗大,数学公式能一口吃完饭吗? 人工智能的本质就是最优化,得益于 ...

  4. 用 Python 手写机器学习最简单的 KNN 算法

    作者 | 苏克1900 责编 | 胡巍巍 说实话,相比爬虫,掌握机器学习更实用竞争力也更强些. 目前网上大多这类教程对新手都不友好,要么直接调用 Sklearn 包,要么满篇抽象枯燥的算法公式文字,看 ...

  5. Python 手写机器学习最简单的 kNN 算法

    https://www.toutiao.com/a6698919092876739079/ Python 手写机器学习最简单的 kNN 算法 苏克1900 Python爬虫与数据挖掘 本文 3000 ...

  6. python机器学习手写算法系列——逻辑回归

    从机器学习到逻辑回归 今天,我们只关注机器学习到线性回归这条线上的概念.别的以后再说.为了让大家听懂,我这次也不查维基百科了,直接按照自己的理解用大白话说,可能不是很严谨. 机器学习就是机器可以自己学 ...

  7. python手写字母识别_机器学习--kNN算法识别手写字母

    本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...

  8. python机器学习手写算法系列——线性回归

    本系列另一篇文章<决策树> https://blog.csdn.net/juwikuang/article/details/89333344 本文源代码: https://github.c ...

  9. python机器学习手写算法系列——kmeans聚类

    从机器学习到kmeans 聚类是一种非监督学习,他和监督学习里的分类有相似之处,两者都是把样本分布到不同的组里去.区别在于,分类分析是有标签的,聚类是没有标签的.或者说,分类是有y的,聚类是没有y的, ...

最新文章

  1. python网页爬虫-Python网页爬虫
  2. 反制爬虫之Burp Suite RCE
  3. 蓝凌ekp开发_新华教育集团战略升级,携手蓝凌量身定制数字化办公平台
  4. c语言获取五子棋盘光标位置,跪求C语言五子棋悔棋部分实现
  5. android jni malloc和free的使用
  6. 如何解决90%的问题?10位阿里大牛公布方法
  7. pwm波程序如何实现_【优秀成果】如何做好算法与程序实现教学的知识储备
  8. 第五章、使用复合赋值和循环语句
  9. C++调用matlab接口
  10. 第五章 运输层[练习题+课后习题]
  11. vscode 不能运行h5c3代码_Golang安装与环境搭建并在VSCode里面输出HelloWord
  12. 服务器win10系统开机慢,Win10系统开机慢怎么办 windows10开机慢的解决方法
  13. 设置表格表头字体_Excel双栏和三栏斜线表头制作技巧
  14. [人工智能]动物专家系统work
  15. 文本聚类 java_【Java】文本聚类
  16. python将word转成excel_Python实现Word表格转成Excel表格的示例代码
  17. ERP电商管理系统开发实现功能
  18. win10计算机打开之后隐藏3d对象视频,Win10 3D对象文件夹如何隐藏?手把手教你隐藏3D对象文件夹...
  19. CDH6.2环境中启用Kerberos
  20. js动态修改表格数据

热门文章

  1. eclipse各版本说明
  2. 什么是Jython?
  3. 144显示器只有60_DIY老司机:吃鸡显示器非得用144Hz,60Hz就不行?
  4. Alcatel OminPCX Office 模拟分机不显示来电号码的解决
  5. 【java文本处理】实现一个简单的小说文本阅读器(分页、翻页、页码跳转)
  6. 【整理】PJSIP开源库详解
  7. 电子信息工程专业就业形势分析
  8. COBIT4.0简介
  9. 32 回归分析——一元线性回归模型
  10. 三十四载Windows崛起之路: 苹果、可视做过微软“铺路石”