决策树二分类之泰坦尼号克生存预测

  • 一、项目简介
    • 1.1 项目背景
    • 1.2 目标问题
    • 1.3 字段描述
  • 二、训练集(train)建模
    • 2.1 导入相关库
    • 2.2 自定义函数
    • 2.3 特征工程
      • 2.3.1 数据导入
      • 2.3.2 数据初探
        • (1)特征信息
        • (2)特征缺失值比例统计
        • (3)数值特征描述统计
      • 2.3.3 单特征可视化分析与处理
        • (1)Survived 是否存活
        • (2)Pclass 乘客等级
        • (3)Name 乘客姓名
        • (4)Sex 性别
        • (5)Age 年龄
        • (6)SibSp 堂兄弟妹个数
        • (7)Parch 父母与小孩的个数
        • (8)Ticket 船票信息
        • (9)Fare 票价
        • (10)Cabin 船舱
        • (11)Embarked 登船的港口
      • 2.3.4 衍生特征可视化分析与处理
        • FamilyNumbers 家庭人数
      • 2.3.5 删除冗余字段
      • 2.3.6 相关性矩阵可视化
    • 2.4 决策树模型训练
      • 2.4.1 数据标准化(Z-score)
      • 2.4.2 划分训练集、测试集
      • 2.4.3 网格寻参与交叉验证
      • 2.4.4 模型评价
      • 2.4.5 混淆矩阵可视化
      • 2.4.6 ROC曲线
  • 三、完整代码(含对test预测)
  • 四、Kaggle 得分

一、项目简介

官方链接:Titanic - Machine Learning from Disaster

1.1 项目背景

  • 1、泰坦尼克号: 英国白星航运公司下辖的一艘奥林匹克级邮轮,于1909年3月31日在爱尔兰贝尔法斯特港的哈兰德与沃尔夫造船厂动工建造,1911年5月31日下水,1912年4月2日完工试航。
  • 2、首航时间: 1912年4月10日
  • 3、航线: 从英国南安普敦出发,途经法国瑟堡-奥克特维尔以及爱尔兰昆士敦,驶向美国纽约。
  • 4、沉船: 1912年4月15日(1912年4月14日23时40分左右撞击冰山)
    船员+乘客人数:2224
  • 5、遇难人数: 1502(67.5%)

