在本教程中,您将学习使用TFLearn和tensorflow评估泰坦尼克号乘客幸存的机会,数据根据是利用他们的个人信息(如性别、年龄等)。为了解决这一经典的机器学习任务,我们要建立一个深神经网络分类器。

准备工作:首先按照指引安装好tensorflow 和 tflearn。

1912年4月15日,泰坦尼克号撞上冰山后沉没,造成2224名乘客和机组人员中1502人死亡。虽然在这场事故中生存下来存在一些运气因素,但是一些群体如妇女、儿童和船体上层人员生存概率更大。在本教程中,我们进行了分析,找出这些人是谁。

数据集

TFlearn会自动下载泰坦尼克号的下面数据:

VARIABLE DESCRIPTIONS:

survived Survived

(0 = No; 1 = Yes)

pclass Passenger Class

(1 = 1st; 2 = 2nd; 3 = 3rd)

name Name

sex Sex

age Age

sibsp Number of Siblings/Spouses Aboard

parch Number of Parents/Children Aboard

ticket Ticket Number

fare Passenger Fare

建立分类器

数据集存储在csv文件中,能够使用TFlearn的load_csv()函数加载数据,使用target_column作为存活与否的标签,也就是数据集第一列survived,函数返回一对数组(data, label)

import numpy as np

import tflearn

# Download the Titanic dataset

from tflearn.datasets import titanic

titanic.download_dataset('titanic_dataset.csv')

# Load CSV file, indicate that the first column represents labels

from tflearn.data_utils import load_csv

data, labels = load_csv('titanic_dataset.csv', target_column=0,

categorical_labels=True, n_classes=2)

预处理

数据作预先处理,数据中name对于预测没有什么用处,取消name和ticket两个字段;其次,神经网络只能处理数字,因此,将sex字段男女转为数字0或1。

# Preprocessing function

def preprocess(data, columns_to_ignore):

# Sort by descending id and delete columns

for id in sorted(columns_to_ignore, reverse=True):

[r.pop(id) for r in data]

for i in range(len(data)):

# Converting 'sex' field to float (id is 1 after removing labels column)

data[i][1] = 1. if data[i][1] == 'female' else 0.

return np.array(data, dtype=np.float32)

# Ignore 'name' and 'ticket' columns (id 1 & 6 of data array)

to_ignore=[1, 6]

# Preprocess data

data = preprocess(data, to_ignore)

建立深度神经网络

我们使用TFLearn建立一个3层神经网络,需要规定输入数据的形态,每个样本有6个特征,我们按批次处理可以节省内存,我们的数据输入形态是 [None, 6] ,其中None代码不知道维度,我们能改变批处理中被处理后的样本总数量。

# Build neural network

net = tflearn.input_data(shape=[None, 6])

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 2, activation='softmax')

net = tflearn.regression(net)

训练

TFLearn提供DNN包装器自动执行神经网络分类任务,比如训练 预测和保存恢复等,我们训练10次,神经网络10次会看到全部数据,每次批处理大小是16:

# Define model

model = tflearn.DNN(net)

# Start training (apply gradient descent algorithm)

model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

输出结果:

---------------------------------

Run id: MG9PV8

Log directory: /tmp/tflearn_logs/

---------------------------------

Training samples: 1309

Validation samples: 0

--

Training Step: 82 | total loss: 0.64003

| Adam | epoch: 001 | loss: 0.64003 - acc: 0.6620 -- iter: 1309/1309

--

Training Step: 164 | total loss: 0.61915

| Adam | epoch: 002 | loss: 0.61915 - acc: 0.6614 -- iter: 1309/1309

--

Training Step: 246 | total loss: 0.56067

| Adam | epoch: 003 | loss: 0.56067 - acc: 0.7171 -- iter: 1309/1309

--

Training Step: 328 | total loss: 0.51807

| Adam | epoch: 004 | loss: 0.51807 - acc: 0.7799 -- iter: 1309/1309

--

Training Step: 410 | total loss: 0.47475

| Adam | epoch: 005 | loss: 0.47475 - acc: 0.7962 -- iter: 1309/1309

--

Training Step: 492 | total loss: 0.51677

| Adam | epoch: 006 | loss: 0.51677 - acc: 0.7701 -- iter: 1309/1309

--

Training Step: 574 | total loss: 0.48988

| Adam | epoch: 007 | loss: 0.48988 - acc: 0.7891 -- iter: 1309/1309

--

Training Step: 656 | total loss: 0.55073

| Adam | epoch: 008 | loss: 0.55073 - acc: 0.7427 -- iter: 1309/1309

--

Training Step: 738 | total loss: 0.50242

| Adam | epoch: 009 | loss: 0.50242 - acc: 0.7854 -- iter: 1309/1309

--

Training Step: 820 | total loss: 0.41557

| Adam | epoch: 010 | loss: 0.41557 - acc: 0.8110 -- iter: 1309/1309

--

模型完成训练准确率达到81%,说明它对全部乘客存活与否能够有81%准确率。

下面我们试用这个模型,将泰坦尼克电影中男女主角杰克和露丝的资料输入:

# Let's create some data for DiCaprio and Winslet

dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]

winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]

# Preprocess data

dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)

# Predict surviving chances (class 1 results)

pred = model.predict([dicaprio, winslet])

print("DiCaprio Surviving Rate:", pred[0][1])

print("Winslet Surviving Rate:", pred[1][1])

输出结果是:

DiCaprio Surviving Rate: 0.13849584758281708

Winslet Surviving Rate: 0.92201167345047

预测露丝有92的高概率生存,而杰克则相反。

更普遍的是,通过这项研究表明,第一层的妇女和儿童的乘客有最高的机会生存,而第三层的男乘客有最低。

tflearn教程_TensorFlow/TFLearn学习案例:泰坦尼克相关推荐

  1. tflearn教程_TensorFlow TFLearn安装和使用

    TFLearn可以定义为TensorFlow框架中使用的模块化和透明的深度学习方面.TFLearn的主要动机是为TensorFlow提供更高级别的API,以促进和展示新的实验. 考虑TFLearn的以 ...

  2. tflearn教程_Tensorflow tflearn 编写RCNN

    两周多的努力总算写出了RCNN的代码,这段代码非常有意思,并且还顺带复习了几个Tensorflow应用方面的知识点,故特此总结下,带大家分享下经验.理论方面,RCNN的理论教程颇多,这里我不在做详尽说 ...

  3. 基于深度学习的泰坦尼克旅客生存预测

    基于深度学习的泰坦尼克旅客生存预测 摘要:近年来,随着深度学习的迅速发展和崛起,尤其在图像分类方向取得了巨大的成就.本文实验基于Windows10系统,仿真软件用的是Anaconda下基于python ...

  4. 小白的机器学习之路(1)---Kaggle竞赛:泰坦尼克之灾(Titanic Machine Learning from Disaster)

    我是目录 前言 数据导入 可视化分析 Pclass Sex Age SibSp Parch Fare Cabin Embarked 特征提取 Title Family Size Companion A ...

  5. tflearn教程_TFLearn:为TensorFlow提供更高级别的API 的深度学习库

    TFlearn是一个基于Tensorflow构建的模块化透明深度学习库.它旨在为TensorFlow提供更高级别的API,以促进和加速实验,同时保持完全透明并与之兼容. TFLearn功能包括: 通过 ...

  6. 二十一世纪大学英语读写教程(第三册)学习笔记(原文)——2 - The Titanic Puzzle(泰坦尼克难题——女权主义者应该接受优先坐上救生艇吗)

    Unit 2 - The Titanic Puzzle - Should a good feminist accept priority seating on a lifeboat?(泰坦尼克难题-- ...

  7. python泰坦尼克号案例分析_泰坦尼克Python经典案例

    12. 章节 12 - 结论和步骤 7: 优化和战略 如何使用本教程 : 请阅读本内核中提供的解释和相关链接.我们的目标不只是知道 " 是什么 " ,还要知道 " 为什么 ...

  8. 集成算法-随机森林与案例实战-泰坦尼克获救预测

    集成算法-随机森林 Ensemble learning 目的:让机器学习效果更好,单个不行,群殴走起 Bagging:训练多个分类器取平均 f ( x ) = 1 / M ∑ m = 1 M f m ...

  9. kaggle 泰坦尼克项目实战(详细代码分享)——集成学习Soft voting

    顺利注册完kaggle之后,终于可以开始上手撸项目啦! 先从大名鼎鼎的泰坦尼克号开始吧! 尽管网上有很多大神进行了"入门级别"的代码分享讲解,但我看了一轮仍然觉得对新手不够友好. ...

最新文章

  1. mysql取最接近的两个值_Mysql:获取一行中另一个字段的最高值和最...
  2. DSM: 域不变的立体匹配网络解析(Stereo Matching Networks)
  3. liunx 上get 不到url参数 java_URL传递中文参数,大坑一枚,Windows与Linux效果竟然不一致...
  4. jvisualvm安装Visual GC插件
  5. Listview中使用线程实现无限加载更多项目的功能
  6. 定期定量采购_定量采购方式
  7. 【C语言】C语言Code的编译与执行
  8. 【BZOJ3781】小B的询问 莫队
  9. html猜随机数游戏,用js制作简易计算器及猜随机数字游戏
  10. Pytorch Feature loss与Perceptual Loss的实现
  11. (clion 安装插件联网络失败,pycharm pip联网失败)当电脑选择拨号上网时,解决系统代理被篡改/pip提示“目标计算机积极拒绝,无法连接”的方法! [ 此方法绝对解决系统代理被篡改问题 ]
  12. html鼠标放在图片上图片自动放大,css使图片自动放大
  13. 天行健---宇宙的生与死
  14. 利用selenium模拟打开百度并输入‘淘宝‘,报错‘dict‘ object has no attribute ‘send_keys‘
  15. 前端轻松破解支付宝AR抢红包
  16. shiny 服务器未响应,在centos上重启shiny-server
  17. 记录一下gitHub跑项目的步骤
  18. 高德地图海量点 API 初探
  19. 集团企业邮箱申请哪家的好,怎么选择?
  20. Pr2022 视频剪辑软件MAC版正式更新,全新版本支持M1,今天详细介绍pr2022如何安装使用?

热门文章

  1. si4463开发总结
  2. 浅谈Facade外观模式
  3. 匈牙利算法解指派问题(Java代码)
  4. Spring学习篇底层核心原理解析
  5. java观察者模式异步notify_Java进阶篇设计模式之十三 ---- 观察者模式和空对象模式...
  6. 计算机考研408每日一题 day121
  7. java前端接收回显图片_图片上传并回显后端篇
  8. 华为设备DHCP snooping配置命令
  9. 计算机统考-ppt操作题2
  10. 步进电机整步、半步、细分波形理解