文章目录

  • 1.Dataset类相关操作
  • 2.如何提升Dataset的读取性能
    • 1)prefetch方法
    • 2)interleave方法
    • 3)map多进程执行
    • 4)cache方法-小数据集
    • 5)性能优化代码演示
  • 3.案例讲解-猫狗图片分类

1.Dataset类相关操作





flat_map()

interleave使用




2.如何提升Dataset的读取性能


1)prefetch方法




2)interleave方法



3)map多进程执行



4)cache方法-小数据集


5)性能优化代码演示

import tensorflow as tf
import time
import os
print(tf.__version__)
data_dir = './datasets'
train_cats_dir = data_dir + '/train/cats/'
train_dogs_dir = data_dir + '/train/dogs/'
test_cats_dir = data_dir + '/valid/cats/'
test_dogs_dir = data_dir + '/valid/dogs/'# 构建训练数据集
train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)][:1000])
train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)][:1000])
train_filenames = tf.concat([train_cat_filenames, train_dog_filenames], axis=-1)
train_labels = tf.concat([tf.zeros(train_cat_filenames.shape, dtype=tf.int32), tf.ones(train_dog_filenames.shape, dtype=tf.int32)], axis=-1)#构建训练集def _decode_and_resize(filename, label):image_string = tf.io.read_file(filename)            # 读取原始文件image_decoded = tf.image.decode_jpeg(image_string)  # 解码JPEG图片image_resized = tf.image.resize(image_decoded, [256, 256]) / 255.0return image_resized, labelbatch_size = 32
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames, train_labels))def benchmark(dataset, num_epochs=1):start_time = time.perf_counter()for epoch_num in range(num_epochs):for sample in dataset:# Performing a training steptime.sleep(0.01)    # 用睡眠代替训练tf.print("Execution time:", time.perf_counter() - start_time)



3.案例讲解-猫狗图片分类

任务目标:Cats vs Dogs(猫狗大战)是Kaggle大数据竞赛的赛题,利用给定的数据集,用算法实现猫和狗的识别
图像分类问题

import tensorflow as tf
import os# 读取文件
data_dir = 'datasets'train_cats_dir = data_dir + '/train/cats/'
train_dogs_dir = data_dir + '/train/dogs/'
test_cats_dir = data_dir + '/valid/cats/'
test_dogs_dir = data_dir + '/valid/dogs/'# 构建训练数据集
# tf.constant创建常量
train_cat_filenames = tf.constant([train_cats_dir + filename for filename in os.listdir(train_cats_dir)])
train_dog_filenames = tf.constant([train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)])
# tf.concat用来拼接张量
train_filenames = tf.concat([train_cat_filenames,train_dog_filenames],axis=-1)# 设置label-猫=0,狗=1
train_labels = tf.concat([tf.zeros(train_cat_filenames.shape,dtype=tf.int32),tf.ones(train_dog_filenames.shape,dtype=tf.int32)],axis=-1)# 构建训练集def _decode_and_resize(filename,label):image_string = tf.io.read_file(filename)      # 读取原始文件image_decoded = tf.image.decode_jpeg(image_string)    # 解码JPEG图片image_resized = tf.image.resize(image_decoded,[256,256]) / 255.0  # 归一化return image_resized,labelbatch_size = 32
train_dataset = tf.data.Dataset.from_tensor_slices((train_filenames,train_labels))# map多进程执行
train_dataset = train_dataset.map(map_func = _decode_and_resize,num_parallel_calls = tf.data.experimental.AUTOTUNE)# 取出前buffer_size个数据放入buffer,并从中随机采样,采样后的数据用后续数据替换
train_dataset = train_dataset.shuffle(buffer_size=23000)     # 缓冲区train_dataset = train_dataset.repeat(count=3)                # 重复三次train_dataset = train_dataset.batch(batch_size)              train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)    # 优化# 构建测试数据集
test_cat_filenames = tf.constant([test_cats_dir + filename for filename in os.listdir(test_cats_dir)])
test_dog_filenames = tf.constant([test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)])
test_filenames = tf.concat([test_cat_filenames, test_dog_filenames], axis=-1)
test_labels = tf.concat([tf.zeros(test_cat_filenames.shape, dtype=tf.int32), tf.ones(test_dog_filenames.shape, dtype=tf.int32)], axis=-1)test_dataset = tf.data.Dataset.from_tensor_slices((test_filenames, test_labels))
test_dataset = test_dataset.map(_decode_and_resize)
test_dataset = test_dataset.batch(batch_size)class CNNModel(tf.keras.models.Model):def __init__(self):super(CNNModel,self).__init__()self.conv1 = tf.keras.layers.Conv2D(32,3,activation='relu')self.maxpool1 = tf.keras.layers.MaxPooling2D()self.conv2 = tf.keras.layers.Conv2D(32,5,activation='relu')self.maxpool2 = tf.keras.layers.MaxPooling2D()self.flatten = tf.keras.layers.Flatten()self.d1 = tf.keras.layers.Dense(64,activation='relu')self.d2 = tf.keras.layers.Dense(2,activation='softmax')   # sigmoid 与 softmax的区别
# softmax      CategoricalCrossentropy
# sigmoid      BinaryCrossentropydef call(self,x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.flatten(x)x = self.d1(x)x = self.d2(x)return xlearning_rate = 0.001
model = CNNModel()loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)train_loss = tf.keras.metrics.Mean(name='train_loss')     # 平均
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')# 每个batch后调用这些函数
@tf.function
def train_step(images,labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels,predictions)# 对可训练参数进行求导gradients = tape.gradient(loss,model.trainable_variables)# 将求导后的可训练参数进行更新optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss.update(loss)train_accuracy(labels,predictions)def test_step(images,labels):predictions = model(images)t_loss =loss_object(labels,predictions)test_loss(t_loss)test_accuracy(labels,predictions)EPOCHS = 10
for epoch in range(EPOCHS):# 在下一个epoch开始时,重置评估指标train_loss.reset_states()train_accuracy.reset_states()test_loss.reset_states()test_accuracy.reset_states()for images,labels in train_dataset:train_step(images,labels)for test_images,test_labels in test_dataset:test_step(test_images,test_labels)template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'print(template.format(epoch + 1,train_loss.result(),train_accuracy.result() * 100,test_loss.result(),test_accuracy.result() * 100))

