目录

背景

数据集

特征处理

模型构建及评估


背景:

很多TF模型的例子都是使用dataframe进行数据处理及读取的,在部署及大任务处理时可能会遇到需要特征额外处理及内存不足等问题,所以想直接使用tf.data将预处理及数据读取批次等问题直接处理掉。

本Demo包含了以下完整代码:

  • 用 tf.data 建立了一个输入流水线(pipeline),用于对行进行分批(batch)和随机排序(shuffle)。
  • 用特征列将 CSV 中的列映射到用于训练模型的特征。
  • 用 Keras 构建,训练并评估模型。

数据集

使用kaggle heart-disease做例子(链接可以下载)。CSV 中有几百行数据。每行描述了一个病人(patient),每列描述了一个属性(attribute)。我们将使用这些信息来预测一位病人是否患有心脏病,这是在该数据集上的二分类任务。

下面是该数据集的描述。 请注意,有数值(numeric)和类别(categorical)类型的列。

描述 特征类型 数据类型
Age 年龄以年为单位 Numerical integer
Sex (1 = 男;0 = 女) Categorical integer
CP 胸痛类型(0,1,2,3,4) Categorical integer
Trestbpd 静息血压(入院时,以mm Hg计) Numerical integer
Chol 血清胆固醇(mg/dl) Numerical integer
FBS (空腹血糖> 120 mg/dl)(1 = true;0 = false) Categorical integer
RestECG 静息心电图结果(0,1,2) Categorical integer
Thalach 达到的最大心率 Numerical integer
Exang 运动诱发心绞痛(1 =是;0 =否) Categorical integer
Oldpeak 与休息时相比由运动引起的 ST 节段下降 Numerical integer
Slope 在运动高峰 ST 段的斜率 Numerical float
CA 荧光透视法染色的大血管动脉(0-3)的数量 Numerical integer
Thal 3 =正常;6 =固定缺陷;7 =可逆缺陷 Categorical string
Target 心脏病诊断(1 = true;0 = false) Classification integer

tf.data数据读取

