概要

本文内容来源于TensorFlow教程
本文主要介绍了三种图片数据的加载和预处理方法:

  1. 使用高级的Keras预处理工具(如tf.keras.utils.image_dataset_from_directory)和预处理层(如tf.keras.layers.Rescaling)从磁盘的图片目录中加载数据。
  2. 使用tf.data的框架写你自己的输入通道。
  3. TensorFlow Datasets中从可用的类别加载数据集。

内容

import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
import tensorflow_datasets as tfds
import pathlib# 下载花的数据集
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,fname='flower_photos',untar=True)
data_dir = pathlib.Path(data_dir)
os.listdir(data_dir)  # ['LICENSE.txt', 'tulips', 'roses', 'dandelion', 'daisy', 'sunflowers']
image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)  # 3670

数据集的目录格式:

flowers_photos/
          daisy/
          dandelion/
          roses/
          sunflowers/
          tulips/

# 画图展示
roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))


使用tf.keras.utils.image_dataset_from_directory将图片数据集加载存入内存

batch_size = 32
img_height = 180
img_width = 180train_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.utils.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)# 可视化图片
import matplotlib.pyplot as pltplt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")# 查看训练集中数据的shape
for image_batch, labels_batch in train_ds:print(image_batch.shape)   # (32, 180, 180, 3)  32是batch size的大小,180 * 180是图片的维度,3是图片的通道数RGB格式print(labels_batch.shape)  # (32,)  batch_size=32break# RGB通道图像的像素值在[0,255],为了更好的模型训练,进行放缩到[0,1]。
normalization_layer = tf.keras.layers.Rescaling(1./255)# 也可以将其放缩到[-1,1]
# normalization_layer = tf.keras.layers.Rescaling(1./127.5, offset=-1)normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixel values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image))

这里注意,我们在使用tf.keras.utils.image_dataset_from_directory加载数据的时候使用image_size参数重新定义了图片的大小。这个步骤也可以定义在模型中,通过使用tf.keras.layers.Resizing
大数据集的情况数据加载有可能会成为模型训练的瓶颈,可以通过以下两种方法使用缓存的方式加载数据:

  1. Dataset.cache,数据集从磁盘上加载放入内存中,如果数据集太大内存放不下,则可以使用此方法创建一个性能磁盘缓存。
  2. Dataset.prefetch,训练时重叠数据预处理和模型执行。
AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

以上是使用tf.keras.utils.image_dataset_from_directory加载数据的方法,下面使用tf.data更好的控制数据输入,通过使用tf.data编写自己的数据输入通道。

