文章目录

  • 1 tf.data.Dataset.from_tensor_slices() 数据集建立
  • 2. Dataset.map(f) 数据集预处理
  • 3. Dataset.prefetch() 并行处理
  • 4. for 循环获取数据
  • 5. 例子: 猫狗分类

学习于:简单粗暴 TensorFlow 2

1 tf.data.Dataset.from_tensor_slices() 数据集建立

tf.data.Dataset.from_tensor_slices()

import matplotlib.pyplot as plt
(train_data, train_label), (_, _) = tf.keras.datasets.mnist.load_data()
train_data = np.expand_dims(train_data.astype(np.float32)/255., axis=-1)mnistdata = tf.data.Dataset.from_tensor_slices((train_data, train_label))for img, label in mnistdata:plt.title(label.numpy())plt.imshow(img.numpy())plt.show()

2. Dataset.map(f) 数据集预处理

  • Dataset.map(f) 应用变换
def rotate90(img, label):img = tf.image.rot90(img)return img, labelmnistdata = mnistdata.map(rotate90)

  • Dataset.batch(batch_size) 分批
mnistdata = mnistdata.batch(5)for img, label in mnistdata:  # img [5,28,28,1],label [5] 包含 5个样本 fig, axs = plt.subplots(1, 5)  # 1 行 5列for i in range(5):axs[i].set_title(label.numpy()[i])axs[i].imshow(img.numpy()[i, :, :, :])plt.show()

  • Dataset.shuffle(buffer_size) 随机打乱
    buffer_size = 1,没有打乱的效果
    数据集较随机,buffer_size 可小一些,否则,设置大一些
    我在做猫狗分类例子的时候,遇到内存不足的报错,建议可以提前打乱数据
# 每次得到的数字不太一样
mnistdata = mnistdata.shuffle(buffer_size=100).batch(5)

3. Dataset.prefetch() 并行处理

  • Dataset.prefetch() 开启预加载数据,使得在 GPU 训练的同时 CPU 可以准备数据
mnistdata = mnistdata.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
# 可设置自动寻找 合适的 buffer_size
  • num_parallel_calls 多核心并行处理
mnistdata = mnistdata.map(map_func=rotate90,num_parallel_calls=2)
# 也可以自动找参数 tf.data.experimental.AUTOTUNE

4. for 循环获取数据

# for 循环
dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
for a, b, c, ... in dataset:# 对张量a, b, c等进行操作,例如送入模型进行训练# 或者 创建迭代器,使用 next() 获取
dataset = tf.data.Dataset.from_tensor_slices((A, B, C, ...))
it = iter(dataset)
a_0, b_0, c_0, ... = next(it)
a_1, b_1, c_1, ... = next(it)

5. 例子: 猫狗分类

项目及数据地址:https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition/overview

The train folder contains 25,000 images of dogs and cats. Each image in this folder has the label as part of the filename. The test folder contains 12,500 images, named according to a numeric id.

For each image in the test set, you should predict a probability that the image is a dog (1 = dog, 0 = cat).