1.2 目标问题

  • 根据训练集中各位乘客的特征及是否获救标志的对应关系训练模型,预测测试集中的乘客是否获救。(二元分类问题

1.3 字段描述

二、训练集(train)建模

  • 数据集链接:train.csv

2.1 导入相关库

import numpy as np
import pandas as pd
from scipy import stats# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score# 可视化相关库
import seaborn as sns
import matplotlib.pyplot as plt# 解决mac 系统画图中文不显示问题
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # # 解决win 系统中文不显示问题
# from pylab import mpl
# mpl.rcParams['font.sans-serif']=['SimHei']# 不显示警告
import warnings
warnings.filterwarnings('ignore')

2.2 自定义函数

def PieChart(df):'''绘制环形饼图'''plt.figure(figsize = (4,4), # 设置图片大小dpi = 100        # 精度)df.value_counts().plot( kind = 'pie',               # 设置绘图类型为饼图wedgeprops = {'width':0.4}, # 设置空心比例autopct = "%.1f%%"          # 显示百分比)def BarPlot(df,ColumnsName):'''绘制不同 ColumnsName 的存活人数柱形图'''ColumnsDf = df.groupby(['Survived',ColumnsName]).count()[['PassengerId']].reset_index()\.rename(columns={"PassengerId":"Count"})plt.figure(figsize=(4,3),dpi=150)sns.barplot(data=ColumnsDf,x=ColumnsName,y="Count",hue="Survived")plt.title('Survived Count Of {}'.format(ColumnsName))def OneHot(x):'''功能:one-hot 编码传入:需要编码的分类变量返回:返回编码后的结果,形式为 dataframe'''# 通过 LabelEncoder 将分类变量打上数值标签 lb = LabelEncoder()                             # 初始化x_pre = lb.fit_transform(x)                     # 模型拟合x_dict = dict([[i,j] for i,j in zip(x,x_pre)])  # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}x_num = [[x_dict[i]] for i in x]                # 通过 x_dict 将分类变量转为数值型# 进行one-hot编码enc = OneHotEncoder()                        # 初始化enc.fit(x_num)                               # 模型拟合array_data = enc.transform(x_num).toarray()  # one-hot 编码后的结果,二维数组形式# 转成 dataframe 形式df = pd.DataFrame(array_data)inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值# columns 重命名if type(x) == pd.Series:firs_name = x.nameelse:firs_name = ""df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]           return df

2.3 特征工程

2.3.1 数据导入

train = pd.read_csv("train.csv")
train.head(5)

2.3.2 数据初探

(1)特征信息

train.info()
  • 可以看出训练集共有891个样本,且有三个字段(Age、Cabin、Embarked)存在缺失值。

(2)特征缺失值比例统计

train.isnull().sum()/len(train)
  • 可以看出,字段Cabin缺失比例较大,达到77%。

(3)数值特征描述统计

train.describe()
  • 可以看出,票价(Fare)最低为0,估计是船上的员工。

2.3.3 单特征可视化分析与处理

(1)Survived 是否存活

########################## 1、Survived 是否存活 ##########################
# Y标签,{0:不存活,1:存活}
# 有无缺失值:无
# 数据处理:不处理
# 从图中可以看出,死亡人数与存活人数占比差异不大PieChart(train['Survived'])

(2)Pclass 乘客等级

########################## 2、Pclass 乘客等级 ##########################
# 无缺失值,等级变量
# 用柱状图查看各乘客等级的存活情况
# 可以看出 Pclass=3 的乘客中,存活人数远低于死亡人数
BarPlot(train,"Pclass")# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
train['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in train['Pclass']]# 查看不同 PclassType 的存活情况
BarPlot(train,"PclassType")# 再对 PclassType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['PclassType']),left_index=True,right_index=True
)

(3)Name 乘客姓名

########################## 3、Name 乘客姓名 ##########################
# 字符串变量
# 有无缺失值:无# 从乘客姓名中获取头街
# 姓名中头街字符串与定义头街类别之间的关系
#     Officer: 政府官员,
#     RoyaIty: 王室(皇室),
#     Mr:      已婚男士,
#     Mrs:     已婚女士,
#     Miss:    年轻未婚女子,
#     Master:  有技能的人/教师
# 新建字段 Title_Dict
Title_Dict = {'Mr':'Mr','Mrs':'Mrs', 'Miss':'Miss','Master': 'Master', 'Don':'Royalty','Rev':'Officer','Dr':')fficer', 'Mme':'Mrs','Ms':'Mrs','Major':'Officer', 'Lady': 'Royalty','Sir': 'Royalty','Mlle':'Miss', 'Col': 'Officer','Capt':'Officer','the Countess': 'Royalty','Jonkheer': 'Royalty','Dona': 'Royalty'
}
train['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in train['Name']] # 对Name进行分类
# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为Mr(已婚男士)中,死亡人数远远大于存活人数;
#        乘客为Mrs(已婚女士)、Miss(年轻未婚女子)中,死亡人数远远低于存活人数;
BarPlot(train,"NameType")# 数据进一步处理:将 NameType 分成三类
# Mr(已婚男士)
# Mrs(已婚女士)、Miss(年轻未婚女子)
# 其他
train['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \for i in train['NameType']]# 查看不同 NameType2 的存活情况
BarPlot(train,"NameType2")# 再对 NameType2 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['NameType2']),left_index=True,right_index=True
)

(4)Sex 性别

########################## 4、Sex 性别 ##########################
# 分类变量
# 有无缺失值:无# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为男性中,死亡人数远远大于存活人数
BarPlot(train,"Sex")# 对 Sex 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['Sex']),left_index=True,right_index=True)

(5)Age 年龄

########################## 5、Age 年龄 ##########################
# 连续变量
# 有无缺失值:有,缺失比例19.9%# 缺失值用均值填充
train['Age'] = train['Age'].fillna(train['Age'].mean())# 用直方图查看各 Age 的存活情况
# 可以看出 可以看出小于5岁的小孩存活率很高
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(train[train['Survived']==0]['Age'],color="red",kde=False)
sns.distplot(train[train['Survived']==1]['Age'],color="blue",kde=False)# 数据处理:将 Age 分成两类,Age<=5、Age>5
train['AgeType'] = ["Age<=5" if i <= 5 else "Age>5"  for i in train['Age']]# 查看不同 AgeType 的存活情况
BarPlot(train,"AgeType")# 再对 AgeType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['AgeType']),left_index=True,right_index=True
)


(6)SibSp 堂兄弟妹个数

########################## 6、SibSp 堂兄弟妹个数 ##########################
# 无缺失值,等级变量
# 用柱状图查看各堂兄弟妹个数的存活情况
# 可以看出 SibSp=0 的乘客中,死亡人数较多
BarPlot(train,"SibSp")# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
train['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in train['SibSp']]# 查看不同 SibSpType 的存活情况
BarPlot(train,"SibSpType")# 再对 SibSpType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['SibSpType']),left_index=True,right_index=True
)

(7)Parch 父母与小孩的个数

########################## 7、Parch 父母与小孩的个数 ##########################
# 连续变量
# 有无缺失值:无# 用柱状图查看父母与小孩的个数的存活情况
# 可以看出 Parch=0 的乘客中,死亡人数较多
BarPlot(train,"Parch")# 数据处理:将 Parch 分成两类,Parch=0、Parch>0
train['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in train['Parch']]# 查看不同 ParchType 的存活情况
BarPlot(train,"ParchType")# 再对 ParchType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['ParchType']),left_index=True,right_index=True
)


(8)Ticket 船票信息

  • 字符变量
  • 有无缺失值:无
  • 数据处理:这里直接删去(下文会删)

(9)Fare 票价

########################## 9、Fare 票价 ##########################
# 连续变量
# 有无缺失值:无# 查看Fare(票价)= 0 的生存情况
Fare0 = train[train['Fare']==0]
Fare0Survived = Fare0.groupby(['Survived']).count()[['PassengerId']].reset_index().rename(columns={"PassengerId":"Count"})
plt.figure(figsize=(4,3),dpi=150)
sns.barplot(data=Fare0Survived,x="Survived",y="Count"
)
plt.title('Survived Count Of Fare=0')# 查看Fare(票价)!= 0 的生存情况
Fare1 = train[train['Fare']!=0]
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(Fare1[Fare1['Survived']==0]['Fare'],color="red",kde=False)
sns.distplot(Fare1[Fare1['Survived']==1]['Fare'],color="blue",kde=False)
plt.title('Survived Count Of Fare!=0')# 对 Fare 分成三类
# Fare = 0
# Fare <=50
# Fare > 50
train['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in train['Fare']]# 用柱状图查看不同 FareType 的存活情况
# 可以看出 Fare=0 的乘客中,乘客几乎都死亡
#        Fare <=50 的乘客中,死亡人数大于存活人数
#        Fare > 50 的乘客中,存活人数大于死亡人数
BarPlot(train,"FareType")# 再对 FareType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['FareType']),left_index=True,right_index=True
)


(10)Cabin 船舱

  • 离散变量
  • 有无缺失值:有,缺失值比例高达77%
  • 数据处理:缺失值比例较大,直接删去(下文会删)

(11)Embarked 登船的港口

########################## 11、Embarked 登船的港口 ##########################
# 离散变量
# 有无缺失值:有,缺失值比例很低# 用柱状图查看各登船的港口的存活情况
# 可以看出 Embarked=S 的乘客中,死亡人数较多
BarPlot(train,"Embarked")# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
mode = stats.mode(train['Embarked'])[0][0] # 众数
train['Embarked'] = train['Embarked'].fillna(mode)train = pd.merge(train,OneHot(train['Embarked']),left_index=True,right_index=True)

2.3.4 衍生特征可视化分析与处理

FamilyNumbers 家庭人数

########################## FamilyNumbers 家庭人数 ##########################
# 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
train['FamilyNumbers'] = train['SibSp'] + train['Parch'] + 1# 用柱状图查看各家庭人数的存活情况
# 可以看出 家庭人数=1 的乘客中,死亡人数较多
#        家庭人数>=5 的乘客中,存活人数较多
BarPlot(train,"FamilyNumbers")# 新增 FamilyType 字段
# 1 : 单身(Single)
# 2-4:小家庭(Family_Small)
# >4: 大家庭(Family_Large)
train['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in train['FamilyNumbers']]# 查看不同 FamilyType 的存活情况
BarPlot(train,"FamilyType")# 对 FamilyType 进行One-hot编码处理
train = pd.merge(train,OneHot(train['FamilyType']),left_index=True,right_index=True)

2.3.5 删除冗余字段

drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\'FamilyNumbers','FamilyType']
train.drop(drop_columns,axis=1,inplace=True)

2.3.6 相关性矩阵可视化

  • 采用斯皮尔曼相关系数
corr_df = train.corr(method="spearman")[['Survived']].sort_values(by="Survived",ascending=False)
plt.figure(figsize=(1,8),dpi=100)
sns.heatmap(corr_df,cmap='Blues',center=0,vmax=1,vmin=-1,annot=True,annot_kws={'size':10,'weight':'bold', 'color':'red'}
)

2.4 决策树模型训练

2.4.1 数据标准化(Z-score)

def ZscoreNormalization(x):'''Z-score 标准化'''return (x - np.mean(x)) / np.std(x)data = train.drop("Survived",axis=1).agg(ZscoreNormalization)
data['Lable'] = train['Survived']

2.4.2 划分训练集、测试集

  • 按7:3比例划分
x_train, x_test, y_train, y_test = train_test_split(data.drop("Lable",axis=1), data['Lable'], test_size = 0.3, random_state = 0
)

2.4.3 网格寻参与交叉验证

param_grid = {'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”'splitter' : ['best','random'],   # 划分方式:{“best”, “random”}, default=”best”'max_depth' : range(1,6),         # 最大深度'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数'min_samples_leaf' : range(1,6),  # 叶节点所需的最小样本数
}
clf = DecisionTreeClassifier()               # 初始化
gs = GridSearchCV(clf,param_grid,cv=5)       # 网格搜索与交叉验证
gs.fit(x_train,y_train)                      # 模型训练
print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
print("Best Score: ",gs.best_score_)         # 打印最好分数


注意: 每次运行的结果输出会存在差别。

2.4.4 模型评价

print("\n---------- 模型评价 ----------")
y_pred = gs.predict(x_test)                         # 预测
cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
df_cm = pd.DataFrame(cm)                            # 构建DataFrame
print('Accuracy score:', accuracy_score(y_test, y_pred))                       # 准确率
print('Recall:', recall_score(y_test, y_pred, average='weighted'))             # 召回率
print('F1-score:', f1_score(y_test, y_pred, average='weighted'))               # F1分数
print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度

2.4.5 混淆矩阵可视化

plt.figure(dpi=150)heatmap = sns.heatmap(df_cm, annot=True, fmt='.0f', cmap='Blues')
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, ha='right')plt.title('DecisionTreeClassifier Model Results')
plt.show()

2.4.6 ROC曲线

y_pred_proba = gs.predict_proba(np.array(x_test))[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)sns.set()plt.figure(figsize=(5,4),dpi=150)
plt.plot(fpr, tpr)
plt.plot(fpr, fpr, linestyle = '-' , color = 'k')plt.xlabel('False positive rate')
plt.ylabel('True positive rate')AU = np.round(roc_auc_score(y_test, y_pred_proba), 2)plt.title(f'AU: {AU}');plt.show()

三、完整代码(含对test预测)

  • 含预测,不含可视化
import numpy as np
import pandas as pd
from scipy import stats# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score# 不显示红色警告
import warnings
warnings.filterwarnings('ignore')def OneHot(x):'''功能:one-hot 编码传入:需要编码的分类变量返回:返回编码后的结果,形式为 dataframe'''# 通过 LabelEncoder 将分类变量打上数值标签 lb = LabelEncoder()                             # 初始化x_pre = lb.fit_transform(x)                     # 模型拟合x_dict = dict([[i,j] for i,j in zip(x,x_pre)])  # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}x_num = [[x_dict[i]] for i in x]                # 通过 x_dict 将分类变量转为数值型# 进行one-hot编码enc = OneHotEncoder()                        # 初始化enc.fit(x_num)                               # 模型拟合array_data = enc.transform(x_num).toarray()  # one-hot 编码后的结果,二维数组形式# 转成 dataframe 形式df = pd.DataFrame(array_data)inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值# columns 重命名if type(x) == pd.Series:firs_name = x.nameelse:firs_name = ""df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns]           return dfdef ZscoreNormalization(x):'''Z-score 标准化'''return (x - np.mean(x)) / np.std(x)def DataClean(df,Lable=True):'''数据预处理函数'''########################## 1、Pclass 乘客等级 ########################### 无缺失值,等级变量# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3df['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in df['Pclass']]# 再对 PclassType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['PclassType']),left_index=True,right_index=True)    ########################## 2、Name 乘客姓名 ########################### 字符串变量# 有无缺失值:无# 从乘客姓名中获取头街# 姓名中头街字符串与定义头街类别之间的关系#     Officer: 政府官员,#     RoyaIty: 王室(皇室),#     Mr:      已婚男士,#     Mrs:     已婚女士,#     Miss:    年轻未婚女子,#     Master:  有技能的人/教师 # 新建字段 Title_Dict Title_Dict = {'Mr':'Mr','Mrs':'Mrs', 'Miss':'Miss','Master': 'Master', 'Don':'Royalty','Rev':'Officer','Dr':')fficer', 'Mme':'Mrs','Ms':'Mrs','Major':'Officer', 'Lady': 'Royalty','Sir': 'Royalty','Mlle':'Miss', 'Col': 'Officer','Capt':'Officer','the Countess': 'Royalty','Jonkheer': 'Royalty','Dona': 'Royalty'}df['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in df['Name']] # 对Name进行分类# 数据进一步处理:将 NameType 分成三类# Mr(已婚男士)# Mrs(已婚女士)、Miss(年轻未婚女子)# 其他df['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \for i in df['NameType']]# 再对 NameType2 进行One-Hot编码处理df = pd.merge(df,OneHot(df['NameType2']),left_index=True,right_index=True)   ########################## 3、Sex 性别 ########################### 分类变量# 有无缺失值:无# 对 Sex 进行One-Hot编码处理df = pd.merge(df,OneHot(df['Sex']),left_index=True,right_index=True)########################## 4、Age 年龄 ########################### 连续变量# 有无缺失值:有,缺失比例19.9%# 缺失值用均值填充df['Age'] = df['Age'].fillna(df['Age'].mean())# 数据处理:将 Age 分成两类,Age<=5、Age>5df['AgeType'] = ["Age<=5" if i <= 5 else "Age>5"  for i in df['Age']]# 再对 AgeType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['AgeType']),left_index=True,right_index=True)   ########################## 5、SibSp 堂兄弟妹个数 ########################### 无缺失值,等级变量# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0df['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in df['SibSp']]# 再对 SibSpType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['SibSpType']),left_index=True,right_index=True)    ########################## 6、Parch 父母与小孩的个数 ########################### 连续变量# 有无缺失值:无# 数据处理:将 Parch 分成两类,Parch=0、Parch>0df['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in df['Parch']]# 再对 ParchType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['ParchType']),left_index=True,right_index=True)    ########################## 8、Fare 票价 ########################### 连续变量# 有无缺失值:无# 对 Fare 分成三类# Fare = 0# Fare <=50# Fare > 50df['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in df['Fare']]# 再对 FareType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['FareType']),left_index=True,right_index=True)    ########################## 10、Embarked 登船的港口 ########################### 离散变量# 有无缺失值:有,缺失值比例很低# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理mode = stats.mode(df['Embarked'])[0][0] # 众数df['Embarked'] = df['Embarked'].fillna(mode)df = pd.merge(df,OneHot(df['Embarked']),left_index=True,right_index=True)########################## 11、FamilyNumbers 家庭人数 ########################### 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)df['FamilyNumbers'] = df['SibSp'] + df['Parch'] + 1# 新增 FamilyType 字段# 1 : 单身(Single)        # 2-4:小家庭(Family_Small)# >4: 大家庭(Family_Large)df['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in df['FamilyNumbers']]# 对 FamilyType 进行One-hot编码处理df = pd.merge(df,OneHot(df['FamilyType']),left_index=True,right_index=True)########################## 删除冗余变量 ##########################drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\'FamilyNumbers','FamilyType']df.drop(drop_columns,axis=1,inplace=True)########################## 数据标准化 ##########################if Lable == True: # 判断是否是测试集(测试集不含标签)data = df.drop("Survived",axis=1).agg(ZscoreNormalization)data['Lable'] = df['Survived']else:data = df.agg(ZscoreNormalization)return datadef sklearn_DecisionTreeClassifier(data):'''决策树二分类'''# 划分训练集、测试集x_train, x_test, y_train, y_test = train_test_split(data.drop("Lable",axis=1), data['Lable'], test_size = 0.3, random_state = 0)print("\n---------- 模型训练 ----------")# 网格寻参param_grid = {'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”'splitter' : ['best','random'],   # 划分方式:{“best”, “random”}, default=”best”'max_depth' : range(1,6),         # 最大深度'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数'min_samples_leaf' : range(1,6),  # 叶节点所需的最小样本数}clf = DecisionTreeClassifier()               # 初始化gs = GridSearchCV(clf,param_grid,cv=5)       # 网格搜索与交叉验证gs.fit(x_train,y_train)                      # 模型训练print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器print("Best Score: ",gs.best_score_)         # 打印最好分数# 模型预测print("\n---------- 模型评价 ----------")y_pred = gs.predict(x_test)                         # 预测cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵df_cm = pd.DataFrame(cm)                            # 构建DataFrameprint('Accuracy score:', accuracy_score(y_test, y_pred))                       # 准确率print('Recall:', recall_score(y_test, y_pred, average='weighted'))             # 召回率print('F1-score:', f1_score(y_test, y_pred, average='weighted'))               # F1分数print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度return gs.best_estimator_ # 返回最好的训练模型if __name__ == "__main__":train = pd.read_csv("train.csv")test  = pd.read_csv("test.csv")print("\n---------- 数据预处理 ----------")train_data = DataClean(train)           test_data = DataClean(test,Lable=False) # 决策树二分类best_estimator = sklearn_DecisionTreeClassifier(train_data) # 预测y_pred = best_estimator.predict(test_data) # 输出预测结果result = test[['PassengerId']]result['Survived'] = y_predresult.to_csv("Titanic Results.csv",index=False)print("\n程序运行完成")

四、Kaggle 得分

  • 得分:0.77511
  • 排名:7651

参考
1、Kaggle泰坦尼克号比赛项目详解
2、机器学习实战——kaggle 泰坦尼克号生存预测——六种算法模型实现与比较

python 决策树分类 泰坦尼克生存预测相关推荐

  1. kaggle房价预测特征意思_机器学习-kaggle泰坦尼克生存预测(一)-数据清洗与特征构建...

    1.背景: 1.1 关于kaggle: 谷歌旗下的 Kaggle 是一个数据建模和数据分析竞赛平台.该平台是当下最流行的数据科研赛事平台,其组织的赛事受到全球数据科学爱好者追捧. 如果学生能够在该平台 ...

  2. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

  3. 实验六:泰坦尼克生存预测之缺失值处理

    一.任务描述 背景故事: 泰坦尼克号(RMS Titanic),又译作铁达尼号,是英国白星航运公司下辖的一艘奥林匹克级游轮,排水量46000吨,于1909年3月31日在北爱尔兰贝尔法斯特港的哈兰德与沃 ...

  4. Python项目实战-Tensorflow2.0实现泰坦尼克生存预测

    目录 一.数据集下载地址 二.探索性因子分析(EDA) 三.特征工程 四.构建Dataset与Model fit和自定义estimator使用 预定义estimator的使用 一.数据集下载地址 # ...

  5. 决策树实例-泰坦尼克幸存者预测

  6. 基于逻辑回归的泰坦尼克生存预测

  7. Kaggle实战:泰坦尼克幸存者预测 - 上

    (文章同步更新于个人博客@dai98.github.io) 源代码: Github Kaggle 泰坦尼克幸存者预测是Kaggle上数据竞赛的入门级别的比赛,我曾经在一年前作为作业参加过这个比赛,我想 ...

  8. Kaggle实战:泰坦尼克幸存者预测 -下

    (文章同步更新于个人博客@dai98.github.io) 源代码:Github 上一篇文章介绍了如何使用深度学习来预测泰坦尼克号幸存者,这一部分使用多分类器投票来做.由于数据预处理部分比较相似,重复 ...

  9. 泰坦尼克号生存预测python_用Python预测泰坦尼克生存情况-附数据集

    介绍:通过逻辑回归算法,解决kaggle网站上的泰坦尼克生存情况预测问题,准确率在80%左右. 一.提出问题 什么样的人在泰坦尼克号中更容易存活? 二.理解数据 2.1 数据来源 数据来自kaggle ...

  10. Kaggle初体验之泰坦尼特生存预测

    Kaggle初体验之泰坦尼特生存预测 学习完了决策树的ID3.C4.5.CART算法,找一个试手的地方,Kaggle的练习赛泰坦尼特很不错,记录下 流程     首先注册一个账号,然后在顶部菜单栏Co ...

最新文章

  1. python note
  2. navicat保存查询语句_MySQL数据库安装创建及Navicat客户端连接
  3. 父亲和女儿同为互联网大佬, 但不幸都得癌症
  4. ajax请求php的过程,php如何实现ajax请求
  5. poj1463 Strategic game
  6. Nginx设置TCP上游服务器的SSL配置
  7. 深度 | 从各种注意力机制窥探深度学习在NLP中的神威
  8. 语音转写可实时,直播也能同步字幕
  9. 反激变压器结构设计学习笔记(进阶)
  10. Sketchup2019安装包安装教程
  11. 苹果IOS隐藏复制链接等按钮失效及报错the permission value is offline verifying解决办法
  12. ospfdr选举规则_DR/BDR详细选举过程
  13. abaqus个人总结 各种问题各种debug
  14. shell 脚本学习
  15. WPP和iHeartMedia推出“聆听项目”
  16. java word apache poi 操作word模板。
  17. 0xC000005:Access Violation和指针强制转换问题
  18. 区块链培训就业方向多不多?
  19. 单片机开发—呼吸灯的三种实现方法
  20. Safari浏览器意外退出无法重新启动怎么办?

热门文章

  1. Chrome主页被强制修改为百度解决办法
  2. [CGAL]建立一个正四面体
  3. storm架构及原理详解
  4. 2022 python获取和风天气 web api v7版本
  5. java中对手机号、邮箱等隐私信息脱敏展示,如手机号138****8888。
  6. ArcGIS Server 发布地图服务遇到的问题
  7. 技术帖:如何把mobi文件转化成pdf
  8. wap pc html,PCWAP手机PC网站信息管理系统 v1.4.3
  9. 深度学习:GAN 对抗网络原理详细解析(零基础必看)
  10. 详解 LVS、Nginx 及 HAProxy 工作原理