本文介绍了如何加载各种数据源,以生成可以用于tensorflow使用的数据集,一般指Dataset。主要包括以下几类数据源:

  • 预定义的公共数据源
  • 内存中的数据
  • csv文件
  • TFRecord
  • 任意格式的数据文件
  • 稀疏数据格式文件

更完整的数据加载方式请参考:https://www.tensorflow.org/tutorials/load_data/images?hl=zh-cn

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import os
print(tf.__version__)
2.5.0

1、预定义的公共数据源

为了方便使用,tensorflow将一些常用的数据源预先处理好,用户可以直接使用。完整内容请参考:

https://www.tensorflow.org/datasets/overview

tensorflow的数据集有2种类型:

  • 简单的数据集,使用keras.datasets.***.load_data()即可以得到数据
  • 在tensorflow_datasets中的数据集。

1.1 简单数据集

常见的有mnist,fashion_mnist等返回的是numpy.ndarray的数据格式。

(x_train_all,y_train_all),(x_test,y_test) = keras.datasets.fashion_mnist.load_data()
print(type(x_train_all))
x_train_all[5,1],y_train_all[5]
<class 'numpy.ndarray'>(array([  0,   0,   0,   1,   0,   0,  20, 131, 199, 206, 196, 202, 242,255, 255, 250, 222, 197, 206, 188, 126,  17,   0,   0,   0,   0,0,   0], dtype=uint8),2)
(x_train_all,y_train_all),(x_test,y_test) = keras.datasets.mnist.load_data()
x_train_all[14,14],y_train_all[14]
(array([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  29,255, 254, 109,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,0,   0], dtype=uint8),1)

1.2 tensorflow_datasets

tensorflow_datasets提供的数据集。

题外话,由于tensorflow dataset被墙,请自备梯子。若在服务器等无法fq的环境,可以先在其它机器下载好,数据一般会下载到~/tensorflow_datasets目录下,然后把目录下的数据集上传到服务器相同的目录即可。tensorflow会优先检查本地目录是否有文件,再去下载。

通过tfds.load()可以方便的加载数据集,返回值为tf.data.Dataset类型;如果with_info=True,则返回(Dataset,ds_info)组成的tuple。
完整内容可参考:
https://www.tensorflow.org/datasets/api_docs/python/tfds/load

1.3 flower数据集

import tensorflow_datasets as tfdsdataset, info = tfds.load("tf_flowers", as_supervised=True, with_info=True)
class_names = info.features["label"].names
n_classes = info.features["label"].num_classes
dataset_size = info.splits["train"].num_examplestest_set_raw, valid_set_raw, train_set_raw = tfds.load("tf_flowers",split=["train[:10%]", "train[10%:25%]", "train[25%:]"],as_supervised=True)# 画一些花朵看一下
plt.figure(figsize=(12, 10))
index = 0
for image, label in train_set_raw.take(9):index += 1plt.subplot(3, 3, index)plt.imshow(image)plt.title("Class: {}".format(class_names[label]))plt.axis("off")plt.show()

2、加载内存中的数据

本部分内容主要将内存中的数据(numpy)转换为Dataset。

from_tensor_slices()将numpy数组中的每一个元素都转化为tensorflow Dataset中的一个元素:

dataset = tf.data.Dataset.from_tensor_slices(np.arange(10))
print(dataset)
for item in dataset:print(item)
<TensorSliceDataset shapes: (), types: tf.int64>
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(2, shape=(), dtype=int64)
tf.Tensor(3, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
tf.Tensor(5, shape=(), dtype=int64)
tf.Tensor(6, shape=(), dtype=int64)
tf.Tensor(7, shape=(), dtype=int64)
tf.Tensor(8, shape=(), dtype=int64)
tf.Tensor(9, shape=(), dtype=int64)

我们可以对这个Dataset做各种的操作,比如:

dataset = dataset.repeat(3).batch(7)
for item in dataset:print(item)
tf.Tensor([0 1 2 3 4 5 6], shape=(7,), dtype=int64)
tf.Tensor([7 8 9 0 1 2 3], shape=(7,), dtype=int64)
tf.Tensor([4 5 6 7 8 9 0], shape=(7,), dtype=int64)
tf.Tensor([1 2 3 4 5 6 7], shape=(7,), dtype=int64)
tf.Tensor([8 9], shape=(2,), dtype=int64)

我们还可以将多个数组整合成一个Dataset,常见的比如feature和label组合成训练样本:

x = np.array([[1, 2], [3, 4], [5, 6]])
y = np.array(['cat', 'dog', 'fox'])
dataset3 = tf.data.Dataset.from_tensor_slices((x, y))
print(dataset3)for item_x, item_y in dataset3:print(item_x.numpy(), item_y.numpy())
<TensorSliceDataset shapes: ((2,), ()), types: (tf.int64, tf.string)>
[1 2] b'cat'
[3 4] b'dog'
[5 6] b'fox'

或者这样做:

dataset4 = tf.data.Dataset.from_tensor_slices({"feature": x,"label": y})
for item in dataset4:print(item["feature"].numpy(), item["label"].numpy())
[1 2] b'cat'
[3 4] b'dog'
[5 6] b'fox'

3、加载csv文件的数据

本部分介绍了tensorflow如何加载csv文件生成Dataset。除了本部分介绍的方法外,如果数据量不大,也可以使用pandas.read_csv加载到内存后,再使用上面介绍的from_tensor_slice()。

3.1 生成csv文件

由于我们没有现成的csv文件,所以我们使用预定义好的公共数据集生成csv文件:

from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler# 获取数据
housing = fetch_california_housing()
x_train_all, x_test, y_train_all, y_test = train_test_split(housing.data, housing.target, random_state = 7)
x_train, x_valid, y_train, y_valid = train_test_split(x_train_all, y_train_all, random_state = 11)
print(x_train.shape, y_train.shape)
print(x_valid.shape, y_valid.shape)
print(x_test.shape, y_test.shape)# 标准化
scaler = StandardScaler()
x_train_scaled = scaler.fit_transform(x_train)
x_valid_scaled = scaler.transform(x_valid)
x_test_scaled = scaler.transform(x_test)# 写入csv文件
output_dir = "generate_csv"
if not os.path.exists(output_dir):os.mkdir(output_dir)def save_to_csv(output_dir, data, name_prefix,header=None, n_parts=10):path_format = os.path.join(output_dir, "{}_{:02d}.csv")filenames = []for file_idx, row_indices in enumerate(np.array_split(np.arange(len(data)), n_parts)):part_csv = path_format.format(name_prefix, file_idx)filenames.append(part_csv)with open(part_csv, "wt", encoding="utf-8") as f:if header is not None:f.write(header + "\n")for row_index in row_indices:f.write(",".join([repr(col) for col in data[row_index]]))f.write('\n')return filenamestrain_data = np.c_[x_train_scaled, y_train]
valid_data = np.c_[x_valid_scaled, y_valid]
test_data = np.c_[x_test_scaled, y_test]
header_cols = housing.feature_names + ["MidianHouseValue"]
header_str = ",".join(header_cols)train_filenames = save_to_csv(output_dir, train_data, "train",header_str, n_parts=20)
valid_filenames = save_to_csv(output_dir, valid_data, "valid",header_str, n_parts=10)
test_filenames = save_to_csv(output_dir, test_data, "test",header_str, n_parts=10)# 看一下生成的文件:
import pprint
print("train filenames:")
pprint.pprint(train_filenames)
print("valid filenames:")
pprint.pprint(valid_filenames)
print("test filenames:")
pprint.pprint(test_filenames)
(11610, 8) (11610,)
(3870, 8) (3870,)
(5160, 8) (5160,)
train filenames:
['generate_csv/train_00.csv','generate_csv/train_01.csv','generate_csv/train_02.csv','generate_csv/train_03.csv','generate_csv/train_04.csv','generate_csv/train_05.csv','generate_csv/train_06.csv','generate_csv/train_07.csv','generate_csv/train_08.csv','generate_csv/train_09.csv','generate_csv/train_10.csv','generate_csv/train_11.csv','generate_csv/train_12.csv','generate_csv/train_13.csv','generate_csv/train_14.csv','generate_csv/train_15.csv','generate_csv/train_16.csv','generate_csv/train_17.csv','generate_csv/train_18.csv','generate_csv/train_19.csv']
valid filenames:
['generate_csv/valid_00.csv','generate_csv/valid_01.csv','generate_csv/valid_02.csv','generate_csv/valid_03.csv','generate_csv/valid_04.csv','generate_csv/valid_05.csv','generate_csv/valid_06.csv','generate_csv/valid_07.csv','generate_csv/valid_08.csv','generate_csv/valid_09.csv']
test filenames:
['generate_csv/test_00.csv','generate_csv/test_01.csv','generate_csv/test_02.csv','generate_csv/test_03.csv','generate_csv/test_04.csv','generate_csv/test_05.csv','generate_csv/test_06.csv','generate_csv/test_07.csv','generate_csv/test_08.csv','generate_csv/test_09.csv']

3.2 加载csv的文件内的数据

# 1. filename -> dataset
# 2. read file -> dataset -> datasets -> merge
# 3. parse csv
def csv_reader_dataset(filenames, n_readers=5,batch_size=32, n_parse_threads=5,shuffle_buffer_size=10000):dataset = tf.data.Dataset.list_files(filenames)dataset = dataset.repeat()dataset = dataset.interleave(lambda filename: tf.data.TextLineDataset(filename).skip(1),cycle_length = n_readers)dataset.shuffle(shuffle_buffer_size)dataset = dataset.map(parse_csv_line,num_parallel_calls=n_parse_threads)dataset = dataset.batch(batch_size)return datasetdef parse_csv_line(line, n_fields = 9):defs = [tf.constant(np.nan)] * n_fieldsparsed_fields = tf.io.decode_csv(line, record_defaults=defs)x = tf.stack(parsed_fields[0:-1])y = tf.stack(parsed_fields[-1:])return x, ytrain_set = csv_reader_dataset(train_filenames, batch_size=3)
for x_batch, y_batch in train_set.take(2):print("x:")pprint.pprint(x_batch)print("y:")pprint.pprint(y_batch)
x:
<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
array([[-0.32652634,  0.4323619 , -0.09345459, -0.08402992,  0.8460036 ,-0.02663165, -0.56176794,  0.1422876 ],[ 0.48530516, -0.8492419 , -0.06530126, -0.02337966,  1.4974351 ,-0.07790658, -0.90236324,  0.78145146],[-1.0591781 ,  1.3935647 , -0.02633197, -0.1100676 , -0.6138199 ,-0.09695935,  0.3247131 , -0.03747724]], dtype=float32)>
y:
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[2.431],[2.956],[0.672]], dtype=float32)>
x:
<tf.Tensor: shape=(3, 8), dtype=float32, numpy=
array([[ 8.0154431e-01,  2.7216142e-01, -1.1624393e-01, -2.0231152e-01,-5.4305160e-01, -2.1039616e-02, -5.8976209e-01, -8.2418457e-02],[ 4.9710345e-02, -8.4924191e-01, -6.2146995e-02,  1.7878747e-01,-8.0253541e-01,  5.0660671e-04,  6.4664572e-01, -1.1060793e+00],[ 2.2754266e+00, -1.2497431e+00,  1.0294788e+00, -1.7124432e-01,-4.5413753e-01,  1.0527152e-01, -9.0236324e-01,  9.0129471e-01]],dtype=float32)>
y:
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[3.226],[2.286],[3.798]], dtype=float32)>
batch_size = 32
train_set = csv_reader_dataset(train_filenames,batch_size = batch_size)
valid_set = csv_reader_dataset(valid_filenames,batch_size = batch_size)
test_set = csv_reader_dataset(test_filenames,batch_size = batch_size)