# ---------cat vs dog-------------
# https://michael.blog.csdn.net/
import tensorflow as tf
import pandas as pd
import numpy as np
import random
import osnum_epochs = 10
batch_size = 32
learning_rate = 1e-4
train_data_dir = "./dogs-vs-cats/train/"
test_data_dir = "./dogs-vs-cats/test/"# 数据处理
def _decode_and_resize(filename, label=None):img_string = tf.io.read_file(filename)img_decoded = tf.image.decode_jpeg(img_string)img_resized = tf.image.resize(img_decoded, [256, 256]) / 255.if label == None:return img_resizedreturn img_resized, label# 使用 tf.data.Dataset 生成数据
def processData(train_filenames, train_labels):train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))train_dataset = train_dataset.map(map_func=_decode_and_resize)# train_dataset = train_dataset.shuffle(buffer_size=25000) # 非常耗内存,不使用train_dataset = train_dataset.batch(batch_size)train_dataset = train_dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)return train_datasetif __name__ == "__main__":# 训练文件路径file_dir = [train_data_dir + filename for filename in os.listdir(train_data_dir)]labels = [0 if filename[0] == 'c' else 1for filename in os.listdir(train_data_dir)]# 打包并打乱f_l = list(zip(file_dir, labels))random.shuffle(f_l)file_dir, labels = zip(*f_l)# 切分训练集,验证集valid_ratio = 0.1idx = int((1 - valid_ratio) * len(file_dir))train_files, valid_files = file_dir[:idx], file_dir[idx:]train_labels, valid_labels = labels[:idx], labels[idx:]# 使用 tf.data.Dataset 生成数据集train_filenames, valid_filenames = tf.constant(train_files), tf.constant(valid_files)train_labels, valid_labels = tf.constant(train_labels), tf.constant(valid_labels)train_dataset = processData(train_filenames, train_labels)valid_dataset = processData(valid_filenames, valid_labels)# 建模 model = tf.keras.Sequential([tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(256, 256, 3)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Conv2D(64, 5, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Conv2D(128, 5, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Dropout(0.2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(2, activation='softmax')])# 模型配置model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),loss=tf.keras.losses.sparse_categorical_crossentropy,metrics=[tf.keras.metrics.sparse_categorical_accuracy])# 训练model.fit(train_dataset, epochs=num_epochs, validation_data=valid_dataset)# 测试 testtest_filenames = tf.constant([test_data_dir + filename for filename in os.listdir(test_data_dir)])test_data = tf.data.Dataset.from_tensor_slices(test_filenames)test_data = test_data.map(map_func=_decode_and_resize)test_data = test_data.batch(batch_size)ans = model.predict(test_data) # ans [12500, 2]prob = ans[:, 1] # dog 的概率# 写入提交文件id = list(range(1, 12501))output = pd.DataFrame({'id': id, 'label': prob})output.to_csv("submission.csv", index=False)

提交成绩:

榜首他人成绩:

  • 把模型改成 MobileNetV2 + FC,训练 2 个 epochs
basemodel = tf.keras.applications.MobileNetV2(input_shape=(256,256,3), include_top=False, classes=2)
model = tf.keras.Sequential([basemodel,tf.keras.layers.Flatten(),tf.keras.layers.Dense(64, activation='relu'),tf.keras.layers.Dense(2, activation='softmax')
])

结果:

704/704 [==============================] - 179s 254ms/step
- loss: 0.0741 - sparse_categorical_accuracy: 0.9737
- val_loss: 0.1609 - val_sparse_categorical_accuracy: 0.9744
704/704 [==============================] - 167s 237ms/step
- loss: 0.0128 - sparse_categorical_accuracy: 0.9955
- val_loss: 0.0724 - val_sparse_categorical_accuracy: 0.9848

准确率(99%, 98%)比上面第一种模型高(第一种模型大概是训练集 92%, 验证集80%)

测试时,损失值竟然比上面的大,怎么解释?貌似第二种方案也没有过拟合吧,训练集和验证集准确率差不多。

