Bart模型应用实例及解析(二)————基于泰坦尼克号数据集的分类模型

  • 前言
    • 一、数据集
      • 1、数据集的获取
      • 2、数据集变量名及意义
      • 3、数据集处理
    • 二、完整代码
    • 三、代码运行结果及解析
      • 1.数据描述性分析
      • 2.建立Bart模型以及分析
      • 3.各模型效果对比
    • 特别声明

前言

这里是在实战中使用Bart模型对数据进行建模及分析,如果有读者对如何建模以及建模函数的参数不了解,对建模后的结果里的参数不清楚的话,可以参考学习作者前面两篇文章内容。以便更好地理解模型、建模过程及思想。

R bartMachine包内bartMachine函数参数详解
https://blog.csdn.net/qq_35674953/article/details/115774921

BartMachine函数建模结果参数解析
https://blog.csdn.net/qq_35674953/article/details/115804662


提示:以下是本篇文章正文内容

一、数据集

1、数据集的获取

链接:https://pan.baidu.com/s/1TZejG8fZTS35RQctwtTn-Q
提取码:h6sv
数据部分截图:

2、数据集变量名及意义

变量名 意义
Survived 分类变量,是否死亡。0代表死亡,1代表存活
Pclass 乘客所持票类,有三种值(1,2,3)
Name 乘客姓名
Sex 乘客性别
Age 乘客年龄(有缺失)
SibSp 乘客兄弟姐妹/配偶的个数(整数值)
Parch 乘客父母/孩子的个数(整数值)
Ticket 票号(字符串)
Fare 乘客所持票的价格(浮点数,0-500不等)
Cabin 乘客所在船舱(有缺失)
Embark 乘客登船港口:S、C、Q(有缺失)。赋值1,、2、3。

3、数据集处理

由实际经验,作者认为自变量乘客姓名(Name)、船票票号(Ticket)、船舱号(Cabin)对因变量是否存活(Survived)没有影响,所以删去这几个变量。对于变量年龄(Age)的缺失值,由于数据集比较大,就删去了有缺失数据。

二、完整代码

代码如下(示例):

