1、CIFAR-10数据集简介

CIFAR-10数据集包含10个类别的RGB彩色图片。图片尺寸为32×32,这十个类别包括:飞机、汽车、鸟、猫、鹿、狗、蛙、马、船、卡车。一共有50000张训练图片和10000张测试图片。

CIFAR-10数据集有如下文件:

batches.meta.txt  data_batch_2.bin  data_batch_4.bin  readme.html

data_batch_1.bin  data_batch_3.bin  data_batch_5.bin  test_batch.bin

其中,data_batch_1.bin~data_batch_5.bin五个文件是训练数据,每个文件以二进制的格式存储10000张图片和这些图片对于的标签。test_batch.bin存储的是10000张测试图像的测试标签。一张图片和对于的标签组成一个样本,一个样本有3073个字节组成,第一个字节为标签,后面3072个字节是图片数据{1024(R) + 1024(G) + 1024(B)}。

2、下载CIFAR-10数据集

2.1、首先下载tensorflow官方CIFAR-10代码

git clone https://github.com/tensorflow/models.git

模块位于models/tutorials/image/cifar10目录下。

2.2、下载CIFAR-10数据集

#coding:utf-8
#导入官方cifar10模块
import cifar10
import tensorflow as tf#tf.app.flags.FLAGS是tensorflow的一个内部全局变量存储器
FLAGS = tf.app.flags.FLAGS
#cifar10模块中预定义下载路径的变量data_dir为'/tmp/cifar10_eval',预定义如下:
#tf.app.flags.DEFINE_string('data_dir', './cifar10_data',
#                           """Path to the CIFAR-10 data directory.""")
#为了方便,我们将这个路径改为当前位置
FLAGS.data_dir = './cifar10_data'#如果不存在数据文件则下载,并且解压
cifar10.maybe_download_and_extract()
下载完数据后提示:>> Downloading cifar-10-binary.tar.gz 100.0%
Successfully downloaded cifar-10-binary.tar.gz 170052171 bytes

3、tensorflow数据读取机制

3.1、tensorflow数据读取机制简介

目前我们接触的Tensorflow有两种数据读取机制,第一种就是往占位符placeholder传入feed_dict,这种机制比较简单,前面的例子也用过,这里不赘述。现在我们讲第二种机制:

如上图,要训练数据,得分两步,第一步,先将数据从硬盘加载到内存中,第二步,拱给CPU或者GPU运算。如果只用一个线程,那么,运行第一步的时候第二步就得等着,运行第二步的时候,第一步就得等着,这样就浪费时间了。

如上图,解决这个问题,就得将第一步和第二步分别放在两个线程中:一个线程不断的把数据读入内存,另一个线程从内存中取出数据进行计算。

为了方便管理,tensorflow在内存队列前又加了一层“文件名队列”。

3.2、数据读取机制对应的函数

对于文件名队列,使用tf.train.string_input_producer函数,该函数有三个比较重要的参数,string_tensor参数向这个函数传入文件名list,系统就自动将它转为一个文件名队列。num_epochs参数传入的是epoch数,即将传入的list全部运算几遍。shuffle参数决定在一个epoch内,文件的顺序是否被打乱,若shuffle=False,则不打乱,否则打乱,默认是打乱的。

在tensorflow中,内存队列不需要我们自己建立,只需要使用reader对象从文件名队列中读取数据即可。

需要注意的是,使用tf.train.string_input_producer函数创建队列后,程序并没有马上将文件名加入队列,而是要运行tf.train.start_queue_runners函数后,才真正开始工作。

3.3、例子

为了便于理解,下面给出一个简单的例子

如上图,有三张图片,分别为1.jpg, 2.jpg,3.jpg,首先来看了当shuffle=False时的读取顺序。

不打乱读取顺序的代码

