在本教程中,您将学习使用TFLearn和TensorFlow来估计泰坦尼克号乘客使用其个人信息(如性别,年龄等)的幸存机会。 为了解决这个经典机器学习任务,我们要构建一个深层神经网络分类器。

前提条件

确保已经安装了TensorFlow和tflearn。 如果没有,请先根据教程安装 。

概述

介绍

1912年4月15日,泰坦尼克号在与冰山相撞后沉没,在2224名乘客和船员中造成1502人死亡。 虽然不幸遇难下沉,但一些群体比如女性,儿童和上层阶级更有可能生存下去。 在本教程中,我们进行分析,以了解这些人是谁。

数据集

我们来看看数据集(TFLearn会自动为你下载)。 对于每位乘客,提供以下信息:

VARIABLE DESCRIPTIONS:

survived Survived

(0 = No; 1 = Yes)

pclass Passenger Class

(1 = 1st; 2 = 2nd; 3 = 3rd)

name Name

sex Sex

age Age

sibsp Number of Siblings/Spouses Aboard

parch Number of Parents/Children Aboard

ticket Ticket Number

fare Passenger Fare

以下是从数据集中提取的一些样本:

survived

pclass

name

sex

age

sibsp

parch

ticket

fare

1

1

Aubart, Mme. Leontine Pauline

female

24

0

0

PC 17477

69.3000

0

2

Bowenur, Mr. Solomon

male

42

0

0

211535

13.0000

1

3

Baclini, Miss. Marie Catherine

female

5

2

1

2666

19.2583

0

3

Youseff, Mr. Gerious

male

45.5

0

0

2628

7.2250

我们的任务是区分2种类型的乘客未存活(标签0)和“存活的”(标签1),乘客数据有8个特征。

构建分类器

加载数据

数据集存储在csv文件中,因此我们可以使用TFLearn load_csv()函数将文件中的数据加载到python list 。 我们指定'target_column'参数来表示我们的标签(存活或未存活)位于第一列(id:0)。 该函数将返回一个元组:(数据,标签)。

import numpy as np

import tflearn

# Download the Titanic dataset

from tflearn.datasets import titanic

titanic.download_dataset('titanic_dataset.csv')

# Load CSV file, indicate that the first column represents labels

from tflearn.data_utils import load_csv

data, labels = load_csv('titanic_dataset.csv', target_column=0,

categorical_labels=True, n_classes=2)

预处理数据

给出的数据需要一些预处理才能在我们的深度神经网络分类器中使用。

首先,我们将丢弃在我们的分析中不太可能帮助的领域。 例如,我们假设“姓名”字段在我们的任务中将不会很有用,因为我们估计一个乘客姓名和他的生存机会是不相关的。 有了这样的想法,我们抛弃了“名字”和“票”字段。

然后,我们需要将所有的数据转换为数值,因为神经网络模型只能对数字执行操作。 但是,我们的数据集包含一些非数值值,例如'name'或'sex'。 因为'name'被丢弃,我们只需要处理'sex'字段。 在这种简单的情况下,我们将把'0'分配给男性,'1'分配给女性。

这是预处理函数:

# Preprocessing function

def preprocess(passengers, columns_to_delete):

# Sort by descending id and delete columns

for column_to_delete in sorted(columns_to_delete, reverse=True):

[passenger.pop(column_to_delete) for passenger in passengers]

for i in range(len(passengers)):

# Converting 'sex' field to float (id is 1 after removing labels column)

passengers[i][1] = 1. if passengers[i][1] == 'female' else 0.

return np.array(passengers, dtype=np.float32)

# Ignore 'name' and 'ticket' columns (id 1 & 6 of data array)

to_ignore=[1, 6]

# Preprocess data

data = preprocess(data, to_ignore)

建立深度神经网络

我们正在使用TFLearn构建一个3层神经网络。 我们需要指定输入数据的形状。 在我们的示例中,每个样本共有6个特征,我们将处理每批次的样本以节省内存,因此我们的数据输入形状为[None,6](“None”代表未知维度,因此我们可以更改在批处理中的样本总数)。

# Build neural network

net = tflearn.input_data(shape=[None, 6])

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 2, activation='softmax')

net = tflearn.regression(net)

训练

TFLearn提供了一个可以自动执行神经网络分类器任务的模型包装器“DNN”,例如训练,预测,保存/恢复等...我们将运行10个次(神经网络将处理所有数据10次)批量大小为16。

# Define model

model = tflearn.DNN(net)

# Start training (apply gradient descent algorithm)

model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

输出:

---------------------------------

Run id: MG9PV8