options(java.parameters = "-Xmx10g")library(ggplot2)
library(bartMachine)
library(reshape2)
library(knitr)
library(ggplot2)
library(GGally)
library(scales)
percent((1:5) / 100)##读取数据
data<-read.csv(file="C:/Users/LHW/Desktop/tt.csv",header=T,sep=",")
head(data)
n=dim(data)
nda<-melt(data)#画出数据箱线图
ggplot(da, aes(x=variable, y=value, fill=variable))+ geom_boxplot()+facet_wrap(~variable,scales="free")#画出数据直方图
ggplot(da, aes(value, fill=variable))+ geom_histogram()+facet_wrap(~variable,scales="free")cormat <- round(cor(data[,2:8]), 2)
head(cormat)
melted_cormat <- melt(cormat)
head(melted_cormat)# 把一侧三角形的值转化为NA
get_upper_tri <- function(cormat){cormat[lower.tri(cormat)]<- NAreturn(cormat)
}
upper_tri <- get_upper_tri(cormat)
upper_tri#转化为矩阵
library(reshape2)
melted_cormat <- melt(upper_tri,na.rm = T)#作相关系数热力图
ggplot(data = melted_cormat, aes(x=Var2, y=Var1, fill = value)) +geom_tile(color = "white") +scale_fill_gradient2(low = "blue", high = "red", mid = "white",midpoint = 0, limit = c(-1, 1), space = "Lab",name="Pearson\nCorrelation") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1)) +coord_fixed() + geom_text(aes(Var2, Var1, label = value), color = "black", size = 4) #随机种子
set.seed(1000)
#按照80%和20%比例划分训练集和测试集
index2=sample(x=2,size=nrow(data),replace=TRUE,prob=c(0.8,0.2))#训练集
train2=data[index2==1,]
head(train2)
x=train2[,-c(1)]
y=train2[,1]
y = factor(y)#预测集
data2=data[index2==2,]
x.test_data=data2[,-c(1)]
head(data2)
xp=x.test_data
yp=data2[,1]
yp = factor(yp)#建立Bart模型
res = bartMachine(x,y,prob_rule_class = 0.5)
print(res)
rm<-res$confusion_matrix#计算精度、查准率、查全率
A=(rm[1,1]+rm[2,2])/length(y)
cat("精度为:",percent(A,accuracy = 0.01), "\n")
P=rm[2,2]/(rm[2,2]+rm[1,2])
cat("查准率为:",percent(P,accuracy = 0.01), "\n")
R=rm[2,2]/(rm[2,2]+rm[2,1])
cat("查全率为:",percent(R,accuracy = 0.01), "\n")#对预测集进行预测
resp1<-predict(res,new_data=xp,type = "class", prob_rule_class = 0.5)#画出预测集的混淆矩阵
tp=0
fn=0
fp=0
tn=0
for(i in 1:length(resp1)){if(resp1[i]==1){if(yp[i]==1){tp=tp+1}else{fn=fn+1}}else{if(yp[i]==1){fp=fp+1}else{tn=tn+1}}
}# 定义行和列的名称
rownames = c("正例", "反例")
colnames = c("正例", "反例")m <- matrix(c(tp,fn,fp,tn), nrow = 2, byrow = TRUE, dimnames = list(rownames, colnames))
print(m)#计算精度、查准率、查全率
A_p=(tp+tn)/length(resp1)
cat("精度为:",percent(A_p,accuracy = 0.01), "\n")
P_p=tp/(tp+fp)
cat("查准率为:",percent(P_p,accuracy = 0.01), "\n")
R_p=tp/(tp+fn)
cat("查全率为:",percent(R_p,accuracy = 0.01), "\n")#用十折交叉验证,选出最佳先验参数
bmcv<-bartMachineCV(X = x, y = y,num_tree_cvs = c(50, 100,150), k_cvs = c(2, 3, 5),nu_q_cvs = NULL, k_folds = 10, verbose = FALSE)
print(bmcv)
bmcv$cv_statsbm<-bmcv$confusion_matrix#计算精度、查准率、查全率
A_cv=(bm[1,1]+bm[2,2])/length(y)
cat("精度为:",percent(A,accuracy = 0.01), "\n")
P_cv=bm[2,2]/(bm[2,2]+bm[1,2])
cat("查准率为:",percent(P,accuracy = 0.01), "\n")
R_cv=bm[2,2]/(bm[2,2]+bm[2,1])
cat("查全率为:",percent(R,accuracy = 0.01), "\n")#对预测集进行预测
resp2<-predict(bmcv,new_data=xp,type = "class", prob_rule_class = 0.5)#画出预测集的混淆矩阵
tp_cv=0
fn_cv=0
fp_cv=0
tn_cv=0
for(i in 1:length(resp2)){if(resp1[i]==1){if(yp[i]==1){tp_cv=tp_cv+1}else{fn_cv=fn_cv+1}}else{if(yp[i]==1){fp_cv=fp_cv+1}else{tn_cv=tn_cv+1}}
}# 定义行和列的名称
rownames = c("正例", "反例")
colnames = c("正例", "反例")m_cv <- matrix(c(tp_cv,fn_cv,fp_cv,tn_cv), nrow = 2, byrow = TRUE, dimnames = list(rownames, colnames))
print(m_cv)#计算精度、查准率、查全率
A_cv_p=(tp_cv+tn_cv)/length(resp1)
cat("精度为:",percent(A_cv_p,accuracy = 0.01), "\n")
P_cv_p=tp/(tp_cv+fp_cv)
cat("查准率为:",percent(P_cv_p,accuracy = 0.01), "\n")
R_cv_p=tp/(tp_cv+fn_cv)
cat("查全率为:",percent(R_cv_p,accuracy = 0.01), "\n")#查看迭代时模型的变化
pcd<-plot_convergence_diagnostics(bmcv,plots = c("sigsqs", "mh_acceptance", "num_nodes", "tree_depths"))#返回每一颗回归树的信息
er<-extract_raw_node_data(bmcv, g = 1)
head(er,1)# 计算每个变量在树中出现的次数。
ivi<-investigate_var_importance(bmcv, type = "trees",plot = TRUE, num_replicates_for_avg = 5, num_trees_bottleneck = 20,num_var_plot = Inf, bottom_margin = 10)
ivi#画出变量的部分依赖图,可以用来展示一个特征是怎样影响模型预测的。这里展示第一个变量的部分依赖图pd<-pd_plot(bmcv, 6,levs = c(0.05, seq(from = 0.1, to = 0.9, by = 0.05), 0.95),lower_ci = 0.025, upper_ci = 0.975, prop_data = 1)