#encoding:utf-8
import tensorflow as tffilenames = ['1.jpg', '2.jpg', '3.jpg']
#shuffle=False表示不打乱顺序,num_epochs=3表示整个队列获取三次
queue = tf.train.string_input_producer(filenames, shuffle=False, num_epochs=3)#读取文件名队列中的数据
reader = tf.WholeFileReader()
key,value = reader.read(queue)with tf.Session() as sess:#初始局部化变量,注意这个函数跟tf.global_variables_initializer.run()是不一样的#因为string_input_producer函数的num_epochs=3传入的是局部变量tf.local_variables_initializer().run()threads = tf.train.start_queue_runners(sess=sess)i = 0while True:i += 1data = sess.run(value)with open('shuffle_false/image_%d.jpg' % i, 'wb') as fd:fd.write(data)

运行结果:

打乱读取顺序的运行结果

4、将CIFAR-10数据集保存为图片

4.1、下载CIFAR-10数据集

# 查看CIFAR-10数据是否存在,如果不存在则下载并解压
def download():# tf.app.flags.FLAGS是tensorflow的一个内部全局变量存储器FLAGS = tf.app.flags.FLAGS# 为了方便,我们将这个路径改为当前位置FLAGS.data_dir = './cifar10_data'# 如果不存在数据文件则下载,并且解压cifar10.maybe_download_and_extract()

4.2、设置图片保存的路径

#将获取的图片保存到这里
image_save_path = './cifar10_image/'
if os.path.exists(image_save_path) == False:os.mkdir(image_save_path)

4.3、根据tensorflow读取机制,设置文件名队列,然后调用获取并解析图片函数

#检测CIFAR-10数据是否存在,如果不存在则返回False
def check_cifar10_data_files(filenames):for file in filenames:if os.path.exists(file) == False:print('Not found cifar10 data.')return Falsereturn True#获取图片前的预处理,检测CIFAR10数据是否存在,如果不存在直接退出
#如果存在,用string_input_producer函数创建文件名队列,
# 并且通过get_record函数获取图片标签和图片数据,并返回
def get_image(data_path):filenames = [os.path.join(data_path, "data_batch_%d.bin" % i) for i in range(1, 6)]print(filenames)if check_cifar10_data_files(filenames) == False:exit()queue = tf.train.string_input_producer(filenames)return get_record(queue)

4.4、读取并解析图片

#获取每个样本数据,样本由一个标签+一张图片数据组成
def get_record(queue):print('get_record')#定义label大小,图片宽度、高度、深度,图片大小、样本大小label_bytes = 1image_width = 32image_height = 32image_depth = 3image_bytes = image_width * image_height * image_depthrecord_bytes = label_bytes + image_bytes#根据样本大小读取数据reader = tf.FixedLengthRecordReader(record_bytes)key, value = reader.read(queue)#将获取的数据转变成一维数组#例如# source = 'abcde'# record_bytes = tf.decode_raw(source, tf.uint8)#运行结果为[ 97  98  99 100 101]record_bytes = tf.decode_raw(value, tf.uint8)#获取label,label数据在每个样本的第一个字节label_data = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)#获取图片数据,label后到样本末尾的数据即图片数据,# 再用tf.reshape函数将图片数据变成一个三维数组depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],[label_bytes + image_bytes]),[3, 32, 32])#矩阵转置,上面得到的矩阵形式是[depth, height, width],即红、绿、蓝分别属于一个维度的,#假设只有3个像素,上面的格式就是RRRGGGBBB#但是我们图片数据一般是RGBRGBRGB,所以这里要进行一下转置#注:上面注释都是我个人的理解,不知道对不对image_data = tf.transpose(depth_major, [1, 2, 0])#统一将数据转为float32格式image_data = tf.cast(image_data, tf.float32)return label_data, image_data

4.5、主函数

