文章目录

  • 1.TFRecord简介
    • 1)TFRecord是什么
    • 2)为什么用TFRecord
    • 3)TFRecord格式
  • 2.写入TFRecord
  • 3.读取TFRecord文件
  • 4.案例实战-猫狗图片分类

1.TFRecord简介

1)TFRecord是什么

TFRecord --> Example --> feature --> key-value键值对,并且value的取值有三种


2)为什么用TFRecord


为什么要用TFRecord?

3)TFRecord格式



2.写入TFRecord



3.读取TFRecord文件


4.案例实战-猫狗图片分类

import tensorflow as tf
import os# 处理文件路径
data_dir = 'datasets'train_cats_dir = data_dir + '/train/cats/'
train_dogs_dir = data_dir + '/train/dogs/'train_tfrecord_file = data_dir + '/train/train.tfrecords'test_cats_dir = data_dir + '/valid/cats/'
test_dogs_dir = data_dir + '/valid/dogs/'test_tfrecord_file = data_dir + '/valid/test.tfrecords'# 将数据存储为TFRecord文件# 文件名字变成list
train_cat_filenames = [train_cats_dir + filename for filename in os.listdir(train_cats_dir)]
train_dog_filenames = [train_dogs_dir + filename for filename in os.listdir(train_dogs_dir)]
train_filenames = train_cat_filenames + train_dog_filenames# 将猫类的标签设为0,dog类的标签设为1
train_labels = [0] * len(train_cat_filenames) + [1] * len(train_dog_filenames)# 写入TFRcord文件
with tf.io.TFRecordWriter(train_tfrecord_file) as writer:for filename, label in zip(train_filenames,train_labels):# 读取数据集图片到内存,image为一个byte类型的字符串image = open(filename,'rb').read()# 建立tf.train.Feature字典feature = {'image' : tf.train.Feature(bytes_list = tf.train.BytesList(value=[image])),    # 图片是一个Bytes对象'label' : tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))     # 标签是一个Int对象}# 通过字典建立Exampleexample = tf.train.Example(features=tf.train.Features(feature=feature))# 将Example序列化serialized = example.SerializeToString()# 写入TFRecord文件writer.write(serialized)test_cat_filenames = [test_cats_dir + filename for filename in os.listdir(test_cats_dir)]
test_dog_filenames = [test_dogs_dir + filename for filename in os.listdir(test_dogs_dir)]
test_filenames = test_cat_filenames + test_dog_filenames# 将猫类的标签设为0,dog类的标签设为1
test_labels = [0] * len(test_cat_filenames) + [1] * len(test_dog_filenames)
with tf.io.TFRecordWriter(test_tfrecord_file) as writer:for filename, label in zip(test_filenames, test_labels):image = open(filename, 'rb').read()     # 读取数据集图片到内存,image 为一个 Byte 类型的字符串feature = {                             # 建立 tf.train.Feature 字典'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),  # 图片是一个 Bytes 对象'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))   # 标签是一个 Int 对象}example = tf.train.Example(features=tf.train.Features(feature=feature)) # 通过字典建立 Exampleserialized = example.SerializeToString() #将Example序列化writer.write(serialized)   # 写入 TFRecord 文件# 读取TFRecord文件# 定义Feature结构,告诉解码器每个Feature的类型是什么
feature_description = {'image': tf.io.FixedLenFeature([],tf.string),'label' : tf.io.FixedLenFeature([],tf.int64),
}# 定义解码函数
def _parse_example(example_string):# 将TFRecord文件中的每一个序列化的tf.train.Example解码feature_dict = tf.io.parse_single_example(example_string, feature_description)# 解码JPEG图片feature_dict['image'] = tf.io.decode_jpeg(feature_dict['image'])   # 处理大小与像素feature_dict['image'] = tf.image.resize(feature_dict['image'],[256, 256]) / 255.0return feature_dict['image'], feature_dict['label']batch_size = 32# 读取TFRecord文件
train_dataset = tf.data.TFRecordDataset(train_tfrecord_file)# 解码
train_dataset = train_dataset.map(_parse_example)
for image,label in train_dataset.take(1):print(image.shape,label)
# (256, 256, 3) tf.Tensor(0, shape=(), dtype=int64)# 模型批量读取
train_dataset = train_dataset.shuffle(buffer_size = 23000)
train_dataset = train_dataset.batch(batch_size)
# (32,256,256,3)
# 优化
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)# 测试集的读取
test_dataset = tf.data.TFRecordDataset(test_tfrecord_file)
test_dataset = test_dataset.map(_parse_example)
test_dataset = test_dataset.batch(batch_size)# 定义CNN模型
class CNNModel(tf.keras.models.Model):def __init__(self):super(CNNModel,self).__init__()self.conv1 = tf.keras.layers.Conv2D(32,3,activation='relu')self.maxpool1 = tf.keras.layers.MaxPooling2D()self.conv2 = tf.keras.layers.Conv2D(32,5,activation='relu')self.maxpool2 = tf.keras.layers.MaxPooling2D()self.flatten = tf.keras.layers.Flatten()self.d1 = tf.keras.layers.Dense(64,activation='relu')self.d2 = tf.keras.layers.Dense(2,activation='softmax')def call(self,x):# 定义前向传播x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.flatten(x)x = self.d1(x)x = self.d2(x)return xlearning_rate = 0.001
model = CNNModel()loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate)# 损失与评估
train_loss = tf.keras.metrics.Mean(name = 'train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name = 'test_accuracy')# batch
# 将动态图转换为静态图,静态图执行效率高
@tf.function
def train_step(images,labels):with tf.GradientTape() as tape:predictions = model(images)loss = loss_object(labels,predictions)# 计算梯度gradients = tape.gradient(loss,model.trainable_variables)optimizer.apply_gradients(zip(gradients,model.trainable_variables))train_loss(loss)train_accuracy(labels,predictions)  # update@tf.function
def test_step(images,labels):predictions = model(images)t_loss = loss_object(labels,predictions)test_loss(t_loss)test_accuracy(labels,predictions)# 模型训练
EPOCHS = 10for epoch in range(EPOCHS):# 重置评估指标train_loss.reset_states()train_accuracy.reset_states()test_loss.reset_states()test_accuracy.reset_states()for images,labels in train_dataset:train_step(images,labels)for test_images,test_labels in test_dataset:test_step(images,labels)template = 'Epoch {}, Loss : {}, Accuracy : {},Test Loss : {},Test Accuracy : {}'# 打印print(template.format(epoch + 1,train_loss.result(),train_accuracy.result() * 100,test_loss.result(),test_accuracy.result() * 100))

深度学习12-TFRecord详解相关推荐

  1. 深度学习开发环境调查结果公布,你的配置是这样吗?(附新环境配置) By 李泽南2017年6月26日 15:57 本周一(6 月 19 日)机器之心发表文章《我的深度学习开发环境详解:Te

    深度学习开发环境调查结果公布,你的配置是这样吗?(附新环境配置) 机器之心 2017-06-25 12:27 阅读:108 摘要:参与:李泽南.李亚洲本周一(6月19日)机器之心发表文章<我的深 ...

  2. 深度学习 --- 玻尔兹曼分布详解

    上一节我们从Hopfield神经网络存在伪吸引子的问题出发,为了解决伪吸引子带来的问题,详细介绍了模拟退火算法,本节也是基础性的讲解,为了解决伪吸引子还需要引入另外一个重要概念即:玻尔兹曼分布.本篇将 ...

  3. 深度学习 --- BP算法详解(流程图、BP主要功能、BP算法的局限性)

    上一节我们详细推倒了BP算法的来龙去脉,请把原理一定要搞懂,不懂的请好好理解BP算法详解,我们下面就直接把上一节推导出的权值调整公式拿过来,然后给出程序流程图,该流程图是严格按照上一节的权值更新过程写 ...

  4. 深度学习 --- BP算法详解(BP算法的优化)

    上一节我们详细分析了BP网络的权值调整空间的特点,深入分析了权值空间存在的两个问题即平坦区和局部最优值,也详细探讨了出现的原因,本节将根据上一节分析的原因进行改进BP算法,本节先对BP存在的缺点进行全 ...

  5. 深度学习归一化算法详解(BN,LN,IN,GN)

    目录 一.Batch Normalization(BN) 1.1为什么提出BN? 1.2BN的基本原理和公式 1.3BN在神经网络中的实现 1.4BN的优点和缺点 二.LN,IN,GN的原理和适用范围 ...

  6. 从未看过如此详细的深度学习推荐系统应用详解,读它!

    作者丨gongyouliu 编辑丨zandy 来源 | 大数据与人工智能(ID:ai-big-data) [导读]2016年DeepMind开发的AlphaGo在围棋对决中战胜了韩国九段选手李世石,一 ...

  7. 深度学习 --- Hopfield神经网络详解

    前面几节我们详细探讨了BP神经网络,基本上很全面深入的探讨了BP,BP属于前馈式类型,但是和BP同一时期的另外一个神经网络也很重要,那就是Hopfield神经网络,他是反馈式类型.这个网络比BP出现的 ...

  8. 深度学习生态圈【详解深度学习工具Keras】

    文章目录: 1 CNTK 2 Tensorflow2.1 介绍2.2 安装2.3 简单例子 3 Keras3.1 介绍3.2 安装Keras3.3 使用Keras构建深度学习模型3.4 一个例子 4 ...

  9. 深度学习之优化详解:batch normalization

    摘要: 一项优化神经网络的技术,点进来了解一下? 认识上图这些人吗?这些人把自己叫做"The Myth Busters",为什么?在他们的启发下,我们才会在Paperspace做类 ...

  10. 深度学习 --- BP算法详解(误差反向传播算法)

    本节开始深度学习的第一个算法BP算法,本打算第一个算法为单层感知器,但是感觉太简单了,不懂得找本书看看就会了,这里简要的介绍一下单层感知器: 图中可以看到,单层感知器很简单,其实本质上他就是线性分类器 ...

最新文章

  1. 负样本修正:既然数据是模型的上限,就不要破坏这个上限
  2. Java 反射 使用总结
  3. 经典C语言程序100例之十一
  4. socket服务器显示未响应,“程序未响应”的思考总结
  5. OAuth2解决什么问题
  6. 95-24-030-Future-ChannelFuture
  7. 某厂面试归来,发现自己落伍了!
  8. dpkg和apt-get命令的用法
  9. eclipse搭建简单的web服务,使用tomcat服务
  10. 课程设计旅游景点咨询系统
  11. atitit.系统托盘图标的设计java swing c# .net c++ js
  12. Vue3.0中文地址文档
  13. 经常使用的几种OCR文档扫描工具|无水印|避免智商税
  14. 线性代数笔记4.3 齐次线性方程组
  15. 频繁默认网关不可用_电脑经常掉线提示默认网关不可用原因分析和解决办法
  16. 性格木讷面试时如何脱颖而出?
  17. java怎么使用sni,启用SNI扩展的SSL握手 - 服务器上的证书选择
  18. excel数据处理:说说数据源表必须遵守的那些规则
  19. 2022-2027年(新版)中国伺服电机行业发展前景及趋势预测分析报告
  20. 模拟算法考试训练题和答案1

热门文章

  1. 在C++中各类型拼接成一个string
  2. OCS UCCA 开发笔记(Unified Communications Client API)
  3. 解决 Invalid character found in method name. HTTP method names must be tokens 异常信息
  4. MySQL数据库架构相关笔记(二)
  5. 高考志愿填报:java 软件 程序员 目前的就业现状
  6. 【转载】linux环境下大数据网站搬家
  7. MVC控制器取参数值
  8. NonComVisibleBaseClass Exception
  9. CentreonMonitoringEvent Logs没有结果的解决方法
  10. PE文件结构及其加载机制(四)