三、代码运行结果及解析

1.数据描述性分析

options(java.parameters = "-Xmx10g")library(ggplot2)
library(bartMachine)
library(reshape2)
library(knitr)
library(ggplot2)
library(GGally)
library(scales)
percent((1:5) / 100)##读取数据
data<-read.csv(file="C:/Users/LHW/Desktop/tt.csv",header=T,sep=",")
head(data)


数据集前六行数据展示。

n=dim(data)
n

显示数据集维度,数据集七自变量,一个因变量(Survived),一共八列。1043行数据样本。

da<-melt(data)#画出数据箱线图
ggplot(da, aes(x=variable, y=value, fill=variable))+ geom_boxplot()+facet_wrap(~variable,scales="free")

对数据的描述性统计,画出的八列数据的箱线图。从图中可以看出变量(Age、SibSp、Parch、Fare)有离群值。

#画出数据直方图
ggplot(da, aes(value, fill=variable))+ geom_histogram()+facet_wrap(~variable,scales="free")

对数据的描述性统计,画出的八列数据的直方图。可以跟直观看出各个分类变量数据分布。

cormat <- round(cor(data[,2:8]), 2)
head(cormat)

自变量的相关系数矩阵。

melted_cormat <- melt(cormat)
head(melted_cormat)

变换数据形式,以便用ggplot画图。


# 把一侧三角形的值转化为NA
get_upper_tri <- function(cormat){cormat[lower.tri(cormat)]<- NAreturn(cormat)
}
upper_tri <- get_upper_tri(cormat)
upper_tri

把自变量的相关系数矩阵一侧三角形的值转化为NA,方便画出相关系数热力图。

#转化为矩阵
library(reshape2)
melted_cormat <- melt(upper_tri,na.rm = T)#作相关系数热力图
ggplot(data = melted_cormat, aes(x=Var2, y=Var1, fill = value)) +geom_tile(color = "white") +scale_fill_gradient2(low = "blue", high = "red", mid = "white",midpoint = 0, limit = c(-1, 1), space = "Lab",name="Pearson\nCorrelation") +theme_minimal() +theme(axis.text.x = element_text(angle = 45, vjust = 1, size = 12, hjust = 1)) +coord_fixed() + geom_text(aes(Var2, Var1, label = value), color = "black", size = 4) 

画出热力图,相关系数矩阵热力图,相关系数范围[-1,1],颜色越红,相关系数就越接近于1,正相关性越高;颜色越蓝,相关系数就越接近于-1,负相关性越高。从图中可以看出Parch与SibSp,正相关性比较高;Fare与Pclass、Age与Pclass,负相关性比较高。

2.建立Bart模型以及分析

#随机种子
set.seed(1000)
#按照80%和20%比例划分训练集和测试集
index2=sample(x=2,size=nrow(data),replace=TRUE,prob=c(0.8,0.2))#训练集
train2=data[index2==1,]
head(train2)
x=train2[,-c(1)]
y=train2[,1]
y = factor(y)

训练集数据集展示。

#预测集
data2=data[index2==2,]
x.test_data=data2[,-c(1)]
head(data2)
xp=x.test_data
yp=data2[,1]
yp = factor(yp)

测试集数据集展示。

#建立Bart模型
res = bartMachine(x,y,prob_rule_class = 0.5)
print(res)

结果返回了训练样本数据集维度,以及分类预测结果的混淆矩阵。

rm<-res$confusion_matrix#计算精度、查准率、查全率
A=(rm[1,1]+rm[2,2])/length(y)
cat("精度为:",percent(A,accuracy = 0.01), "\n")
P=rm[2,2]/(rm[2,2]+rm[1,2])
cat("查准率为:",percent(P,accuracy = 0.01), "\n")
R=rm[2,2]/(rm[2,2]+rm[2,1])
cat("查全率为:",percent(R,accuracy = 0.01), "\n")