if __name__ == '__main__':#查看CIFAR-10数据是否存在,如果不存在则下载并解压download()#将获取的图片保存到这里image_save_path = './cifar10_image/'if os.path.exists(image_save_path) == False:os.mkdir(image_save_path)#获取图片数据key, value = get_image('./cifar10_data/cifar-10-batches-bin/')with tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()#这里才真的启动队列threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(50):# print("i:%d" % i)#####################################这里data和label不能分开run,否则图片和标签就不匹配了,多谢网友ATPY提醒#data = sess.run(value)#label = sess.run(key)#应该这样label, data = sess.run([key, value])####################################print(label)scipy.misc.toimage(data).save(image_save_path + '/%d_%d.jpg' % (label, i))coord.request_stop()coord.join()

4.6、运行结果:

4.7、完整代码

#encoding:utf-8
import tensorflow as tf
import os
import cifar10
import scipy.misc# 查看CIFAR-10数据是否存在,如果不存在则下载并解压
def download():# tf.app.flags.FLAGS是tensorflow的一个内部全局变量存储器FLAGS = tf.app.flags.FLAGS# 为了方便,我们将这个路径改为当前位置FLAGS.data_dir = './cifar10_data'# 如果不存在数据文件则下载,并且解压cifar10.maybe_download_and_extract()#检测CIFAR-10数据是否存在,如果不存在则返回False
def check_cifar10_data_files(filenames):for file in filenames:if os.path.exists(file) == False:print('Not found cifar10 data.')return Falsereturn True#获取每个样本数据,样本由一个标签+一张图片数据组成
def get_record(queue):print('get_record')#定义label大小,图片宽度、高度、深度,图片大小、样本大小label_bytes = 1image_width = 32image_height = 32image_depth = 3image_bytes = image_width * image_height * image_depthrecord_bytes = label_bytes + image_bytes#根据样本大小读取数据reader = tf.FixedLengthRecordReader(record_bytes)key, value = reader.read(queue)#将获取的数据转变成一维数组#例如# source = 'abcde'# record_bytes = tf.decode_raw(source, tf.uint8)#运行结果为[ 97  98  99 100 101]record_bytes = tf.decode_raw(value, tf.uint8)#获取label,label数据在每个样本的第一个字节label_data = tf.cast(tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)#获取图片数据,label后到样本末尾的数据即图片数据,# 再用tf.reshape函数将图片数据变成一个三维数组depth_major = tf.reshape(tf.strided_slice(record_bytes, [label_bytes],[label_bytes + image_bytes]),[3, 32, 32])#矩阵转置,上面得到的矩阵形式是[depth, height, width],即红、绿、蓝分别属于一个维度的,#假设只有3个像素,上面的格式就是RRRGGGBBB#但是我们图片数据一般是RGBRGBRGB,所以这里要进行一下转置#注:上面注释都是我个人的理解,不知道对不对image_data = tf.transpose(depth_major, [1, 2, 0])return label_data, image_data#获取图片前的预处理,检测CIFAR10数据是否存在,如果不存在直接退出
#如果存在,用string_input_producer函数创建文件名队列,
# 并且通过get_record函数获取图片标签和图片数据,并返回
def get_image(data_path):filenames = [os.path.join(data_path, "data_batch_%d.bin" % i) for i in range(1, 6)]print(filenames)if check_cifar10_data_files(filenames) == False:exit()queue = tf.train.string_input_producer(filenames, shuffle=False)# return tf.cast((cifar10_input.read_cifar10(queue)).uint8image, tf.float32)return get_record(queue)if __name__ == '__main__':#查看CIFAR-10数据是否存在,如果不存在则下载并解压download()#将获取的图片保存到这里image_save_path = './cifar10_image/'if os.path.exists(image_save_path) == False:os.mkdir(image_save_path)#获取图片数据key, value = get_image('./cifar10_data/cifar-10-batches-bin/')with tf.Session() as sess:#初始化变量sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()#这里才真的启动队列threads = tf.train.start_queue_runners(sess=sess, coord=coord)for i in range(50):# print("i:%d" % i)#####################################这里data和label不能分开run,否则图片和标签就不匹配了,多谢网友ATPY提醒#data = sess.run(value)#label = sess.run(key)#应该这样label, data = sess.run([key, value])####################################print(label)scipy.misc.toimage(data).save(image_save_path + '/%d_%d.jpg' % (label, i))coord.request_stop()coord.join()