使用make_csv_dataset方法直接批量读取csv数据。注意:file_path可以是一个list,可以是通配符比如:/path/to/dir/*.csv

import tensorflow as tf
from tensorflow import feature_column
from tensorflow.keras import layersdef csv_to_tfdata(file_path,LABEL_COLUMN='lable',batch_size = 10,**kwargs):dataset = tf.data.experimental.make_csv_dataset(file_path,batch_size=batch_size,label_name=LABEL_COLUMN,na_value="?",num_epochs=1,ignore_errors=True,**kwargs)return dataset

读取训练、验证、测试数据,我事先把数据按照80:4:16分割了

dataset_train = csv_to_tfdata("heart.csv",LABEL_COLUMN="target")
dataset_val = csv_to_tfdata("heart-val.csv",LABEL_COLUMN="target")
dataset_test = csv_to_tfdata("heart-test.csv",LABEL_COLUMN="target")
def show_data(data):for batch,label in data.take(1):for key, value in batch.items():print("{:20s}: {}".format(key,value.numpy()))show_data(dataset_val)

特征处理

# 要创建特征列,请调用 tf.feature_column 模块的函数。该模块中常用的九个函数如下图所示,所有九个函数都会返回一个 Categorical-Column 或一个
# Dense-Column 对象,但却不会返回 bucketized_column,后者继承自这两个类。
# 所有的Catogorical Column类型最终都要通过indicator_column转换成Dense Column类型才能传入模型!

  • numeric_column 数值列,最常用。
  • bucketized_column 分桶列,由数值列生成,可以由一个数值列出多个特征,one-hot编码。
  • categorical_column_with_identity 分类标识列,one-hot编码,相当于分桶列每个桶为1个整数的情况。
  • categorical_column_with_vocabulary_list 分类词汇列,one-hot编码,由list指定词典。
  • categorical_column_with_vocabulary_file 分类词汇列,由文件file指定词典。
  • categorical_column_with_hash_bucket 哈希列,整数或词典较大时采用。
  • indicator_column 指标列,由Categorical Column生成,one-hot编码
  • embedding_column 嵌入列,由Categorical Column生成,嵌入矢量分布参数需要学习。嵌入矢量维数建议取类别数量的 4 次方根。
  • crossed_column 交叉列,可以由除categorical_column_with_hash_bucket的任意分类列构成。
feature_columns = []# 数值列
for header in ['age', 'trestbps', 'chol', 'thalach', 'oldpeak', 'slope', 'ca']:feature_columns.append(feature_column.numeric_column(header))# 分桶列
age = feature_column.numeric_column("age")
age_buckets = feature_column.bucketized_column(age, boundaries=[18, 25, 30, 35, 40, 45, 50, 55, 60, 65])
feature_columns.append(age_buckets)# 分类列
thal = feature_column.categorical_column_with_vocabulary_list('thal', [1, 2, 3])
thal_one_hot = feature_column.indicator_column(thal)
feature_columns.append(thal_one_hot)# 嵌入列
thal_embedding = feature_column.embedding_column(thal, dimension=8)
feature_columns.append(thal_embedding)# 组合列
crossed_feature = feature_column.crossed_column([age_buckets, thal], hash_bucket_size=1000)
crossed_feature = feature_column.indicator_column(crossed_feature)
feature_columns.append(crossed_feature)

模型构建及评估

feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
batch_size = 32model = tf.keras.Sequential([feature_layer,layers.Dense(128, activation='relu'),layers.Dense(128, activation='relu'),layers.Dense(1, activation='sigmoid')
])model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'],run_eagerly=True)history = model.fit(dataset_train,validation_data=dataset_val,epochs=5)

model.summary()%matplotlib inline
%config InlineBackend.figure_format = 'svg'import matplotlib.pyplot as pltdef plot_metric(history, metric):train_metrics = history.history[metric]val_metrics = history.history['val_'+metric]epochs = range(1, len(train_metrics) + 1)plt.plot(epochs, train_metrics, 'bo--')plt.plot(epochs, val_metrics, 'ro-')plt.title('Training and validation '+ metric)plt.xlabel("Epochs")plt.ylabel(metric)plt.legend(["train_"+metric, 'val_'+metric])plt.show()plot_metric(history,"accuracy")

TF2.0使用tf.data处理数据建模Demo相关推荐

  1. TensorFlow tf.data 导入数据(tf.data官方教程) * * * * *

    原文链接:https://blog.csdn.net/u014061630/article/details/80728694 TensorFlow版本:1.10.0 > Guide > I ...

  2. TensorFlow 2.0 - tf.data.Dataset 数据预处理 猫狗分类

    文章目录 1 tf.data.Dataset.from_tensor_slices() 数据集建立 2. Dataset.map(f) 数据集预处理 3. Dataset.prefetch() 并行处 ...

  3. openlayers map获取全部feature_tf2.0基础-tf.data与tf.feature_column

    7.2.1 tf.data 使用 tf.data API 可以轻松处理大量数据.不同的数据格式以及复杂的转换.tf.data API 在 TensorFlow 中引入了两个新的抽象类: tf.data ...

  4. TensorFlow :tf.data 高性能数据输入管道设计指南

    TensorFlow版本:1.12.0 本篇主要介绍怎么使用 tf.data API 来构建高性能的输入 pipeline. tf.data官方教程详见前面的博客<<<<< ...

  5. ERWin -- erwin Data Modeler 数据建模

    erwin 的全称是erwin Data Modeler,是erwin公司的数据建模工具.支持各主流数据库系统.erwin数据建模市场占有率第一的产品,市场占有率33%. erwin数据建模工具是业界 ...

  6. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  7. 数据建模-聚类分析-K-Means算法 --聚类可视化工具TSNE

    使用TSNE口可视化工具显示 数据建模-聚类分析-K-Means算法 #-*- coding: utf-8 -*-import sys reload(sys) sys.setdefaultencodi ...

  8. tf.data官方教程 - - 基于TF-v2

    这是本人关于tf.data的第二篇博文,第一篇基于TF-v1详细介绍了tf.data,但是v1和v2很多地方不兼容,所以替大家瞧瞧v2的tf.data模块有什么新奇之处. TensorFlow版本:2 ...

  9. 使用tf.data 加载文件夹下的图片集合并分类

    Tensorflow原始教程链接在官网: https://tensorflow.google.cn/tutorials/load_data/images 简化版: https://colab.rese ...

最新文章

  1. 企业应用程序部署在iOS 7.1上不起作用
  2. react编译器jsxTransformer,babel
  3. 回调函数和闭包的理解
  4. Oracle相关报错
  5. Eclipse调试方法
  6. SpringBoot与JPA
  7. fullpage.js(cndjs)
  8. Linux 内核进程uid,Linux内核学习笔记: uid之ruid,euid,suid
  9. 从编程语言进化史,看 Java、C、C++ 等语言的演变
  10. GitHub使用笔记
  11. Atitit diy战略 attilax总结
  12. intel网卡win10 修改mac
  13. 浅谈,盘点历史上有哪些著名的电脑病毒,80%的人都不知道!
  14. 无源滤波器和有源滤波器有什么区别?-道合顺大数据infinigo
  15. 解决systemback 无法生成超过4G的iso的问题
  16. 2021厦大计算机考研炸了,厦门大学2021年硕士研究生复试名单
  17. 华为鸿蒙推送机型,华为鸿蒙系统开始推送,这15款机型可率先升级,有你的吗?...
  18. 从12个球任取8个球
  19. 分享一下 各类学习网站
  20. 通过边界代理一路打到三层内网+后渗透通用手法

热门文章

  1. 新浪财经分析报告(0605)
  2. ETSI TR101 290监测的三种级别错误接收端现象
  3. Navisworks2014-2020 安装说明
  4. 【stm32】引脚高低电平、上拉输入与下拉输入
  5. 【论文阅读】基于单幅图像的快速去雾
  6. 【vscode】vscode 一键删除所有注释
  7. MySQL 算数表达式
  8. 常用软件:FTP客户端 ftprush
  9. 【MD5】什么是MD5?md5的简要描述
  10. cir模型matlab代码,CIR模型MATLAB程序