可以看出精度较高,查全率、查准率表现都很可以,说明模型拟合得不错。

#对预测集进行预测
resp1<-predict(res,new_data=xp,type = "class", prob_rule_class = 0.5)#画出预测集的混淆矩阵
tp=0
fn=0
fp=0
tn=0
for(i in 1:length(resp1)){if(resp1[i]==1){if(yp[i]==1){tp=tp+1}else{fn=fn+1}}else{if(yp[i]==1){fp=fp+1}else{tn=tn+1}}
}# 定义行和列的名称
rownames = c("正例", "反例")
colnames = c("正例", "反例")m <- matrix(c(tp,fn,fp,tn), nrow = 2, byrow = TRUE, dimnames = list(rownames, colnames))
print(m)

这里返回了,预测集的混淆矩阵。

#计算精度、查准率、查全率
A_p=(tp+tn)/length(resp1)
cat("精度为:",percent(A_p,accuracy = 0.01), "\n")
P_p=tp/(tp+fp)
cat("查准率为:",percent(P_p,accuracy = 0.01), "\n")
R_p=tp/(tp+fn)
cat("查全率为:",percent(R_p,accuracy = 0.01), "\n")



可以看出精度较高,查全率、查准率表现都很可以,说明模型拟合得不错。

#用十折交叉验证,选出最佳先验参数
bmcv<-bartMachineCV(X = x, y = y,num_tree_cvs = c(50, 100,150), k_cvs = c(2, 3, 5),nu_q_cvs = NULL, k_folds = 10, verbose = FALSE)
print(bmcv)

结果返回了训练样本数据集维度,以及分类预测结果的混淆矩阵。

#交叉验证参数及训练结果汇总
print(bmcv$cv_stats)


我们选取的是第一行参数进行建模。

bm<-bmcv$confusion_matrix#计算精度、查准率、查全率
A_cv=(bm[1,1]+bm[2,2])/length(y)
cat("精度为:",percent(A,accuracy = 0.01), "\n")
P_cv=bm[2,2]/(bm[2,2]+bm[1,2])
cat("查准率为:",percent(P,accuracy = 0.01), "\n")
R_cv=bm[2,2]/(bm[2,2]+bm[2,1])
cat("查全率为:",percent(R,accuracy = 0.01), "\n")



可以看出精度较高,查全率、查准率表现都比较不错,说明模型拟合比较好。

#用最佳参数模型对预测集进行预测
resp2<-predict(bmcv,new_data=xp,type = "class", prob_rule_class = 0.5)#画出预测集的混淆矩阵
tp_cv=0
fn_cv=0
fp_cv=0
tn_cv=0
for(i in 1:length(resp2)){if(resp1[i]==1){if(yp[i]==1){tp_cv=tp_cv+1}else{fn_cv=fn_cv+1}}else{if(yp[i]==1){fp_cv=fp_cv+1}else{tn_cv=tn_cv+1}}
}# 定义行和列的名称
rownames = c("正例", "反例")
colnames = c("正例", "反例")m_cv <- matrix(c(tp_cv,fn_cv,fp_cv,tn_cv), nrow = 2, byrow = TRUE, dimnames = list(rownames, colnames))
print(m_cv)

这里返回了,预测集的混淆矩阵。

#计算精度、查准率、查全率
A_cv_p=(tp_cv+tn_cv)/length(resp1)
cat("精度为:",percent(A_cv_p,accuracy = 0.01), "\n")
P_cv_p=tp/(tp_cv+fp_cv)
cat("查准率为:",percent(P_cv_p,accuracy = 0.01), "\n")
R_cv_p=tp/(tp_cv+fn_cv)
cat("查全率为:",percent(R_cv_p,accuracy = 0.01), "\n")




可以看出精度较高,查全率、查准率表现都比较不错,说明模型拟合比较好。

#查看迭代时模型的变化
pcd<-plot_convergence_diagnostics(bmcv,plots = c("sigsqs", "mh_acceptance", "num_nodes", "tree_depths"))