Log directory: /tmp/tflearn_logs/

---------------------------------

Training samples: 1309

Validation samples: 0

--

Training Step: 82 | total loss: 0.64003

| Adam | epoch: 001 | loss: 0.64003 - acc: 0.6620 -- iter: 1309/1309

--

Training Step: 164 | total loss: 0.61915

| Adam | epoch: 002 | loss: 0.61915 - acc: 0.6614 -- iter: 1309/1309

--

Training Step: 246 | total loss: 0.56067

| Adam | epoch: 003 | loss: 0.56067 - acc: 0.7171 -- iter: 1309/1309

--

Training Step: 328 | total loss: 0.51807

| Adam | epoch: 004 | loss: 0.51807 - acc: 0.7799 -- iter: 1309/1309

--

Training Step: 410 | total loss: 0.47475

| Adam | epoch: 005 | loss: 0.47475 - acc: 0.7962 -- iter: 1309/1309

--

Training Step: 492 | total loss: 0.51677

| Adam | epoch: 006 | loss: 0.51677 - acc: 0.7701 -- iter: 1309/1309

--

Training Step: 574 | total loss: 0.48988

| Adam | epoch: 007 | loss: 0.48988 - acc: 0.7891 -- iter: 1309/1309

--

Training Step: 656 | total loss: 0.55073

| Adam | epoch: 008 | loss: 0.55073 - acc: 0.7427 -- iter: 1309/1309

--

Training Step: 738 | total loss: 0.50242

| Adam | epoch: 009 | loss: 0.50242 - acc: 0.7854 -- iter: 1309/1309

--

Training Step: 820 | total loss: 0.41557

| Adam | epoch: 010 | loss: 0.41557 - acc: 0.8110 -- iter: 1309/1309

--

我们的模型完成训练,总体精度达到81%左右,这意味着它能以81%的准确率预测结果(乘客的存活与否)。

测试模型

现在是测试我们的模型的时候了。 为了乐趣,我们来看看泰坦尼克号电影的主角(DiCaprio和温斯莱特),并计算他们的生存机会(1级)。

# Let's create some data for DiCaprio and Winslet

dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]

winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]

# Preprocess data

dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)

# Predict surviving chances (class 1 results)

pred = model.predict([dicaprio, winslet])

print("DiCaprio Surviving Rate:", pred[0][1])

print("Winslet Surviving Rate:", pred[1][1])

输出:

DiCaprio Surviving Rate: 0.13849584758281708

Winslet Surviving Rate: 0.92201167345047

令人印象深刻! 我们的模型准确预测了电影的结果。 DiCaprio生存机会不大,但温斯莱特有很高的生存机会。

更一般来说,通过这项研究可以看出,头等舱的妇女和儿童乘客的生存机会最高,而三等舱男性乘客生存机会最少。

源代码

from __future__ import print_function

import numpy as np

import tflearn

# Download the Titanic dataset

from tflearn.datasets import titanic

titanic.download_dataset('titanic_dataset.csv')

# Load CSV file, indicate that the first column represents labels

from tflearn.data_utils import load_csv

data, labels = load_csv('titanic_dataset.csv', target_column=0,

categorical_labels=True, n_classes=2)

# Preprocessing function

def preprocess(passengers, columns_to_delete):

# Sort by descending id and delete columns

for column_to_delete in sorted(columns_to_delete, reverse=True):

[passenger.pop(column_to_delete) for passenger in passengers]

for i in range(len(passengers)):

# Converting 'sex' field to float (id is 1 after removing labels column)

passengers[i][1] = 1. if data[i][1] == 'female' else 0.

return np.array(passengers, dtype=np.float32)

# Ignore 'name' and 'ticket' columns (id 1 & 6 of data array)

to_ignore=[1, 6]

# Preprocess data

data = preprocess(data, to_ignore)

# Build neural network

net = tflearn.input_data(shape=[None, 6])

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 32)

net = tflearn.fully_connected(net, 2, activation='softmax')

net = tflearn.regression(net)

# Define model

model = tflearn.DNN(net)

# Start training (apply gradient descent algorithm)

model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

# Let's create some data for DiCaprio and Winslet

dicaprio = [3, 'Jack Dawson', 'male', 19, 0, 0, 'N/A', 5.0000]

winslet = [1, 'Rose DeWitt Bukater', 'female', 17, 1, 2, 'N/A', 100.0000]

# Preprocess data

dicaprio, winslet = preprocess([dicaprio, winslet], to_ignore)

# Predict surviving chances (class 1 results)

pred = model.predict([dicaprio, winslet])

print("DiCaprio Surviving Rate:", pred[0][1])

