作者 | 王树义

来源 | 玉树芝兰(ID:nkwangshuyi)

以客户流失数据为例,看 Tensorflow 2.0 版本如何帮助我们快速构建表格(结构化)数据的神经网络分类模型。

变化

表格数据,你应该并不陌生。毕竟, Excel 这东西在咱们平时的工作和学习中,还是挺常见的。


在之前的教程里,我为你分享过,如何利用深度神经网络,锁定即将流失的客户。里面用到的,就是这样的表格数据。

时间过得真快,距离写作那篇教程,已经一年半了。

这段时间里,出现了2个重要的变化,使我觉得有必要重新来跟你谈谈这个话题。

这两个变化分别是:

首先,tflearn 框架的开发已经不再活跃。


tflearn 是当时教程中我们使用的高阶深度学习框架,它基于 Tensorflow 之上,包裹了大量的细节,让用户可以非常方便地搭建自己的模型。

但是,由于 Tensorflow 选择拥抱了它的竞争者 Keras ,导致后者的竞争优势凸显。


对比二者获得的星数,已经不在同一量级。

观察更新时间,tflearn 已经几个月没有动静;而 Keras 几个小时之前,还有更新。

我们选择免费开源框架,一定要使用开发活跃社区支持完善的。只有这样,遇到问题才能更低成本、高效率地解决。

看过我的《Python编程遇问题,文科生怎么办?》一文之后,你对上述结论,应该不陌生。

另一项新变化,是 Tensorflow 发布了 2.0 版本。

相对 1.X 版本,这个大版本的变化,我在《如何用 Python 和 BERT 做中文文本二元分类?》一文中,已经粗略地为你介绍过了。简要提炼一下,就是:

之前的版本,以计算图为中心。开发者需要为这张图服务。因此,引入了大量的不必要术语。新版本以人为中心,用户撰写高阶的简洁语句,框架自动将其转化为对应的计算图。

之前的版本,缺少目前竞争框架(如 PyTorch 等)包含的新特性。例如计算图动态化、运行中调试功能等。

但对普通开发者来说,最为重要的是,官方文档和教程变得对用户友好许多。不仅写得清晰简明,更靠着 Google Colab 的支持,全都能一键运行。我尝试了 2.0 版本的一些教程样例,确实感觉大不一样了。


其实你可能会觉得奇怪—— Tensorflow 大张旗鼓宣传的大版本改进,其实也无非就是向着 PyTorch 早就有的功能靠拢而已嘛。那我干脆去学 PyTorch 好了!

如果我们只说道理,这其实没错。然而,还是前面那个论断,一个框架好不好,主要看是否开发活跃社区支持完善。这就是一个自证预言。一旦人们都觉得 Tensorflow 好用,那么 Tensorflow 就会更好用。因为会有更多的人参与进来,帮助反馈和改进。

看看现在 PyTorch 的 Github 页面。


受关注度,确实已经很高了。

然而你再看看 Tensorflow 的。


至少在目前,二者根本不在一个数量级。

Tensorflow 的威力,不只在于本身构建和训练模型是不是好用。那其实只是深度学习中,非常小的一个环节。不信?你在下图里找找看。


真正的问题,在于是否有完整的生态环境支持。其中的逻辑,我在《学 Python ,能提升你的竞争力吗?》一文中,已经为你详细分析过了。

而 Tensorflow ,早就通过一系列的布局,使得其训练模型可以直接快速部署,最快速度铺开,帮助开发者占领市场先机。


如果你使用 PyTorch ,那么这样的系统,是相对不完善的。当然你可以在 PyTorch 中训练,然后转换并且部署到 Tensorflow 里面。毕竟三巨头达成了协议,标准开放,这样做从技术上并不困难。


但是,人的认知带宽,是非常有限的。大部分人,是不会选择在两个框架甚至生态系统之间折腾的。这就是路径依赖

所以,别左顾右盼了,认认真真学 Tensorflow 2.0 吧。

这篇文章里面,我给你介绍,如何用 Tensorflow 2.0 ,来训练神经网络,对用户流失数据建立分类模型,从而可以帮你见微知著,洞察风险,提前做好干预和防范。

数据

你手里拥有的,是一份银行欧洲区客户的数据,共有10000条记录。客户主要分布在法国、德国和西班牙。


数据来自于匿名化处理后的真实数据集,下载自 superdatascience 官网。

从表格中,可以读取的信息,包括客户们的年龄、性别、信用分数、办卡信息等。客户是否已流失的信息在最后一列(Exited)。

这份数据,我已经上传到了这个地址,你可以下载,并且用 Excel 查看。

