python 决策树分类 泰坦尼克生存预测
决策树二分类之泰坦尼号克生存预测
- 一、项目简介
- 1.1 项目背景
- 1.2 目标问题
- 1.3 字段描述
- 二、训练集(train)建模
- 2.1 导入相关库
- 2.2 自定义函数
- 2.3 特征工程
- 2.3.1 数据导入
- 2.3.2 数据初探
- (1)特征信息
- (2)特征缺失值比例统计
- (3)数值特征描述统计
- 2.3.3 单特征可视化分析与处理
- (1)Survived 是否存活
- (2)Pclass 乘客等级
- (3)Name 乘客姓名
- (4)Sex 性别
- (5)Age 年龄
- (6)SibSp 堂兄弟妹个数
- (7)Parch 父母与小孩的个数
- (8)Ticket 船票信息
- (9)Fare 票价
- (10)Cabin 船舱
- (11)Embarked 登船的港口
- 2.3.4 衍生特征可视化分析与处理
- FamilyNumbers 家庭人数
- 2.3.5 删除冗余字段
- 2.3.6 相关性矩阵可视化
- 2.4 决策树模型训练
- 2.4.1 数据标准化(Z-score)
- 2.4.2 划分训练集、测试集
- 2.4.3 网格寻参与交叉验证
- 2.4.4 模型评价
- 2.4.5 混淆矩阵可视化
- 2.4.6 ROC曲线
- 三、完整代码(含对test预测)
- 四、Kaggle 得分
一、项目简介
官方链接:Titanic - Machine Learning from Disaster
1.1 项目背景
- 1、泰坦尼克号: 英国白星航运公司下辖的一艘奥林匹克级邮轮,于1909年3月31日在爱尔兰贝尔法斯特港的哈兰德与沃尔夫造船厂动工建造,1911年5月31日下水,1912年4月2日完工试航。
- 2、首航时间: 1912年4月10日
- 3、航线: 从英国南安普敦出发,途经法国瑟堡-奥克特维尔以及爱尔兰昆士敦,驶向美国纽约。
- 4、沉船: 1912年4月15日(1912年4月14日23时40分左右撞击冰山)
船员+乘客人数:2224 - 5、遇难人数: 1502(67.5%)
1.2 目标问题
- 根据训练集中各位乘客的特征及是否获救标志的对应关系训练模型,预测测试集中的乘客是否获救。(
二元分类问题
)
1.3 字段描述
二、训练集(train)建模
- 数据集链接:train.csv
2.1 导入相关库
import numpy as np
import pandas as pd
from scipy import stats# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score# 可视化相关库
import seaborn as sns
import matplotlib.pyplot as plt# 解决mac 系统画图中文不显示问题
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # # 解决win 系统中文不显示问题
# from pylab import mpl
# mpl.rcParams['font.sans-serif']=['SimHei']# 不显示警告
import warnings
warnings.filterwarnings('ignore')
2.2 自定义函数
def PieChart(df):'''绘制环形饼图'''plt.figure(figsize = (4,4), # 设置图片大小dpi = 100 # 精度)df.value_counts().plot( kind = 'pie', # 设置绘图类型为饼图wedgeprops = {'width':0.4}, # 设置空心比例autopct = "%.1f%%" # 显示百分比)def BarPlot(df,ColumnsName):'''绘制不同 ColumnsName 的存活人数柱形图'''ColumnsDf = df.groupby(['Survived',ColumnsName]).count()[['PassengerId']].reset_index()\.rename(columns={"PassengerId":"Count"})plt.figure(figsize=(4,3),dpi=150)sns.barplot(data=ColumnsDf,x=ColumnsName,y="Count",hue="Survived")plt.title('Survived Count Of {}'.format(ColumnsName))def OneHot(x):'''功能:one-hot 编码传入:需要编码的分类变量返回:返回编码后的结果,形式为 dataframe'''# 通过 LabelEncoder 将分类变量打上数值标签 lb = LabelEncoder() # 初始化x_pre = lb.fit_transform(x) # 模型拟合x_dict = dict([[i,j] for i,j in zip(x,x_pre)]) # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}x_num = [[x_dict[i]] for i in x] # 通过 x_dict 将分类变量转为数值型# 进行one-hot编码enc = OneHotEncoder() # 初始化enc.fit(x_num) # 模型拟合array_data = enc.transform(x_num).toarray() # one-hot 编码后的结果,二维数组形式# 转成 dataframe 形式df = pd.DataFrame(array_data)inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值# columns 重命名if type(x) == pd.Series:firs_name = x.nameelse:firs_name = ""df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns] return df
2.3 特征工程
2.3.1 数据导入
train = pd.read_csv("train.csv")
train.head(5)
2.3.2 数据初探
(1)特征信息
train.info()
- 可以看出训练集共有891个样本,且有三个字段(Age、Cabin、Embarked)存在缺失值。
(2)特征缺失值比例统计
train.isnull().sum()/len(train)
- 可以看出,字段Cabin缺失比例较大,达到77%。
(3)数值特征描述统计
train.describe()
- 可以看出,票价(Fare)最低为0,估计是船上的员工。
2.3.3 单特征可视化分析与处理
(1)Survived 是否存活
########################## 1、Survived 是否存活 ##########################
# Y标签,{0:不存活,1:存活}
# 有无缺失值:无
# 数据处理:不处理
# 从图中可以看出,死亡人数与存活人数占比差异不大PieChart(train['Survived'])
(2)Pclass 乘客等级
########################## 2、Pclass 乘客等级 ##########################
# 无缺失值,等级变量
# 用柱状图查看各乘客等级的存活情况
# 可以看出 Pclass=3 的乘客中,存活人数远低于死亡人数
BarPlot(train,"Pclass")# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3
train['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in train['Pclass']]# 查看不同 PclassType 的存活情况
BarPlot(train,"PclassType")# 再对 PclassType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['PclassType']),left_index=True,right_index=True
)
(3)Name 乘客姓名
########################## 3、Name 乘客姓名 ##########################
# 字符串变量
# 有无缺失值:无# 从乘客姓名中获取头街
# 姓名中头街字符串与定义头街类别之间的关系
# Officer: 政府官员,
# RoyaIty: 王室(皇室),
# Mr: 已婚男士,
# Mrs: 已婚女士,
# Miss: 年轻未婚女子,
# Master: 有技能的人/教师
# 新建字段 Title_Dict
Title_Dict = {'Mr':'Mr','Mrs':'Mrs', 'Miss':'Miss','Master': 'Master', 'Don':'Royalty','Rev':'Officer','Dr':')fficer', 'Mme':'Mrs','Ms':'Mrs','Major':'Officer', 'Lady': 'Royalty','Sir': 'Royalty','Mlle':'Miss', 'Col': 'Officer','Capt':'Officer','the Countess': 'Royalty','Jonkheer': 'Royalty','Dona': 'Royalty'
}
train['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in train['Name']] # 对Name进行分类
# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为Mr(已婚男士)中,死亡人数远远大于存活人数;
# 乘客为Mrs(已婚女士)、Miss(年轻未婚女子)中,死亡人数远远低于存活人数;
BarPlot(train,"NameType")# 数据进一步处理:将 NameType 分成三类
# Mr(已婚男士)
# Mrs(已婚女士)、Miss(年轻未婚女子)
# 其他
train['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \for i in train['NameType']]# 查看不同 NameType2 的存活情况
BarPlot(train,"NameType2")# 再对 NameType2 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['NameType2']),left_index=True,right_index=True
)
(4)Sex 性别
########################## 4、Sex 性别 ##########################
# 分类变量
# 有无缺失值:无# 用柱状图查看各 NameType 的存活情况
# 可以看出 乘客为男性中,死亡人数远远大于存活人数
BarPlot(train,"Sex")# 对 Sex 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['Sex']),left_index=True,right_index=True)
(5)Age 年龄
########################## 5、Age 年龄 ##########################
# 连续变量
# 有无缺失值:有,缺失比例19.9%# 缺失值用均值填充
train['Age'] = train['Age'].fillna(train['Age'].mean())# 用直方图查看各 Age 的存活情况
# 可以看出 可以看出小于5岁的小孩存活率很高
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(train[train['Survived']==0]['Age'],color="red",kde=False)
sns.distplot(train[train['Survived']==1]['Age'],color="blue",kde=False)# 数据处理:将 Age 分成两类,Age<=5、Age>5
train['AgeType'] = ["Age<=5" if i <= 5 else "Age>5" for i in train['Age']]# 查看不同 AgeType 的存活情况
BarPlot(train,"AgeType")# 再对 AgeType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['AgeType']),left_index=True,right_index=True
)
(6)SibSp 堂兄弟妹个数
########################## 6、SibSp 堂兄弟妹个数 ##########################
# 无缺失值,等级变量
# 用柱状图查看各堂兄弟妹个数的存活情况
# 可以看出 SibSp=0 的乘客中,死亡人数较多
BarPlot(train,"SibSp")# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0
train['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in train['SibSp']]# 查看不同 SibSpType 的存活情况
BarPlot(train,"SibSpType")# 再对 SibSpType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['SibSpType']),left_index=True,right_index=True
)
(7)Parch 父母与小孩的个数
########################## 7、Parch 父母与小孩的个数 ##########################
# 连续变量
# 有无缺失值:无# 用柱状图查看父母与小孩的个数的存活情况
# 可以看出 Parch=0 的乘客中,死亡人数较多
BarPlot(train,"Parch")# 数据处理:将 Parch 分成两类,Parch=0、Parch>0
train['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in train['Parch']]# 查看不同 ParchType 的存活情况
BarPlot(train,"ParchType")# 再对 ParchType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['ParchType']),left_index=True,right_index=True
)
(8)Ticket 船票信息
- 字符变量
- 有无缺失值:无
- 数据处理:这里直接删去(下文会删)
(9)Fare 票价
########################## 9、Fare 票价 ##########################
# 连续变量
# 有无缺失值:无# 查看Fare(票价)= 0 的生存情况
Fare0 = train[train['Fare']==0]
Fare0Survived = Fare0.groupby(['Survived']).count()[['PassengerId']].reset_index().rename(columns={"PassengerId":"Count"})
plt.figure(figsize=(4,3),dpi=150)
sns.barplot(data=Fare0Survived,x="Survived",y="Count"
)
plt.title('Survived Count Of Fare=0')# 查看Fare(票价)!= 0 的生存情况
Fare1 = train[train['Fare']!=0]
plt.figure(figsize=(8,4),dpi=150)
sns.distplot(Fare1[Fare1['Survived']==0]['Fare'],color="red",kde=False)
sns.distplot(Fare1[Fare1['Survived']==1]['Fare'],color="blue",kde=False)
plt.title('Survived Count Of Fare!=0')# 对 Fare 分成三类
# Fare = 0
# Fare <=50
# Fare > 50
train['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in train['Fare']]# 用柱状图查看不同 FareType 的存活情况
# 可以看出 Fare=0 的乘客中,乘客几乎都死亡
# Fare <=50 的乘客中,死亡人数大于存活人数
# Fare > 50 的乘客中,存活人数大于死亡人数
BarPlot(train,"FareType")# 再对 FareType 进行One-Hot编码处理
train = pd.merge(train,OneHot(train['FareType']),left_index=True,right_index=True
)
(10)Cabin 船舱
- 离散变量
- 有无缺失值:有,缺失值比例高达77%
- 数据处理:缺失值比例较大,直接删去(下文会删)
(11)Embarked 登船的港口
########################## 11、Embarked 登船的港口 ##########################
# 离散变量
# 有无缺失值:有,缺失值比例很低# 用柱状图查看各登船的港口的存活情况
# 可以看出 Embarked=S 的乘客中,死亡人数较多
BarPlot(train,"Embarked")# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理
mode = stats.mode(train['Embarked'])[0][0] # 众数
train['Embarked'] = train['Embarked'].fillna(mode)train = pd.merge(train,OneHot(train['Embarked']),left_index=True,right_index=True)
2.3.4 衍生特征可视化分析与处理
FamilyNumbers 家庭人数
########################## FamilyNumbers 家庭人数 ##########################
# 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)
train['FamilyNumbers'] = train['SibSp'] + train['Parch'] + 1# 用柱状图查看各家庭人数的存活情况
# 可以看出 家庭人数=1 的乘客中,死亡人数较多
# 家庭人数>=5 的乘客中,存活人数较多
BarPlot(train,"FamilyNumbers")# 新增 FamilyType 字段
# 1 : 单身(Single)
# 2-4:小家庭(Family_Small)
# >4: 大家庭(Family_Large)
train['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in train['FamilyNumbers']]# 查看不同 FamilyType 的存活情况
BarPlot(train,"FamilyType")# 对 FamilyType 进行One-hot编码处理
train = pd.merge(train,OneHot(train['FamilyType']),left_index=True,right_index=True)
2.3.5 删除冗余字段
drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\'FamilyNumbers','FamilyType']
train.drop(drop_columns,axis=1,inplace=True)
2.3.6 相关性矩阵可视化
- 采用斯皮尔曼相关系数
corr_df = train.corr(method="spearman")[['Survived']].sort_values(by="Survived",ascending=False)
plt.figure(figsize=(1,8),dpi=100)
sns.heatmap(corr_df,cmap='Blues',center=0,vmax=1,vmin=-1,annot=True,annot_kws={'size':10,'weight':'bold', 'color':'red'}
)
2.4 决策树模型训练
2.4.1 数据标准化(Z-score)
def ZscoreNormalization(x):'''Z-score 标准化'''return (x - np.mean(x)) / np.std(x)data = train.drop("Survived",axis=1).agg(ZscoreNormalization)
data['Lable'] = train['Survived']
2.4.2 划分训练集、测试集
- 按7:3比例划分
x_train, x_test, y_train, y_test = train_test_split(data.drop("Lable",axis=1), data['Lable'], test_size = 0.3, random_state = 0
)
2.4.3 网格寻参与交叉验证
param_grid = {'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”'splitter' : ['best','random'], # 划分方式:{“best”, “random”}, default=”best”'max_depth' : range(1,6), # 最大深度'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数'min_samples_leaf' : range(1,6), # 叶节点所需的最小样本数
}
clf = DecisionTreeClassifier() # 初始化
gs = GridSearchCV(clf,param_grid,cv=5) # 网格搜索与交叉验证
gs.fit(x_train,y_train) # 模型训练
print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器
print("Best Score: ",gs.best_score_) # 打印最好分数
注意: 每次运行的结果输出会存在差别。
2.4.4 模型评价
print("\n---------- 模型评价 ----------")
y_pred = gs.predict(x_test) # 预测
cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵
df_cm = pd.DataFrame(cm) # 构建DataFrame
print('Accuracy score:', accuracy_score(y_test, y_pred)) # 准确率
print('Recall:', recall_score(y_test, y_pred, average='weighted')) # 召回率
print('F1-score:', f1_score(y_test, y_pred, average='weighted')) # F1分数
print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度
2.4.5 混淆矩阵可视化
plt.figure(dpi=150)heatmap = sns.heatmap(df_cm, annot=True, fmt='.0f', cmap='Blues')
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right')
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=0, ha='right')plt.title('DecisionTreeClassifier Model Results')
plt.show()
2.4.6 ROC曲线
y_pred_proba = gs.predict_proba(np.array(x_test))[:,1]
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)sns.set()plt.figure(figsize=(5,4),dpi=150)
plt.plot(fpr, tpr)
plt.plot(fpr, fpr, linestyle = '-' , color = 'k')plt.xlabel('False positive rate')
plt.ylabel('True positive rate')AU = np.round(roc_auc_score(y_test, y_pred_proba), 2)plt.title(f'AU: {AU}');plt.show()
三、完整代码(含对test预测)
- 含预测,不含可视化
import numpy as np
import pandas as pd
from scipy import stats# sklearn 相关库
from sklearn.tree import DecisionTreeClassifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder,OneHotEncoder
from sklearn.model_selection import train_test_split,GridSearchCV
from sklearn.metrics import confusion_matrix,accuracy_score,roc_curve, roc_auc_score
from sklearn.metrics import accuracy_score, recall_score, f1_score, precision_score# 不显示红色警告
import warnings
warnings.filterwarnings('ignore')def OneHot(x):'''功能:one-hot 编码传入:需要编码的分类变量返回:返回编码后的结果,形式为 dataframe'''# 通过 LabelEncoder 将分类变量打上数值标签 lb = LabelEncoder() # 初始化x_pre = lb.fit_transform(x) # 模型拟合x_dict = dict([[i,j] for i,j in zip(x,x_pre)]) # 生成编码字典--> {'收藏': 1, '点赞': 2, '关注': 0}x_num = [[x_dict[i]] for i in x] # 通过 x_dict 将分类变量转为数值型# 进行one-hot编码enc = OneHotEncoder() # 初始化enc.fit(x_num) # 模型拟合array_data = enc.transform(x_num).toarray() # one-hot 编码后的结果,二维数组形式# 转成 dataframe 形式df = pd.DataFrame(array_data)inverse_dict = dict([val,key] for key,val in x_dict.items()) # 反转 x_dict 的键、值# columns 重命名if type(x) == pd.Series:firs_name = x.nameelse:firs_name = ""df.columns = [firs_name+"_"+inverse_dict[i] for i in df.columns] return dfdef ZscoreNormalization(x):'''Z-score 标准化'''return (x - np.mean(x)) / np.std(x)def DataClean(df,Lable=True):'''数据预处理函数'''########################## 1、Pclass 乘客等级 ########################### 无缺失值,等级变量# 数据处理:将Pclass分成两类,Pclass>=3、Pclass<3df['PclassType'] = ["Pclass>=3" if i >= 3 else "Pclass<3" for i in df['Pclass']]# 再对 PclassType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['PclassType']),left_index=True,right_index=True) ########################## 2、Name 乘客姓名 ########################### 字符串变量# 有无缺失值:无# 从乘客姓名中获取头街# 姓名中头街字符串与定义头街类别之间的关系# Officer: 政府官员,# RoyaIty: 王室(皇室),# Mr: 已婚男士,# Mrs: 已婚女士,# Miss: 年轻未婚女子,# Master: 有技能的人/教师 # 新建字段 Title_Dict Title_Dict = {'Mr':'Mr','Mrs':'Mrs', 'Miss':'Miss','Master': 'Master', 'Don':'Royalty','Rev':'Officer','Dr':')fficer', 'Mme':'Mrs','Ms':'Mrs','Major':'Officer', 'Lady': 'Royalty','Sir': 'Royalty','Mlle':'Miss', 'Col': 'Officer','Capt':'Officer','the Countess': 'Royalty','Jonkheer': 'Royalty','Dona': 'Royalty'}df['NameType'] = [Title_Dict[i.split(".")[0].split(", ")[-1]] for i in df['Name']] # 对Name进行分类# 数据进一步处理:将 NameType 分成三类# Mr(已婚男士)# Mrs(已婚女士)、Miss(年轻未婚女子)# 其他df['NameType2'] = ["Mr" if i == "Mr" else ("Mrs and Miss" if i in ['Mrs','Miss'] else "Other") \for i in df['NameType']]# 再对 NameType2 进行One-Hot编码处理df = pd.merge(df,OneHot(df['NameType2']),left_index=True,right_index=True) ########################## 3、Sex 性别 ########################### 分类变量# 有无缺失值:无# 对 Sex 进行One-Hot编码处理df = pd.merge(df,OneHot(df['Sex']),left_index=True,right_index=True)########################## 4、Age 年龄 ########################### 连续变量# 有无缺失值:有,缺失比例19.9%# 缺失值用均值填充df['Age'] = df['Age'].fillna(df['Age'].mean())# 数据处理:将 Age 分成两类,Age<=5、Age>5df['AgeType'] = ["Age<=5" if i <= 5 else "Age>5" for i in df['Age']]# 再对 AgeType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['AgeType']),left_index=True,right_index=True) ########################## 5、SibSp 堂兄弟妹个数 ########################### 无缺失值,等级变量# 数据处理:将 SibSp 分成两类,SibSp=0、SibSp>0df['SibSpType'] = ["SibSp=0" if i == 0 else "SibSp>0" for i in df['SibSp']]# 再对 SibSpType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['SibSpType']),left_index=True,right_index=True) ########################## 6、Parch 父母与小孩的个数 ########################### 连续变量# 有无缺失值:无# 数据处理:将 Parch 分成两类,Parch=0、Parch>0df['ParchType'] = ["Parch=0" if i == 0 else "Parch>0" for i in df['Parch']]# 再对 ParchType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['ParchType']),left_index=True,right_index=True) ########################## 8、Fare 票价 ########################### 连续变量# 有无缺失值:无# 对 Fare 分成三类# Fare = 0# Fare <=50# Fare > 50df['FareType'] = ["Fare=0" if i == 0 else ("Fare<=50" if i <= 50 else "Fare>50") for i in df['Fare']]# 再对 FareType 进行One-Hot编码处理df = pd.merge(df,OneHot(df['FareType']),left_index=True,right_index=True) ########################## 10、Embarked 登船的港口 ########################### 离散变量# 有无缺失值:有,缺失值比例很低# 数据处理:缺失值按众数填充,然后再进行One-hot编码处理mode = stats.mode(df['Embarked'])[0][0] # 众数df['Embarked'] = df['Embarked'].fillna(mode)df = pd.merge(df,OneHot(df['Embarked']),left_index=True,right_index=True)########################## 11、FamilyNumbers 家庭人数 ########################### 计算方式:SibSp(堂兄弟妹个数) + Parch(父母与小孩的个数) + 1(自己)df['FamilyNumbers'] = df['SibSp'] + df['Parch'] + 1# 新增 FamilyType 字段# 1 : 单身(Single) # 2-4:小家庭(Family_Small)# >4: 大家庭(Family_Large)df['FamilyType'] = ['Single' if i == 1 else('Family_Small' if i<=4 else 'Family_Large') for i in df['FamilyNumbers']]# 对 FamilyType 进行One-hot编码处理df = pd.merge(df,OneHot(df['FamilyType']),left_index=True,right_index=True)########################## 删除冗余变量 ##########################drop_columns = ['PassengerId','Pclass','PclassType','Name','NameType','NameType2','Sex','Age','AgeType',\'SibSp','SibSpType','Parch','ParchType','Fare','FareType','Ticket','Cabin','Embarked',\'FamilyNumbers','FamilyType']df.drop(drop_columns,axis=1,inplace=True)########################## 数据标准化 ##########################if Lable == True: # 判断是否是测试集(测试集不含标签)data = df.drop("Survived",axis=1).agg(ZscoreNormalization)data['Lable'] = df['Survived']else:data = df.agg(ZscoreNormalization)return datadef sklearn_DecisionTreeClassifier(data):'''决策树二分类'''# 划分训练集、测试集x_train, x_test, y_train, y_test = train_test_split(data.drop("Lable",axis=1), data['Lable'], test_size = 0.3, random_state = 0)print("\n---------- 模型训练 ----------")# 网格寻参param_grid = {'criterion' : ['gini','entropy'], # 划分属性时选用的准则:{“gini”, “entropy”}, default=”gini”'splitter' : ['best','random'], # 划分方式:{“best”, “random”}, default=”best”'max_depth' : range(1,6), # 最大深度'min_samples_split' : range(1,6), # 拆分内部节点所需的最小样本数'min_samples_leaf' : range(1,6), # 叶节点所需的最小样本数}clf = DecisionTreeClassifier() # 初始化gs = GridSearchCV(clf,param_grid,cv=5) # 网格搜索与交叉验证gs.fit(x_train,y_train) # 模型训练print("Best Estimator: ",gs.best_estimator_) # 打印最好的分类器print("Best Score: ",gs.best_score_) # 打印最好分数# 模型预测print("\n---------- 模型评价 ----------")y_pred = gs.predict(x_test) # 预测cm = confusion_matrix(y_test, y_pred,labels=[0, 1]) # 混淆矩阵df_cm = pd.DataFrame(cm) # 构建DataFrameprint('Accuracy score:', accuracy_score(y_test, y_pred)) # 准确率print('Recall:', recall_score(y_test, y_pred, average='weighted')) # 召回率print('F1-score:', f1_score(y_test, y_pred, average='weighted')) # F1分数print('Precision score:', precision_score(y_test, y_pred, average='weighted')) # 精确度return gs.best_estimator_ # 返回最好的训练模型if __name__ == "__main__":train = pd.read_csv("train.csv")test = pd.read_csv("test.csv")print("\n---------- 数据预处理 ----------")train_data = DataClean(train) test_data = DataClean(test,Lable=False) # 决策树二分类best_estimator = sklearn_DecisionTreeClassifier(train_data) # 预测y_pred = best_estimator.predict(test_data) # 输出预测结果result = test[['PassengerId']]result['Survived'] = y_predresult.to_csv("Titanic Results.csv",index=False)print("\n程序运行完成")
四、Kaggle 得分
- 得分:0.77511
- 排名:7651
参考:
1、Kaggle泰坦尼克号比赛项目详解
2、机器学习实战——kaggle 泰坦尼克号生存预测——六种算法模型实现与比较
python 决策树分类 泰坦尼克生存预测相关推荐
- kaggle房价预测特征意思_机器学习-kaggle泰坦尼克生存预测(一)-数据清洗与特征构建...
1.背景: 1.1 关于kaggle: 谷歌旗下的 Kaggle 是一个数据建模和数据分析竞赛平台.该平台是当下最流行的数据科研赛事平台,其组织的赛事受到全球数据科学爱好者追捧. 如果学生能够在该平台 ...
- MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测
Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...
- 实验六:泰坦尼克生存预测之缺失值处理
一.任务描述 背景故事: 泰坦尼克号(RMS Titanic),又译作铁达尼号,是英国白星航运公司下辖的一艘奥林匹克级游轮,排水量46000吨,于1909年3月31日在北爱尔兰贝尔法斯特港的哈兰德与沃 ...
- Python项目实战-Tensorflow2.0实现泰坦尼克生存预测
目录 一.数据集下载地址 二.探索性因子分析(EDA) 三.特征工程 四.构建Dataset与Model fit和自定义estimator使用 预定义estimator的使用 一.数据集下载地址 # ...
- 决策树实例-泰坦尼克幸存者预测
- 基于逻辑回归的泰坦尼克生存预测
- Kaggle实战:泰坦尼克幸存者预测 - 上
(文章同步更新于个人博客@dai98.github.io) 源代码: Github Kaggle 泰坦尼克幸存者预测是Kaggle上数据竞赛的入门级别的比赛,我曾经在一年前作为作业参加过这个比赛,我想 ...
- Kaggle实战:泰坦尼克幸存者预测 -下
(文章同步更新于个人博客@dai98.github.io) 源代码:Github 上一篇文章介绍了如何使用深度学习来预测泰坦尼克号幸存者,这一部分使用多分类器投票来做.由于数据预处理部分比较相似,重复 ...
- 泰坦尼克号生存预测python_用Python预测泰坦尼克生存情况-附数据集
介绍:通过逻辑回归算法,解决kaggle网站上的泰坦尼克生存情况预测问题,准确率在80%左右. 一.提出问题 什么样的人在泰坦尼克号中更容易存活? 二.理解数据 2.1 数据来源 数据来自kaggle ...
- Kaggle初体验之泰坦尼特生存预测
Kaggle初体验之泰坦尼特生存预测 学习完了决策树的ID3.C4.5.CART算法,找一个试手的地方,Kaggle的练习赛泰坦尼特很不错,记录下 流程 首先注册一个账号,然后在顶部菜单栏Co ...
最新文章
- python note
- navicat保存查询语句_MySQL数据库安装创建及Navicat客户端连接
- 父亲和女儿同为互联网大佬, 但不幸都得癌症
- ajax请求php的过程,php如何实现ajax请求
- poj1463 Strategic game
- Nginx设置TCP上游服务器的SSL配置
- 深度 | 从各种注意力机制窥探深度学习在NLP中的神威
- 语音转写可实时,直播也能同步字幕
- 反激变压器结构设计学习笔记(进阶)
- Sketchup2019安装包安装教程
- 苹果IOS隐藏复制链接等按钮失效及报错the permission value is offline verifying解决办法
- ospfdr选举规则_DR/BDR详细选举过程
- abaqus个人总结 各种问题各种debug
- shell 脚本学习
- WPP和iHeartMedia推出“聆听项目”
- java word apache poi 操作word模板。
- 0xC000005:Access Violation和指针强制转换问题
- 区块链培训就业方向多不多?
- 单片机开发—呼吸灯的三种实现方法
- Safari浏览器意外退出无法重新启动怎么办?
热门文章
- Chrome主页被强制修改为百度解决办法
- [CGAL]建立一个正四面体
- storm架构及原理详解
- 2022 python获取和风天气 web api v7版本
- java中对手机号、邮箱等隐私信息脱敏展示,如手机号138****8888。
- ArcGIS Server 发布地图服务遇到的问题
- 技术帖:如何把mobi文件转化成pdf
- wap pc html,PCWAP手机PC网站信息管理系统 v1.4.3
- 深度学习:GAN 对抗网络原理详细解析(零基础必看)
- 详解 LVS、Nginx 及 HAProxy 工作原理