list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'), shuffle=False)
list_ds = list_ds.shuffle(image_count, reshuffle_each_iteration=False)for f in list_ds.take(5):print(f.numpy())# 使用文件的树结构生成类别组数
class_names = np.array(sorted([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"]))
print(class_names)
# 划分训练集和验证集
val_size = int(image_count * 0.2)
train_ds = list_ds.skip(val_size)
val_ds = list_ds.take(val_size)print(tf.data.experimental.cardinality(train_ds).numpy())
print(tf.data.experimental.cardinality(val_ds).numpy())# 转换文件路径成(img, label)对
def get_label(file_path):# Convert the path to a list of path componentsparts = tf.strings.split(file_path, os.path.sep)# The second to last is the class-directoryone_hot = parts[-2] == class_names# Integer encode the labelreturn tf.argmax(one_hot)def decode_img(img):# Convert the compressed string to a 3D uint8 tensorimg = tf.io.decode_jpeg(img, channels=3)# Resize the image to the desired sizereturn tf.image.resize(img, [img_height, img_width])def process_path(file_path):label = get_label(file_path)# Load the raw data from the file as a stringimg = tf.io.read_file(file_path)img = decode_img(img)return img, label# 使用`Dataset.map`创建一个image,label对数据集
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
train_ds = train_ds.map(process_path, num_parallel_calls=AUTOTUNE)
val_ds = val_ds.map(process_path, num_parallel_calls=AUTOTUNE)for image, label in train_ds.take(1):print("Image shape: ", image.numpy().shape)print("Label: ", label.numpy())

为了性能配置数据集。

def configure_for_performance(ds):ds = ds.cache()  # 缓存ds = ds.shuffle(buffer_size=1000)  # 打乱数据ds = ds.batch(batch_size) # 批处理ds = ds.prefetch(buffer_size=AUTOTUNE) # 保证批量数据尽快可用return dstrain_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)# 可视化数据
image_batch, label_batch = next(iter(train_ds))plt.figure(figsize=(10, 10))
for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(image_batch[i].numpy().astype("uint8"))label = label_batch[i]plt.title(class_names[label])plt.axis("off")

使用TensorFlow数据集

(train_ds, val_ds, test_ds), metadata = tfds.load('tf_flowers',split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'],with_info=True,as_supervised=True,
)num_classes = metadata.features['label'].num_classes
print(num_classes)get_label_name = metadata.features['label'].int2strimage, label = next(iter(train_ds))
_ = plt.imshow(image)
_ = plt.title(get_label_name(label))train_ds = configure_for_performance(train_ds)
val_ds = configure_for_performance(val_ds)
test_ds = configure_for_performance(test_ds)

为了完整性,我们构建了一个卷积网络训练模型。三个带最大池化的卷积层,一个全连接层

num_classes = 5model = tf.keras.Sequential([tf.keras.layers.Rescaling(1./255),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(32, 3, activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(num_classes)
])model.compile(optimizer='adam',loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model.fit(train_ds,validation_data=val_ds,epochs=3
)

我们也可以自己写一个训练循环器替代model.fit,详情参考:从头编写训练循环

TensorFlow构建模型(图片数据加载)六相关推荐

  1. nuScenes自动驾驶数据集:数据格式精解,格式转换,模型的数据加载 (一)

    nuScenes数据集及nuScenes开发工具包简介 文章目录 nuScenes数据集及nuScenes开发工具包简介 1.1. nuScenes数据集简介: 1.2 数据采集: 1.2.1 传感器 ...

  2. nuScenes自动驾驶数据集:格式转换,模型的数据加载(二)

    文章目录 一.nuScenes数据集格式精解 二.nuScenes数据格式转换(To COCO) 数据格式转换框架 2.1 核心:convert_nuScenes.py解析 其他格式转换文件 2.1. ...

  3. TensorFlow2数据加载与数据集

    加载数据集 keras 加载在线数据集 tf.keras.datasets提供了加载在线数据集的API,其中可加载的数据集包括: boston_housing module: Boston housi ...

  4. PyTorch1.12 亮点一览 | DataPipe + TorchArrow 新的数据加载与处理范式

    目录 前言 现有的 Dataset 和 DataLoader 及其存在的问题 新的数据加载方式:DataPipe 与 DataLoader2 结构化数据处理新范式:TorchArrow 总结 参考链接 ...

  5. android平台gallery2应用分析,Android5.1图库Gallery2代码分析数据加载流程

    图片数据加载流程. Gallery---->GalleryActivity------>AlbumSetPage------->AlbumPage--------->Photo ...

  6. 作业1:关于使用python中scikit-learn(sklearn)模块,实现鸢尾花(iris)相关数据操作(数据加载、标准化处理、构建聚类模型并训练、可视化、评价模型)

    操作题:利用鸢尾花数据实现数据加载.标准化处理.构建聚类模型并训练.聚类效果可视化展示及对模型进行评价 一.数据加载 from sklearn.datasets import load_iris fr ...

  7. R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载、划分数据、并分别构建线性回归模型和广义线性加性模型GAMs、并比较线性模型和GAMs模型的性能

    R语言广义加性模型(GAMs:Generalized Additive Model)建模:数据加载.划分数据.并分别构建线性回归模型和广义线性加性模型GAMs.并比较线性模型和GAMs模型的性能 目录

  8. 深度学习-Pytorch:项目标准流程【构建、保存、加载神经网络模型;数据集构建器Dataset、数据加载器DataLoader(线性回归案例、手写数字识别案例)】

    1.拿到文本,分词,清晰数据(去掉停用词语): 2.建立word2index.index2word表 3.准备好预训练好的word embedding 4.做好DataSet / Dataloader ...

  9. [tensorflow] 模型保存、加载与转换详解

    TensorFlow模型加载与转换详解 本次讲解主要涉及到TensorFlow框架训练时候模型文件的管理以及转换. 首先我们需要明确TensorFlow模型文件的存储格式以及文件个数: model_f ...

  10. 多输出模型实例的数据加载

    多输出模型实例的数据加载 相关的数据集放在C:/Users/Administrator/data/moc import tensorflow as tf from tensorflow import ...

最新文章

  1. Python使用QRCode模块生成二维码
  2. 使用python获取路径问题
  3. 申请贷款必须留联系人吗?不留行不行?
  4. 一款强大的 Kubernetes API 流量查看神器
  5. linux open dev/tty0 receive_buf,书写基于内核的linux键盘纪录器(p9-0e)(3)
  6. npm install全局安装的模块路径自定义管理
  7. 【转】小议Bug敏感度---Bug敏感度的故事(一)
  8. 了解RxJava以及如何在Android应用中使用它
  9. php该怎么下载文件,php怎么实现下载文件?
  10. 【图像处理】海森矩阵(Hessian Matrix)及一个用例(图像增强)
  11. 注音输入与拼音输入对照
  12. python将多个列表合并_Python中将两个或多个list合成一个list的方法小结
  13. Centos 7搭建PPTP服务器方法
  14. Python 爬取新浪网新闻和存取CSV文件
  15. 提问的智慧!高手如何成长为高手,高手原来也是像我一样的菜鸟!
  16. (一)Xray-的安装,入门的使用方法
  17. 绝对定位后的DIV水平居中
  18. mtk,展讯等手机平台知识杂烩
  19. kernel 加载用户空间fw实现原理
  20. 在农村养殖什么最赚钱,推荐这两个项目,一年收入还是不错的

热门文章

  1. 蓝桥杯官网 试题 PREV-253 历届真题 质数行者【第十一届】【决赛】【研究生组】【C++】【Java】两种解法
  2. 蓝桥杯官网 试题 PREV-94 历届真题 矩阵计数【第十届】【决赛】【研究生组】【C++】解法
  3. 摩尔定律终结后 科技也许会向这3个方向前进
  4. 记一次jenkins构建无权限问题
  5. 黑帽SEO必须掌握的四种暗链代码
  6. Unity Koreographer 之 音乐制作插件介绍学习,一般使用步骤介绍(包括:一般音乐游戏制作流程简绍) 一
  7. .net 6简单使用NPOI 读取 Excel 案例+流程
  8. Unity3D学习 ④ Unity导入商店资源,实现基本的奔跑、攻击动作切换与交互
  9. 矩阵键盘行列扫描c语言,单片机矩阵键盘按钮行列逐级扫描法
  10. Web压力测试和手机App测试