4.8、为了更好理解get_record函数怎么将每个样本数据提取并转换的过成,我再给个小例子:

#encoding:utf-8
import tensorflow as tf# 为了简化过程,假设一个4×4×3的样本数据如下,
# 其中,第一个字符“0”表示图片的标签label
# “1”表示图片颜色值的R通道,“2”表示G通道,“3”表示B通道
source = '0111111111111111122222222222222223333333333333333'
sourcelist = tf.decode_raw(source, tf.uint8)
#上面运行后得到的数据如下:(0的ASCII值是48,同理推出1、2、3的值为49,50,51,这不是重点不用关心)
#[48 49 49 49 49 49 49 49 49 49 49 49 49 49 49 49 49 50 50 50 50 50 50 50
# 50 50 50 50 50 50 50 50 50 51 51 51 51 51 51 51 51 51 51 51 51 51 51 51
# 51]#获取label
label = tf.strided_slice(sourcelist, [0], [1]);#获取图片数据,并转为[3, 4, 4]的矩阵形式,其中,
#[1]表示从1下标开始截取,[49]表示截取到49下标,[3, 4, 4]中, 3表示通道数,4分别表示宽度和高度
image = tf.reshape(tf.strided_slice(sourcelist, [1], [49]), [3, 4, 4])
#上面运行后得到数据如下:
# [[[49 49 49 49]
#   [49 49 49 49]
#   [49 49 49 49]
#   [49 49 49 49]]
#
#  [[50 50 50 50]
#   [50 50 50 50]
#   [50 50 50 50]
#   [50 50 50 50]]
#
#  [[51 51 51 51]
#   [51 51 51 51]
#   [51 51 51 51]
#   [51 51 51 51]]]
#可以看到,RGB数据都分别在同一维度#这里就是对上面得到的矩阵进行转置
image_transpose = tf.transpose(image, [1, 2, 0])
#上面运行后得到的数据如下
# [[[49 50 51]
#   [49 50 51]
#   [49 50 51]
#   [49 50 51]]
#
#  [[49 50 51]
#   [49 50 51]
#   [49 50 51]
#   [49 50 51]]
#
#  [[49 50 51]
#   [49 50 51]
#   [49 50 51]
#   [49 50 51]]
#
#  [[49 50 51]
#   [49 50 51]
#   [49 50 51]
#   [49 50 51]]]with tf.Session() as sess:sess.run(tf.global_variables_initializer())result = sess.run(tf.cast(sourcelist, tf.int32))print(result)result = sess.run(tf.cast(image, tf.int32))print(result)result = sess.run(tf.cast(image_transpose, tf.int32))print(result)

总结:

今天的学习记录到此~