print("Winslet Surviving Rate:", pred[1][1])

tflearn教程_TFlearn 快速入门相关推荐

  1. Tomcat 教程之快速入门

    Tomcat 教程之快速入门 版本说明 本文使用 Tomcat 版本为 Tomcat 8.5.24. Tomcat 8.5 要求 JDK 版本为 1.7 以上. 简介 Tomcat 是什么 Tomca ...

  2. ArcGIS教程——ArcGIS快速入门

    实例数据:https://pan.baidu.com/s/184wwCmWrJdb-qjxsT614EQ 密码:dowv ArcGIS for Desktop是一套完整的专业GIS应用程序,包含有Ar ...

  3. tensorflow2.0教程- Keras 快速入门

    tensorflow2.0教程-tensorflow.keras 快速入门 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/artic ...

  4. SWMM从入门到实践教程 03 快速入门案例的设施参数设置与批量设置

    文章目录 1 雨量计 1.1 雨量计基础设置 1.2 雨量计数据来源 2 汇水区 2.1 参数讲解 2.2 设置结果 3 检查井 3.1 参数讲解 3.2 批量设置 4 管道 4.1 参数讲解 4.2 ...

  5. SWMM从入门到实践教程 02 快速入门案例的绘制

    文章目录 1 建模准备 2 设置各类设施 2.1 添加雨量计 2.2 添加子汇水区(正方形) 2.3 绘制节点(圆形) 2.4 绘制管渠 2.5 添加排水口(三角形) 3 画面调节 1 建模准备 建模 ...

  6. esp8266灯上电闪一下_【零知ESP8266教程】快速入门2-点亮外部LED灯

    [零知ESP8266教程]快速入门2-点亮外部LED灯 [复制链接] 一.工具原料 电脑,windows系统 ESP8266开发板 micro-usb线 LED灯1个 220Ω 电阻1个 面包板一个+ ...

  7. Python零基础入门教程( 快速入门)

    前言 学无止境,无止境学. 今天要给大家分享的是<Python零基础入门教程01 快速入门>,这是一个系列的教程,从零基础到项目实战.在本教程中,我会给大家介绍Python入门的一些基础知 ...

  8. tflearn教程_TFLearn:为TensorFlow提供更高级别的API 的深度学习库

    TFlearn是一个基于Tensorflow构建的模块化透明深度学习库.它旨在为TensorFlow提供更高级别的API,以促进和加速实验,同时保持完全透明并与之兼容. TFLearn功能包括: 通过 ...

  9. Spring Boot 2.x基础教程:快速入门

    点击蓝色"程序猿DD"关注我哟 来源:http://t./ <Star最多的Spring Boot教程继续更新了> 牛皮吹过了! Git仓库和博客专题页也改版完成! 是 ...

最新文章

  1. 记录一下利用ffmpeg将avi转为mp4
  2. atitit. groupby linq的实现(1)-----linq框架选型 java .net php
  3. anddroid异常处理之UncaughtException
  4. 使用MEF构建可扩展的Silverlight应用
  5. linux系统设置服务开机启动3种方法,Linux开机启动程序详解
  6. openstack单元測试用组件一览
  7. JS-元素大小深入学习-offset、client、scroll等学习研究笔记
  8. wifi安装linux分区,centos7配置wifi驱动
  9. windows下python 自动截图功能
  10. Echarts 地图绘制
  11. 安装conntrack-tools
  12. zepto.js学习笔记02
  13. linux查看串口驱动
  14. ISO27001体系的价值(详解)
  15. FFmpeg[11] - ffmpeg去除水印(图片和文字)
  16. 武汉市计算机类中专学校排名,武汉中职中专学校一览表 2021最新排名
  17. ServiceNow获得FedRAMP高基准授权
  18. linux下wifi连接方法
  19. 软件测试面试-为什么选择软件测试?
  20. 使用visio创建跨职能流程图

热门文章

  1. 中国电动汽车换电行业需求现状及未来发展规划报告2022-2028年版
  2. java开发webservice简单实例_jsp实现的webservice的简单实例
  3. 【Web技术】959- JavaScript 如何在线解压 ZIP 文件?
  4. google play支付提示“此版本的应用程序未配置为通过Google Play结算。有关详情,请访问帮助中心。”
  5. 如何使用思维导图做技术书籍笔记?
  6. 如何使用egg.js开发后端,包含连接数据库
  7. 递推最小二乘法RLS公式详细推导
  8. 智慧应急解决方案-最新全套文件
  9. 【pymongo】连接认证 auth failed解决方法
  10. 计算100以内的奇数之和