TensorFlow 2.0 - tf.data.Dataset 数据预处理 猫狗分类相关推荐

  1. 记录 之 tensorflow函数:tf.data.Dataset.from_tensor_slices

    tf.data.Dataset.from_tensor_slices(),是常见的数据处理函数,它的作用是将给定的元组(turple).列表(list).张量(tensor)等特征进行特征切片.切片的 ...

  2. tensorflow基础:tf.data.Dataset.from_tensor_slices()

    tf.data.Dataset.from_tensor_slices() 语义解释:from_tensor_slices,从张量的切片读取数据. 工作原理:将输入的张量的第一个维度看做样本的个数,沿其 ...

  3. tensorflow基础:tf.data.Dataset.from_tensor_slices() 与 tf.data.Dataset.from_generator()的异同

    tf.data.Dataset.from_tensor_slices(tensor): -->将tensor沿其第一个维度切片,返回一个含有N个样本的数据集(假设tensor的第一个维度为N). ...

  4. TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和 tf.data.Dataset机制

    TensorFlow数据读取机制:文件队列 tf.train.slice_input_producer和tf.data.Dataset机制 之前写了一篇博客,关于<Tensorflow生成自己的 ...

  5. tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator

    tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练. 也有tensorflow中的 tf.data.DataSet的使用.并且由于是tensorflow框架 ...

  6. tf.data.Dataset 用法

    tf.data.DatasetAPI支持写入的描述性和高效的输入管线.Dataset用法遵循一个常见模式: 从输入数据创建源数据集. 应用数据集转换来预处理数据. 迭代数据集并处理元素. 迭代以流式方 ...

  7. tf.data.Dataset.from_tensor_slices() 详解

    函数原型: tf.data.Dataset.from_tensor_slices(tensors, name=None ) 官网地址:https://www.tensorflow.org/api_do ...

  8. tf.data.Dataset.interleave

    本文对tf.data.Dataset.interleave该方法做点笔记. 在tensorflow中数据处理通道中,有一个方法interleave,tf定义如下: interleave(map_fun ...

  9. 【Tensorflow 2.0 正式版教程】tf.data.Dataset的基本使用方法

    Tensorflow 2.0中提供了专门用于数据输入的接口tf.data.Dataset,可以简洁高效的实现数据的读入.打乱(shuffle).增强(augment)等功能.下面以一个简单的实例讲解该 ...

最新文章

  1. 基于视角特征提取的3D检测方法汇总
  2. 魅族升级鸿蒙,魅族要“弯道超车”?率先升级鸿蒙OS,挥别安卓系统
  3. 神策数据实战学堂开课,分享行业最佳业务和技术实践
  4. VNC怎么和宿主机共享粘贴板(整理)
  5. 新装的主机没有ifconfig,route等命令,怎么查找对应的安装包
  6. JavaScript多继承(转载)
  7. 更改hadoop集群yarn的webui中的开始时间和结束时间为本地时间
  8. 畅通工程(并查集模版题)
  9. Web 2.0下一个版本是什么 3.0就要到来了吗?
  10. SQL Server 自增字段归零等问题
  11. JavaScript JSON.stringify()
  12. scrapy数据存储在mysql数据库的两种方式
  13. FreeRTOS 教程指南 学习笔记 第五章 软件计时器
  14. wps建立的文件后缀名为docx,写在里面的东西还不允许保存
  15. 阿里天池—2022江苏气象预测AI算法挑战赛
  16. 汇编语言里 eax ebx ecx edx esi edi ebp esp这些都是什么意思啊
  17. 系列 HTML+JS GAME制作 之 移动消灭-方块
  18. 【STM32 x ESP8266】连接阿里云 MQTT 服务器(报文连接)
  19. tekton入门 - tasks
  20. 由于代理原因,联网失败的解决方法

热门文章

  1. cookie和session之会话机制:   http 协议  ---》 无状态协议
  2. 回归分析什么时候取对数_冬蜜什么时候取,冬天取蜂蜜的方法
  3. html5carousel图片轮播,jQuery响应式轮播图插件VM Carousel
  4. java中非法运算符_Java 中的运算符和流程控制相关内容的理解
  5. 十字连接焊盘_你应该知道的焊盘基础知识
  6. 二进制、十进制、其他进制之间的转换
  7. wangeditor html编辑,Vue整合wangEditor富文本编辑器
  8. Java连载1-概述常用的dos命令
  9. WPF 开源项目 【watcher】 守望者,一款监控,统计,分析你每天在自己电脑上究竟干了什么的软件...
  10. 设置拖拽事件,获取拖拽内容