环境

本文的配套源代码,我放在了这个 Github 项目中。请你点击这个链接(http://t.cn/EXffmgX)访问。


如果你对我的教程满意,欢迎在页面右上方的 Star 上点击一下,帮我加一颗星。谢谢!

注意这个页面的中央,有个按钮,写着“在 Colab 打开” (Open in Colab)。请你点击它。

然后,Google Colab 就会自动开启。


我建议你点一下上图中红色圈出的 “COPY TO DRIVE” 按钮。这样就可以先把它在你自己的 Google Drive 中存好,以便使用和回顾。


Colab 为你提供了全套的运行环境。你只需要依次执行代码,就可以复现本教程的运行结果了。

如果你对 Google Colab 不熟悉,没关系。我这里有一篇教程,专门讲解 Google Colab 的特点与使用方式。

为了你能够更为深入地学习与了解代码,我建议你在 Google Colab 中开启一个全新的 Notebook ,并且根据下文,依次输入代码并运行。在此过程中,充分理解代码的含义。

这种看似笨拙的方式,其实是学习的有效路径。

代码

首先,我们下载客户流失数据集。

!wget https://raw.githubusercontent.com/wshuyi/demo-customer-churn-ann/master/customer_churn.csv

载入 Pandas 数据分析包。

import pandas as pd

利用 read_csv 函数,读取 csv 格式数据到 Pandas 数据框。

df = pd.read_csv('customer_churn.csv')

我们来看看前几行显示结果:

df.head()

显示正常。下面看看一共都有哪些列。

df.columns

我们对所有列,一一甄别。

  • RowNumber:行号,这个对于模型没用,忽略

  • CustomerID:用户编号,这个是顺序发放的,忽略

  • Surname:用户姓名,对流失没有影响,忽略

  • CreditScore:信用分数,这个很重要,保留

  • Geography:用户所在国家/地区,这个有影响,保留

  • Gender:用户性别,可能有影响,保留

  • Age:年龄,影响很大,年轻人更容易切换银行,保留

  • Tenure:当了本银行多少年用户,很重要,保留

  • Balance:存贷款情况,很重要,保留

  • NumOfProducts:使用产品数量,很重要,保留

  • HasCrCard:是否有本行信用卡,很重要,保留

  • IsActiveMember:是否活跃用户,很重要,保留

  • EstimatedSalary:估计收入,很重要,保留

  • Exited:是否已流失,这将作为我们的标签数据

确定了不同列的含义和价值,下面我们处理起来,就得心应手了。

数据有了,我们来调入深度学习框架。

因为本次我们需要使用 Tensorflow 2.0 ,而写作本文时,该框架版本尚处于 Alpha 阶段,因此 Google Colab 默认使用的,还是 Tensorflow 1.X 版本。要用 2.0 版,便需要显式安装。

!pip install -q tensorflow==2.0.0-alpha0

安装框架后,我们载入下述模块和函数,后文会用到。

import numpy as npimport tensorflow as tffrom tensorflow import kerasfrom sklearn.model_selection import train_test_splitfrom tensorflow import feature_column

这里,我们设定一些随机种子值。这主要是为了保证结果可复现,也就是在你那边的运行结果,和我这里尽量保持一致。这样我们观察和讨论问题,会更方便。

首先是 Tensorflow 中的随机种子取值,设定为 1 。

tf.random.set_seed(1)

然后我们来分割数据。这里使用的是 Scikit-learn 中的 train_test_split 函数。指定分割比例即可。

我们先按照 80:20 的比例,把总体数据分成训练集测试集

train, test = train_test_split(df, test_size=0.2, random_state=1)

然后,再把现有训练集的数据,按照 80:20 的比例,分成最终的训练集,以及验证集

train, valid = train_test_split(train, test_size=0.2, random_state=1)

这里,我们都指定了 random_state ,为的是保证咱们随机分割的结果一致。

我们看看几个不同集合的长度。

print(len(train))print(len(valid))print(len(test))

验证无误。下面我们来做特征工程(feature engineering)。

因为我们使用的是表格数据(tabular data),属于结构化数据。因此特征工程相对简单一些。

先初始化一个空的特征列表。

feature_columns = []

然后,我们指定,哪些列是数值型数据(numeric data)。

numeric_columns = ['CreditScore', 'Age', 'Tenure', 'Balance', 'NumOfProducts', 'EstimatedSalary']

可见,包含了以下列:

  • CreditScore:信用分数

  • Age:年龄

  • Tenure:当了本银行多少年用户

  • Balance:存贷款情况

  • NumOfProducts:使用产品数量

  • EstimatedSalary:估计收入

对于这些列,只需要直接指定类型,加入咱们的特征列表就好。

for header in numeric_columns:  feature_columns.append(feature_column.numeric_column(header))

下面是比较讲究技巧的部分了,就是类别数据。

先看看都有哪些列:

categorical_columns = ['Geography', 'Gender', 'HasCrCard', 'IsActiveMember']
  • Geography:用户所在国家/地区

  • Gender:用户性别

  • HasCrCard:是否有本行信用卡

  • IsActiveMember:是否活跃用户

类别数据的特点,在于不能直接用数字描述。例如 Geography 包含了国家/地区名称。如果你把法国指定为1, 德国指定为2,电脑可能自作聪明,认为“德国”是“法国”的2倍,或者,“德国”等于“法国”加1。这显然不是我们想要表达的。

所以我这里编了一个函数,把一个类别列名输入进去,让 Tensorflow 帮我们将其转换成它可以识别的类别形式。例如把法国按照 [0, 0, 1],德国按照 [0, 1, 0] 来表示。这样就不会有数值意义上的歧义了。

def get_one_hot_from_categorical(colname):  categorical = feature_column.categorical_column_with_vocabulary_list(colname, train[colname].unique().tolist())  return feature_column.indicator_column(categorical)

我们尝试输入 Geography 一项,测试一下函数工作是否正常。

geography = get_one_hot_from_categorical('Geography'); geography

观察结果,测试通过。

下面我们放心大胆地把所有类别数据列都在函数里面跑一遍,并且把结果加入到特征列表中。

for col in categorical_columns:  feature_columns.append(get_one_hot_from_categorical(col))

看看此时的特征列表内容:

feature_columns

6个数值类型,4个类别类型,都没问题了。

下面该构造模型了。

我们直接采用 Tensorflow 2.0 鼓励开发者使用的 Keras 高级 API 来拼搭一个简单的深度神经网络模型。

from tensorflow.keras import layers

我们把刚刚整理好的特征列表,利用 DenseFeatures 层来表示。把这样的一个初始层,作为模型的整体输入层。

feature_layer = layers.DenseFeatures(feature_columns); feature_layer

下面,我们顺序叠放两个中间层,分别包含200个,以及100个神经元。这两层的激活函数,我们都采用 relu

relu 函数大概长这个样子:


model = keras.Sequential([  feature_layer,  layers.Dense(200, activation='relu'),  layers.Dense(100, activation='relu'),  layers.Dense(1, activation='sigmoid')])

我们希望输出结果是0或者1,所以这一层只需要1个神经元,而且采用的是 sigmoid 作为激活函数。

sigmoid 函数的长相是这样的:


模型搭建好了,下面我们指定3个重要参数,编译模型。

model.compile(optimizer='adam',              loss='binary_crossentropy',              metrics=['accuracy'])

这里,我们选择优化器为 adam


因为评判二元分类效果,所以损失函数选的是 binary_crossentropy

至于效果指标,我们使用的是准确率(accuracy)。

模型编译好之后。万事俱备,只差数据了。

你可能纳闷,一上来不就已经把训练、验证和测试集分好了吗?

没错,但那只是原始数据。我们模型需要接收的,是数据流

在训练和验证过程中,数据都不是一次性灌入模型的。而是一批次一批次分别载入。每一个批次,称作一个 batch;相应地,批次大小,叫做 batch_size

为了方便咱们把 Pandas 数据框中的原始数据转换成数据流。我这里编写了一个函数。

def df_to_tfdata(df, shuffle=True, bs=32):  df = df.copy()  labels = df.pop('Exited')  ds = tf.data.Dataset.from_tensor_slices((dict(df), labels))  if shuffle:    ds = ds.shuffle(buffer_size=len(df), seed=1)  ds = ds.batch(bs)  return ds

这里首先是把数据中的标记拆分出来。然后根据把数据读入到 ds 中。根据是否是训练集,我们指定要不要需要打乱数据顺序。然后,依据 batch_size 的大小,设定批次。这样,数据框就变成了神经网络模型喜闻乐见的数据流

train_ds = df_to_tfdata(train)valid_ds = df_to_tfdata(valid, shuffle=False)test_ds = df_to_tfdata(test, shuffle=False)

这里,只有训练集打乱顺序。因为我们希望验证和测试集一直保持一致。只有这样,不同参数下,对比的结果才有显著意义。

有了模型架构,也有了数据,我们把训练集和验证集扔进去,让模型尝试拟合。这里指定了,跑5个完整轮次(epochs)。

model.fit(train_ds,          validation_data=valid_ds,          epochs=5)

你会看到,最终的验证集准确率接近80%。

我们打印一下模型结构:

model.summary()

虽然我们的模型非常简单,却也依然包含了23401个参数。

下面,我们把测试集放入模型中,看看模型效果如何。

model.evaluate(test_ds)

依然,准确率接近80%。

还不错吧?

……

真的吗?

疑惑

如果你观察很仔细,可能刚才已经注意到了一个很奇特的现象:


训练的过程中,除了第一个轮次外,其余4个轮次的这几项重要指标居然都没变

它们包括:

  • 训练集损失

  • 训练集准确率

  • 验证集损失

  • 验证集准确率

所谓机器学习,就是不断迭代改进啊。如果每一轮下来,结果都一模一样,这难道不奇怪吗?难道没问题吗?

我希望你,能够像侦探一样,揪住这个可疑的线索,深入挖掘进去。

这里,我给你个提示。

看一个分类模型的好坏,不能只看准确率(accuracy)。对于二元分类问题,你可以关注一下 f1 score,以及混淆矩阵(confusion matrix)。

如果你验证了上述两个指标,那么你应该会发现真正的问题是什么

下一步要穷究的,是问题产生的原因

回顾一下咱们的整个儿过程,好像都很清晰明了,符合逻辑啊。究竟哪里出了问题呢?

如果你一眼就看出了问题。恭喜你,你对深度学习已经有感觉了。那么我继续追问你,该怎么解决这个问题呢?

欢迎你把思考后的答案在留言区告诉我。

对于第一名全部回答正确上述问题的读者,我会邀请你作为嘉宾,免费(原价199元)加入我本年度的知识星球。当然,前提是你愿意。

小结

希望通过本文的学习,你已掌握了以下知识点:

  1. Tensorflow 2.0 的安装与使用;

  2. 表格式数据的神经网络分类模型构建;

  3. 特征工程的基本流程;

  4. 数据集合的随机分割与利用种子数值保持一致;

  5. 数值型数据列与类别型数据列的分别处理方式;

  6. Keras 高阶 API 的模型搭建与训练;

  7. 数据框转化为 Tensorflow 数据流;

  8. 模型效果的验证;

  9. 缺失的一环,也即本文疑点产生的原因,以及正确处理方法。

希望本教程对于你处理表格型数据分类任务,能有帮助。

祝深度学习愉快!

(本文为AI科技大本营转载文章,转载请联系原作者

长三角开发者联盟

代码就是力量,长三角的开发者联合起来!

加入「长三角开发者联盟」将获得以下权益

长三角地区明星企业内推岗位
CSDN独家技术与行业报告
CSDN线下活动优先参与权
CSDN线上分享活动优先参与权

扫码添加联盟小助手,回复关键词“长三角2”,加入「长三角开发者联盟」。

推荐阅读:

  • 机器学习萌新必备的三种优化算法 | 选型指南

  • A* 算法之父、人工智能先驱Nils Nilsson逝世 | 缅怀

  • Python程序员Debug的利器,和Print说再见 | 技术头条

  • 入门AI第一步,从安装环境Ubuntu+Anaconda开始教!

  • 小程序的侵权“生死局”

  • @996 程序员,ICU 你真的去不起!

  • Elastic Jeff Yoshimura:开源正在开启新一轮的创新 | 人物志

  • 19岁当老板, 20岁ICO失败, 21岁将项目挂到了eBay, 为何初创公司如此艰难?

  • 她说:为啥程序员都特想要机械键盘?这答案我服!

点击阅读原文,了解CTA核心技术及应用峰会」

怎样搞定分类表格数据?有人用TF2.0构建了一套神经网络 | 技术头条相关推荐

  1. excel行列互换_1秒轻松搞定EXCEL表格行列内容互换

    Hello,大家好!办公一点通,工作更轻松.我是头条号职场加油驿站,我又来和大家分享EXCEL办公小技能了~今天和大家分享如何1秒搞定EXCEL表格行列内容互换,快来和我一起学习吧 问题:经常处理数据 ...

  2. 【Python基础】一文搞定pandas的数据合并

    作者:来源于读者投稿 出品:Python数据之道 一文搞定pandas的数据合并 在实际处理数据业务需求中,我们经常会遇到这样的需求:将多个表连接起来再进行数据的处理和分析,类似SQL中的连接查询功能 ...

  3. mvvm怎么让光标制定属性的文本框_Word怎么快速制作斜线表头?10秒搞定,表格颜值直线上升...

    Word怎么快速制作斜线表头?这个问题可能困扰着很多刚进入职场的小伙伴.不管是文员是一线工作人员,工作中或多或少都会涉及斜线表头的制作. 因此,今天就给大家分享一下制作斜线表头的方法,而在Word中斜 ...

  4. dfema规则_六步搞定DFMEA表格

    原标题:六步搞定DFMEA表格 档即用www.downjy.com向您分享如下的"六步搞定DFMEA表格"的知识.原版文档下载方法参照文章底部说明~ 1‍DFMEA的重大作用 FM ...

  5. 一文搞定pandas的数据合并

    作者:来源于读者投稿 出品:Python数据之道 一文搞定pandas的数据合并 在实际处理数据业务需求中,我们经常会遇到这样的需求:将多个表连接起来再进行数据的处理和分析,类似SQL中的连接查询功能 ...

  6. 查询所有_学会DSUM函数,轻松搞定所有的数据查询与数据求和

    在Excel表格中数据的查询与数据求和是我们经常会遇到的问题.今天和朋友们一起学习一下非常强大的DSUM函数,这个一个函数就可以轻松搞定单条件查询.多条件查询.反向查询.单条件求和.多条件求和. 一. ...

  7. 亿条数据读取工具_仅需1秒!搞定100万行数据:超强Python数据分析利器

    作者:Maarten.Roman.Jovan 编译:1+1=6 1.前言 使用Python进行大数据分析变得越来越流行.这一切都要从NumPy开始,它也是今天我们在推文介绍工具背后支持的模块之一. 2 ...

  8. 仅需1秒!搞定100万行数据:超强Python数据分析利器

    前言 使用Python进行大数据分析变得越来越流行.这一切都要从NumPy开始,它也是今天我们在推文介绍工具背后支持的模块之一. 2 Vaex 很多人学习python,不知道从何学起. 很多人学习py ...

  9. 服务器mbr文件丢失吗,硬盘中了MBR病毒不要急,一款工具帮你搞定,保证数据不丢失!...

    经常看见论坛上面有人说自己的硬盘被锁了,开机后出现一行红字:FUCK YOU POJIEZHE. 这个问题主要的原因是:病毒对MBR分区的修改导致的. MBR病毒简介: 引导区病毒是PC机上最早出现的 ...

最新文章

  1. java幂等性的解决方案
  2. 幂函数与指数函数的区别
  3. python定时关闭进程_Python子进程,定时延迟后终止进程
  4. java并发编程之美-阅读记录2
  5. 22个ES6面试、复习干货知识点汇总
  6. zabbix简介及部署
  7. Ubuntu16.04 安装 卸载 pip
  8. 信创办公--基于WPS的Word最佳实践系列(利用表格控制排版)
  9. 1. 从键盘输入一系列字符(以回车符结束,字符的个数不超过 200 个),统计输入字符串中数字与非数字字符的个数,并将计数结果输出。
  10. VMware15Pro 安装CentOS7
  11. http:网易云音乐
  12. Traffic shaping 一个事半功倍的程序化”噪音“解决方案
  13. 子类拷贝构造是否会调用父类的拷贝构造?
  14. 请假过来面试,没有被录用,总不能让我一点收获都没有吧
  15. 预约订座APP系统(基于uni-app框架)毕业设计毕业论文开题报告参考(3)系统后台管理功能
  16. DirectX 3D 简单渲染流程
  17. 零基础PS----制作不一样的个人简历
  18. 技嘉服务器主板是什么型号,ASUS华硕/技嘉/微星MSI工作站服务器主板型号对比说明,注入win7驱动工具...
  19. 啤酒游戏的牛鞭效应分析之供应链4层模式
  20. 抗渗等级p6是什么意思_关于混凝土抗渗等级p6 p8采用混凝土抗渗剂的用法

热门文章

  1. Error: could not open 'D:\Program Files\Java\jre7\lib\amd64\jvm.cfg'
  2. 腾讯微博快速有效增加广播转播量的方法与技巧
  3. SLF4J 的几种实际应用模式--之二:SLF4J+Logback
  4. [转]C# 2.0新特性与C# 3.5新特性
  5. Windows Server 2008 R2 Beta VHD镜像文件发布
  6. Java版开发原生App支付
  7. 解决js中数字相减为负数的情况
  8. 20165219王彦博《基于Cortex-M4的虚拟机制作与测试》课程设计个人报告
  9. microsoft 为microbit.org 设计的课程
  10. 如何将github上的 lib fork之后通过podfile 改变更新源到自己fork的地址