TensorFlow精进之路(四):CIFAR-10图像识别(上)相关推荐

  1. TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来

    1.概述 自从开了专栏<TensorFlow精进之路>关于对TensorFlow的整理思路更加清晰.上两篇讲到Softmax回归模型和两层卷积神经网络模型训练MNIST,虽然使用神经网络能 ...

  2. TensorFlow精进之路(九):TensorFlow编程基础

    1.概述 卷积部分的知识点在博客:TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来已经写过,所以不再赘述.这一节简单聊聊tensorflow的编程基础. 2. ...

  3. TensorFlow精进之路(十二):随时间反向传播BPTT

    1.概述 上一节介绍了TensorFlow精进之路(十一):反向传播BP,这一节就简单介绍一下BPTT. 2.网络结构 RNN正向传播可以用上图表示,这里忽略偏置. 上图中, x(1:T)表示输入序列 ...

  4. tensorflow精进之路(二十四)——Object Detection API目标检测(中)(COCO数据集训练的模型—ssd_mobilenet_v1_coco模型)

    1.概述 上一讲简单的讲了目标检测的原理以及Tensorflow Object Detection API的安装,这一节继续讲Tensorflow Object Detection API怎么用. 2 ...

  5. tensorflow精进之路(二十二)——使用slim模型对图像识别与检测(下)(VGG19模型)

    1.概述 上一节使用slim对图像进行识别,但是一张图片里就识别出一样东西,这节我们就来学习怎么检测图片里更多的物品.上一节我们使用的是Inception-ResNet-v2模型,这一节我们使用的是V ...

  6. tensorflow精进之路(二十一)——使用slim模型对图像识别与检测(上)(Inception_ResNet_v2模型)

    1.概述 上一讲,我们使用了slim训练了自己的数据,主要用于分类任务.这一讲,我们还是继续学习slim库,用它来对图像进行识别和检测. 2.下载Inception_ResNet_v2模型 第十六讲中 ...

  7. TensorFlow精进之路(十四):RNN训练MNIST数据集

    1.概述 前面介绍了RNN,这一节就用tensorflow的RNN来训练MNIST数据集,看看准确率如何. 2.代码实现 2.1.导入数据集 # encoding:utf-8 import tenso ...

  8. TensorFlow精进之路(五):CIFAR-10图像识别(中)

    5.数据增强 5.1.简介 问题:深度学习中通常会要求数量很大的训练样本,一般来说,样本数量越多,训练效果越好,但是这么庞大的样本的收集整理是很大的工程. 依据:如果对一张图像进行简单的平移.翻转.缩 ...

  9. TensorFlow精进之路(七):关于两层卷积神经网络对CIFAR-10图像的识别

    1.概述 在前面已经对官方的CIFAR10图像识别模块进行分析,但如果只做到这一步感觉还是不够,没能做到举一反三以及对之前学的知识的巩固,所以这一节,我打算结合之前学的双层卷积神经网络自己写一个dem ...

最新文章

  1. docker 搭建Tomcat web 简单示例过程
  2. 如何平衡存储系统的一致性和可用性?
  3. Autodesk Revit DB Link 中文理解
  4. unix linux系统版本,怎么查看UNIX系统版本?
  5. [BZOJ4817]树点涂色
  6. nodejs 错误问题解决
  7. 全国计算机等级考试在线报名,全国计算机等级考试网上报考具体流程
  8. 图书馆系统软件测试计划,图书馆管理系统软件测试计划
  9. NuPlayer源码分析三:解码模块
  10. 【esp32-adf】按键服务源码分析
  11. 数据库DevOps:我们如何提供安全、稳定、高效的研发全自助数据库服务-iDB/DMS企业版
  12. 致远OA,小地球启动报错:读取系统初始化信息失败!
  13. 猿团君分析-程序员如何成功的提高影响力2.0
  14. 一行代码病毒小心谨慎
  15. React基础—state组件使用及分类
  16. 图卷积神经网络的数学原理——谱图理论和傅里叶变换初探
  17. CF #80 Cthulhu
  18. 360 PK QQ 始末
  19. 解决Visio导出图片没有边界或者边缘留白过少的问题
  20. L3-2 拼题A打卡奖励 (30 分)

热门文章

  1. App_Data 目录中的数据库位置指定了一个本地 SQL Server
  2. Visual Studio DSL 入门 13---结合T4生成代码
  3. springboot通过各种不同类型参数获取容器中的bean工具类
  4. 开始使用windows live writer写博客。
  5. ORM框架之Mybatis(二)数据库连接池、事务及动态SQL
  6. mysql基础知识(二)
  7. 2018-08-10 Netty:4.x
  8. JavaScript——(function(){})()立即执行函数解析
  9. JavaWeb——mybatis一对一、一对多查询
  10. SQL那些事儿(七)--oracle表空间、用户查看基本语句