上图为评估 BART 模型的收敛和特征的一组图,竖线前是被丢弃的抽样样本。
“Percent acceptance”选项绘制每个吉布斯样本接受的Metropolis Hastings步骤的比例,从图中可以看出接受率随着抽样样本增加较为稳定;
“Tree Num nodes”选项根据Gibbs样本数绘制树和模型中每棵树上的平均节点数 ,节点数随着抽样样本增加较为稳定。蓝线是所有树上的平均节点数;
“tree depth”选项根据Gibbs样本数在树和模型中绘制每棵树的平均树深。蓝线是所有树上的平均节点数。

#返回每一颗回归树的信息
er<-extract_raw_node_data(bmcv, g = 1)
er

部分结果:

由于篇幅较长,这里只展示了第一棵树的部分信息,更多信息可以自己运行代码进行查看。

# 计算每个变量在树中出现的次数。
ivi<-investigate_var_importance(bmcv, type = "trees",plot = TRUE, num_replicates_for_avg = 5, num_trees_bottleneck = 20,num_var_plot = Inf, bottom_margin = 10)
ivi

算出BART模型的变量被包含在树里的比例,了解不同协变量的相对影响。在图中,红条对应的是每一个变量比例的标准误差。用此来表示每一个变量的重要程度。

图中数据为每个变量比例的具体数值。

#画出变量的部分依赖图,可以用来展示一个特征是怎样影响模型预测的。这里展示第一个变量的部分依赖图pd<-pd_plot(bmcv, 6,levs = c(0.05, seq(from = 0.1, to = 0.9, by = 0.05), 0.95),lower_ci = 0.025, upper_ci = 0.975, prop_data = 1)

可以看出第六列(Fare)变量,数据集中在0到140,其他变量不变,预测值随着自变量的增加,变化不大。

3.各模型效果对比

模型 文中命名 精度 查准率 查全率 预测集精度 预测集查准率 预测集查全率
默认参数模型 res 85.97% 86.32% 77.49% 83.59% 72.60% 81.54%
最佳先验参数模型 bmcv 85.97% 86.32% 77.49% 83.59% 72.60% 81.54%

由不同模型对比可以看出,在交叉验证选择最佳参数后对模型没有改进,预测效果也相同。从交叉验证参数及训练结果也可以看出不同的模型参数对模型的影响比较小。

特别声明

作者也是初学者,水平有限,文章中会存在一定的缺点和谬误,恳请读者多多批评、指正和交流!