3.3 训练模型

model = keras.models.Sequential([keras.layers.Dense(30, activation='relu',input_shape=[8]),keras.layers.Dense(1),
])
model.compile(loss="mean_squared_error", optimizer="sgd")
callbacks = [keras.callbacks.EarlyStopping(patience=5, min_delta=1e-2)]history = model.fit(train_set,validation_data = valid_set,steps_per_epoch = 11160 // batch_size,validation_steps = 3870 // batch_size,epochs = 10,callbacks = callbacks)
Epoch 1/100
348/348 [==============================] - 1s 3ms/step - loss: 1.5927 - val_loss: 2.1706
Epoch 2/100
348/348 [==============================] - 1s 2ms/step - loss: 0.7043 - val_loss: 0.5049
Epoch 3/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4733 - val_loss: 0.4638
Epoch 4/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4384 - val_loss: 0.4345
Epoch 5/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4070 - val_loss: 0.4233
Epoch 6/100
348/348 [==============================] - 1s 4ms/step - loss: 0.4066 - val_loss: 0.4139
Epoch 7/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4051 - val_loss: 0.4155
Epoch 8/100
348/348 [==============================] - 1s 4ms/step - loss: 0.3824 - val_loss: 0.3957
Epoch 9/100
348/348 [==============================] - 1s 3ms/step - loss: 0.3956 - val_loss: 0.3884
Epoch 10/100
348/348 [==============================] - 1s 3ms/step - loss: 0.3814 - val_loss: 0.3856
Epoch 11/100
348/348 [==============================] - 1s 2ms/step - loss: 0.4826 - val_loss: 0.3887
Epoch 12/100
348/348 [==============================] - 1s 3ms/step - loss: 0.3653 - val_loss: 0.3853
Epoch 13/100
348/348 [==============================] - 1s 3ms/step - loss: 0.3765 - val_loss: 0.3810
Epoch 14/100
348/348 [==============================] - 1s 4ms/step - loss: 0.3632 - val_loss: 0.3775
Epoch 15/100
348/348 [==============================] - 1s 4ms/step - loss: 0.3654 - val_loss: 0.3758
model.evaluate(test_set, steps = 5160 // batch_size)
161/161 [==============================] - 1s 2ms/step - loss: 0.38110.38114801049232483

tensorflow系列之1:加载数据相关推荐

  1. 7.3 TensorFlow笔记(基础篇):加载数据之从队列中读取

    前言 整体步骤 在TensorFlow中进行模型训练时,在官网给出的三种读取方式,中最好的文件读取方式就是将利用队列进行文件读取,而且步骤有两步: 1. 把样本数据写入TFRecords二进制文件 2 ...

  2. 7.1 TensorFlow笔记(基础篇):加载数据之预加载数据与填充数据

    TensorFlow加载数据 TensorFlow官方共给出三种加载数据的方式: 1. 预加载数据 2. 填充数据 预加载数据的缺点: 将数据直接嵌在数据流图中,当训练数据较大时,很消耗内存.填充的方 ...

  3. 在TensorFlow中使用pipeline加载数据

    正文共2028个字,6张图,预计阅读时间6分钟. 前面对TensorFlow的多线程做了测试,接下来就利用多线程和Queue pipeline地加载数据.数据流如下图所示: 首先,A.B.C三个文件通 ...

  4. 【TensorFlow-windows】keras接口——利用tensorflow的方法加载数据

    前言 之前使用tensorflow和keras的时候,都各自有一套数据读取方法,但是遇到一个问题就是,在训练的时候,GPU的利用率忽高忽低,极大可能是由于训练过程中读取每个batch数据造成的,所以又 ...

  5. 用爬虫抓取动态加载数据丨Python爬虫实战系列(6)

    提示:最新Python爬虫资料/代码练习>>戳我直达 前言 抓取动态加载数据 话不多说,开练! 爬虫抓取动态加载数据 确定网站类型 首先要明确网站的类型,即是动态还是静态.检查方法:右键查 ...

  6. TensorFlow加载数据的方式

    tensorflow作为符号编程框架,需要先构建数据流图,再读取数据,然后再进行训练.tensorflow提供了以下三种方式来加载数据: 预加载数据(preloaded data):在tensorfl ...

  7. Android官方开发文档Training系列课程中文版:后台加载数据之处理CursorLoader的查询结果

    原文地址:http://android.xsoftlab.net/training/load-data-background/handle-results.html 就像上节课所说的,我们应该在onC ...

  8. skyline系列10 - Skyline TerraExplorer 加载数据使用方法 (客户版)

    Skyline  TerraExplorer 加载数据使用方法 方法1:直接打开已经加载了地形数据的pmt文件,找到影像位置即可 如下图: 加载以后可能会因为影像片太小看不清,需要自己找到并放大查看, ...

  9. echarts在.Net中使用实例(二) 使用ajax动态加载数据

    通过上一篇文章可以知道和echarts参考手册可知,series字段就是用来存储我们显示的数据,所以我们只需要用ajax来获取series的值就可以. option 名称 描述 {color}back ...

  10. spring 启动加载数据_12个很棒的Spring数据教程来启动您的数据项目

    spring 启动加载数据 Spring Data的任务是为数据访问提供一个熟悉且一致的,基于Spring的编程模型,同时仍保留基础数据存储的特​​殊特征. 它使使用数据访问技术,关系和非关系数据库, ...

最新文章

  1. java中记忆深刻的问题_工作中碰到比较印象深刻的问题(面试必问)
  2. 自学使用sort他命令使用
  3. UIScrollView 的代理方法简单注解
  4. scjp考试准备 - 4 - 关于数组
  5. 2021-01-07 matlab数值分析 线性方程组的迭代解法 高斯-赛德尔迭代法
  6. QToolButton设置图标位置
  7. μC/OS-Ⅱ的移植
  8. 使用 Visual Studio 2019 批量添加代码文件头
  9. pat 乙级 1011 A+B 和 C(C++)
  10. golang日志输出
  11. 这可能是最好的RxJava 2.x 入门教程学习系列
  12. 解决mysql地区时间错误_mysql time zone时区的错误解决
  13. python selenium 等待js加载完成_一个用python完成的RSA成功模拟JS加密完成自动登录...
  14. xps文件的查看及转换
  15. MySQL开发者需要了解的12个技巧与窍门
  16. 你怎么保存微博中喜欢的视频
  17. 金蝶迷你版凭证导入工具_金蝶kis迷你版如何插入凭证?
  18. 手把手教你如何制作iPhone卡贴(多图)
  19. linux网络编程面试题
  20. python自己的手稿四之互动沟通

热门文章

  1. 搜索引擎链接算法之:HITS算法解析
  2. gsu 2524 Frozen Rose-Heads
  3. python参数是什么_最全Python快速入门教程,满满都是干货
  4. response.setHeader各种用法 .
  5. Android 通过字符串来获取R下面资源的ID 值 文字资源
  6. HTML 取消超链接下划线
  7. 如何判断Socket连接失效
  8. xml 和android脚本之家,AndroidManifest.xml配置文件解析_Android_脚本之家
  9. jest java_✅使用jest进行测试驱动开发
  10. java list数组排序_浅谈对象数组或list排序及Collections排序原理