运行结果:

深度学习11-tf.data详解以及猫狗图片分类实战相关推荐

  1. 深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现

    深度学习网络模型--RepVGG网络详解.RepVGG网络训练花分类数据集整体项目实现 0 前言 1 RepVGG Block详解 2 结构重参数化 2.1 融合Conv2d和BN 2.2 Conv2 ...

  2. 机器学习,深度学习基础算法原理详解(图的搜索、交叉验证、PAC框架、VC-维(持续更新))

    机器学习,深度学习基础算法原理详解(图的搜索.交叉验证.PAC框架.VC-维.支持向量机.核方法(持续更新)) 机器学习,深度学习基础算法原理详解(数据结构部分(持续更新)) 文章目录 1. 图的搜索 ...

  3. 深度学习网络模型——Vision Transformer详解 VIT详解

    深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...

  4. 深度学习之自编码器(2)Fashion MNIST图片重建实战

    深度学习之自编码器(2)Fashion MNIST图片重建实战 1. Fashion MNIST数据集 2. 编码器 3. 解码器 4. 自编码器 5. 网络训练 6. 图片重建 完整代码  自编码器 ...

  5. 记录|深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天

    记录|深度学习100例-卷积神经网络(CNN)彩色图片分类 | 第2天 1. 彩色图片分类效果图 数据集如下: 测试图1如下 训练/验证精确图如下: 优化后:测试图--打印预测标签: 优化后:测试图- ...

  6. Keras深度学习使用VGG16预训练神经网络实现猫狗分类

    Keras深度学习使用VGG16预训练神经网络实现猫狗分类 最近刚刚接触深度学习不久,而Keras呢,是在众多的深度学习框架中,最适合上手的,而猫狗的图像分类呢,也算是计算机视觉中的一个经典案例,下面 ...

  7. 体验AI乐趣:基于AI Gallery的二分类猫狗图片分类小数据集自动学习

    摘要:直接使用AI Gallery里面现有的数据集进行自动学习训练,很简单和方便,节约时间,不用自己去训练了,AI Gallery 里面有很多类似的有趣数据集,也非常好玩,大家一起试试吧. 本文分享自 ...

  8. 11.CNN实现真实猫狗图片分类

    CNN实现真实猫狗图片分类 个人认为,和上一节的mnist数据集里面的手写数字图片不同之处就是,真实的图片更加复杂,像素点更多.因此在对应的图片预处理方面会稍微麻烦一些.但是这个例子能让我们可以处理自 ...

  9. 从手写数字识别入门深度学习丨MNIST数据集详解

    就像无数人从敲下"Hello World"开始代码之旅一样,许多研究员从"MNIST数据集"开启了人工智能的探索之路. MNIST数据集(Mixed Natio ...

  10. 独家总结| 基于深度学习的目标检测详解

    欢迎关注微信公众号[计算机视觉联盟] 获取更多前沿AI.CV资讯                                                基于深度学习的目标检测 基于深度学习的目 ...

最新文章

  1. 成功解决 bs4\__init__.py:219: UserWarning: b'.' looks like a filename, not markup. You should probably
  2. visio 科学图形包_R可视乎|混合多个图形
  3. Yii2 主从 数据库
  4. javaone_JavaOne 2012:101种改进Java的方法-开发人员参与为何如此重要
  5. fread读取同一个文件得到缓冲区大小不同_缓冲区——计算机科学
  6. java synchronized atomic_atomic 包、synchronized | Java 中线程安全
  7. 3大细节让新站SEO更有竞争优势
  8. 一个开发周期为6个月的中小型软件开发项目成本预算大致表,不足之处请指点...
  9. Linux 性能调试 之 drop_caches
  10. 简单推箱子java_Java实现简单推箱子游戏
  11. 如何实现团队高效协作办公?
  12. 大学生软件设计大赛文档要求
  13. BZOJ2827 千山鸟飞绝 (离散+treap)
  14. 先定一个能达到的小决心,比方读个一本书 ——《小决心》读后感 @阿狸不歌
  15. php 通过sendcloud发送邮件附件功能
  16. Lombok链式调用,子类对象set父类属性,返回父类对象
  17. html像素测量尺,屏幕尺子:ruul.Screen ruler
  18. 尝试用python解概率题,并祝大小朋友儿童节快乐
  19. 获取当天,本周,本月,本季度,本半年,本年时间
  20. AP Memory IoT RAM——嵌入式物联网最佳存储解决方案

热门文章

  1. ByteBuffer分散和聚集的应用场景
  2. 精简ICO图标可减小EXE程序文件大小
  3. java day39【HTTP协议:响应消息 、Response对象 、ServletContext对象】
  4. ADO.Net之SqlConnection、 Sqlcommand的应用
  5. python对拍程序
  6. 在浏览器上运行Qt应用 emscripten-qt
  7. Linux操作系统中的IP配置
  8. C++编程语言中引用(reference)介绍
  9. python中静态方法、类方法、属性方法区别
  10. Android 怎么使用Bitmap+Canvas 自适应屏幕