Bart模型应用实例及解析(二)————基于泰坦尼克号数据集的分类模型相关推荐

  1. 寺冈labelnet使用说明_基于imagenet数据集的ResNet50模型训练示例

    基于imagenet数据集的ResNet50模型训练示例 训练前准备 数据集获取 本训练示例以imagenet数据集为例,从imagenet官方网站http://www.image-net.org/获 ...

  2. R语言使用yardstick包的conf_mat函数计算多分类(Multiclass)模型的混淆矩阵、并使用summary函数基于混淆矩阵输出分类模型评估的其它详细指标(kappa、npv等13个)

    R语言使用yardstick包的conf_mat函数计算多分类(Multiclass)模型的混淆矩阵(confusion matrix).并使用summary函数基于混淆矩阵输出分类模型评估的其它详细 ...

  3. 15 分钟搭建一个基于XLNET的文本分类模型——keras实战

    今天笔者将简要介绍一下后bert 时代中一个又一比较重要的预训练的语言模型--XLNET ,下图是XLNET在中文问答数据集CMRC 2018数据集(哈工大讯飞联合实验室发布的中文机器阅读理解数据,形 ...

  4. elm分类器功能_基于ELM的情绪分类模型研究

    龙源期刊网 http://www.qikan.com.cn 基于 ELM 的情绪分类模型研究 作者:陈珊 来源:<价值工程> 2017 年第 04 期 摘要: 采用计算机进行情绪判断对实现 ...

  5. 从零开始构建基于textcnn的文本分类模型(上),word2vec向量训练,预训练词向量模型加载,pytorch Dataset、collete_fn、Dataloader转换数据集并行加载

    伴随着bert.transformer模型的提出,文本预训练模型应用于各项NLP任务.文本分类任务是最基础的NLP任务,本文回顾最先采用CNN用于文本分类之一的textcnn模型,意在巩固分词.词向量 ...

  6. EL之Bagging:kaggle比赛之利用泰坦尼克号数据集建立Bagging模型对每个人进行获救是否预测

    EL之Bagging:kaggle比赛之利用泰坦尼克号数据集建立Bagging模型对每个人进行获救是否预测 目录 输出结果 设计思路 核心代码 输出结果 设计思路 核心代码 bagging_clf = ...

  7. ML之RF:kaggle比赛之利用泰坦尼克号数据集建立RF模型对每个人进行获救是否预测

    ML之RF:kaggle比赛之利用泰坦尼克号数据集建立RF模型对每个人进行获救是否预测 目录 输出结果 实现代码 输出结果 后期更新-- 实现代码 #预测模型选择的RF import numpy as ...

  8. R语言编写自定义函数计算分类模型评估指标:准确度、特异度、敏感度、PPV、NPV、数据数据为模型预测后的混淆矩阵、比较多个分类模型分类性能(逻辑回归、决策树、随机森林、支持向量机)

    R语言编写自定义函数计算分类模型评估指标:准确度.特异度.敏感度.PPV.NPV.数据数据为模型预测后的混淆矩阵.比较多个分类模型分类性能(逻辑回归.决策树.随机森林.支持向量机) 目录

  9. 基于双语数据集搭建seq2seq模型

    目录 一.前言 二.数据预处理 2.1 数据清洗 2.2 词元化 2.3 建立词表 2.4 数据加载 2.5 构建数据集 三.模型搭建 3.1 Encoder-Decoder 架构 3.2 Encod ...

  10. logit模型应用实例_第六章 逻辑斯谛回归与最大熵模型(第1节 逻辑斯谛回归模型)...

    逻辑斯谛回归(logistic regression)是经典的分类方法. 最大熵是概率模型学习的一个准则,将其推广到分类问题得到最大熵模型(maximum entropy model). 逻辑斯谛回归 ...

最新文章

  1. 计算机三年工作经验和研究生,三年工作经验和读三年研究生到底哪个更值?这个回答很权威...
  2. token,session,cookie
  3. SAP Spartacus store里引用的library是如何编译出来的
  4. SAP Fiori Launchpad Tile点击后跳转的调试技巧
  5. sap 分割评估_SAP那些事-实战篇-73-受托加工的几种方案探讨
  6. x299服务器芯片组,18核心炸裂!X299主板全集:为它真拼了
  7. 前端学习(3209):react中类中方法的this指向
  8. 【算法分析与设计】找到最重的球
  9. HDU1253 胜利大逃亡
  10. Windows Terminal Preview 1.5 发布
  11. Oracle XE安装具体解释
  12. android开发分页查询,Android开发中实现分页效果的简单步骤
  13. ROS☞通过两种方法提取.bag中的图像数据
  14. 京东基础架构部招聘GO/JAVA架构师两名(T7+)
  15. 在线动态几何编辑器 GeometryEditor
  16. VMware虚拟机XP系统安装图文教程
  17. wps启用编辑按钮在哪里_wps页面设置在哪里?wps页面设置使用教程
  18. Unity喷墨效果Shader实现
  19. 沙耶の唄(沙耶之歌)游戏全攻略
  20. 中国菜刀与一句话木马之间的原理分析

热门文章

  1. 测量电源纹波-正确测量方法
  2. 京享值超8万的京东钻石用户告诉你套路是这样的
  3. 输入某年某月某日,判断这一天是这一年的第几天?
  4. 【Axure高保真原型】用户详细画像可视化原型模板
  5. seurat质控Warning: Feature names cannot have underscores (‘_‘), replacing with dashes (‘-‘)
  6. C++输出流cout的执行顺序问题
  7. Spring Security 如何防止 Session Fixation 攻击
  8. Python利用re正则表达式抓取豆瓣电影Top250排行榜
  9. RouterOS 重置密码
  10. 拉着老公,逛了一趟绿源电动车连锁店,喜提新座驾。