转载请注明作者和出处: http://blog.csdn.net/wiinter_fdd/article/details/72835939
运行平台: Windows
Python版本: Python3.x
IDE: Spyder

前言

   最近一直在研究深度学习,主要是针对卷积神经网络(CNN),接触过的数据集也有了几个,最经典的就是MNIST, CIFAR10/100, NOTMNIST, CATS_VS_DOGS 这几种,由于这几种是在深度学习入门中最被广泛应用的,所以很多深度学习框架 Tensorflow、keraspytorch都有针对这些数据集专用的数据导入的函数封装,但是一般情况下我们的数据集并不是这种很规范的形式,那么如何把自己的数据集转换成这些框架能够使用的数据形式至关重要,接下来博主将会对现有的较流行的深度学习框架封装自己的数据进行讲解,首先是针对最流行的Tensorflow。

   查阅tensorflow的官方API,在GET STARTED下面的Programmer’s Guide中有一个Reading Data的章节介绍,大体内容就是tensorflow读取数据的方式:

可以看到,tensorflow官网给出了三种读取数据的方法:
对于数据量较小而言,可能一般选择直接将数据加载进内存,然后再分batch输入网络进行训练(tip:使用这种方法时,结合yield 使用更为简洁,大家自己尝试一下吧,我就不赘述了)。但是,如果数据量较大,这样的方法就不适用了,因为太耗内存,所以这时最好使用tensorflow提供的队列queue,也就是第二种方法 从文件读取数据。对于一些特定的读取,比如csv文件格式,官网有相关的描述,在这儿我介绍一种比较通用,高效的读取方法(官网介绍的少),即使用tensorflow内定标准格式——TFRecord.

   那下面就让我们了解一下什么是TFRecord:

1. What is TFRecord?

TFRecords其实是一种二进制文件,虽然它不如其他格式好理解,但是它能更好的利用内存,更方便复制和移动,并且不需要单独的标签文件(等会儿就知道为什么了)… …总而言之,这样的文件格式好处多多,所以让我们用起来吧。这里注意:TFRecord会根据你输入的文件的类,自动给每一类打上同样的标签。

TFRecord文件中的数据是通过tf.train.Example Protocol Buffer的格式存储的,下面是tf.train.Example的定义:

message Example {Features features = 1;
};message Features{map<string,Feature> featrue = 1;
};message Feature{oneof kind{BytesList bytes_list = 1;FloatList float_list = 2;Int64List int64_list = 3;}
};

从上述代码可以看到,ft.train.Example 的数据结构相对简洁。tf.train.Example中包含了一个从属性名称到取值的字典,其中属性名称为一个字符串,属性的取值可以为字符串(BytesList ),实数列表(FloatList )或整数列表(Int64List )。例如我们可以将解码前的图片作为字符串,图像对应的类别标号作为整数列表。

2. How to convert our own data to TFRecord?

终于我们关心的话题来了,怎么转换?这里我们使用Kaggle上面有名猫狗大战的数据集可以通过Dogs vs Cats来下载,为了方便演示,我们利用这个数据集创建了一个新的数据集,取猫狗图片中各100张分别放在data文件夹下面的cats和dogs子文件中,入下图所示。

数据准备好以后,我们就要开始读取数据,生成TFRecord了,下面直接上代码,对于代码内容随后会有相应的说明:

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as npcwd = "D://Anaconda3//spyder//Tensorflow_ReadData//data//"
classes = {'cats', 'dogs'} #预先自己定义的类别
writer = tf.python_io.TFRecordWriter('train.tfrecords') #输出成tfrecord文件def _int64_feature(value):return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))def _bytes_feature(value):return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))for index, name in enumerate(classes):class_path = cwd + name + '//'for img_name in os.listdir(class_path):img_path = class_path + img_name    #每个图片的地址img = Image.open(img_path)img = img.resize((208, 208))img_raw = img.tobytes()  #将图片转化为二进制格式example = tf.train.Example(features = tf.train.Features(feature = {"label": _int64_feature(index),"img_raw": _bytes_feature(img_raw),                                                                          }))writer.write(example.SerializeToString())  #序列化为字符串
writer.close()

以上代码就是将数据读去进来,生成tfrecord文件,看过Tensorflow官方API的同学们应该都可以看懂。

3. How to read data from TFRecords?

上面已经把自己的数据保存成tensorflow可以使用的tfrecord的形式了,那么tensorflow到底如何使用呢?下面继续看代码:

def read_and_decode(filename, batch_size): # read train.tfrecordsfilename_queue = tf.train.string_input_producer([filename])# create a queuereader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)#return file_name and filefeatures = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})#return image and labelimg = tf.decode_raw(features['img_raw'], tf.uint8)img = tf.reshape(img, [208, 208, 3])  #reshape image to 208*208*3label = tf.cast(features['label'], tf.int32) #throw label tensorimg_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size= batch——size,num_threads=64,capacity=2000,min_after_dequeue=1500,)return img_batch, tf.reshape(label_batch,[batch_size])

   以上是我们定义的从tfrecord文件中读取数据的函数,在这里我们使用的tensorflow的队列读取方式。在读取到队列中后,数据输出之前还要作解码的操作从,可以使用tf.TFRecordReader的tf.parse_single_example解析器。这个操作可以将Example协议内存块(protocol buffer)解析为张量。
   可以看到,这个函数除了tfrecord文件的名这一个参数外,还有batch_size这个参数,利用tf.train.shuffle_batch()这个函数对读取到的数据进行了batch处理,这样更有利于后续的训练。
注意:当数据量加大时,也可以将数据写入多个TFRecord文件。

  我们的数据是读进来,那么到底是不是我们想要的呢,下面就是我们的测试程序。

4. How to show TFRecords’ images?

tfrecords_file = 'D://Anaconda3//spyder//Tensorflow_ReadData//train.tfrecords'
BATCH_SIZE = 4
image_batch, label_batch = read_and_decode(tfrecords_file,BATCH_SIZE)with tf.Session()  as sess:i = 0coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)try:while not coord.should_stop() and i<1:# just plot one batch sizeimage, label = sess.run([image_batch, label_batch])for j in np.arange(4):print('label: %d' % label[j])plt.imshow(image[j,:,:,:])plt.show()i+=1except tf.errors.OutOfRangeError:print('done!')finally:coord.request_stop()coord.join(threads)

这里我们也是用的tensorflow官网推荐的队列管理形式,batch_size这里可以大家任意设定,显示几幅图片都可以,这里博主设置的是4。

这样就可以把任意格式的数据转换成tensorflow推荐的TFRecord的格式的了,是不是随你有很大帮助呢。下面是完整的代码:

# -*- coding: utf-8 -*-
import os
import tensorflow as tf
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np#%%
cwd = "D://Anaconda3//spyder//Tensorflow_ReadData//data//"
classes = {'cats', 'dogs'}
writer = tf.python_io.TFRecordWriter('train.tfrecords')def _int64_feature(value):return tf.train.Feature(int64_list = tf.train.Int64List(value = [value]))def _bytes_feature(value):return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value]))for index, name in enumerate(classes):class_path = cwd + name + '//'for img_name in os.listdir(class_path):img_path = class_path + img_name    #每个图片的地址img = Image.open(img_path)img = img.resize((208, 208))img_raw = img.tobytes()  #将图片转化为二进制格式example = tf.train.Example(features = tf.train.Features(feature = {"label": _int64_feature(index),"img_raw": _bytes_feature(img_raw),                                                                          }))writer.write(example.SerializeToString())  #序列化为字符串
writer.close()
#%%
def read_and_decode(filename, batch_size): # read train.tfrecordsfilename_queue = tf.train.string_input_producer([filename])# create a queuereader = tf.TFRecordReader()_, serialized_example = reader.read(filename_queue)#return file_name and filefeatures = tf.parse_single_example(serialized_example,features={'label': tf.FixedLenFeature([], tf.int64),'img_raw' : tf.FixedLenFeature([], tf.string),})#return image and labelimg = tf.decode_raw(features['img_raw'], tf.uint8)img = tf.reshape(img, [208, 208, 3])  #reshape image to 512*80*3
#    img = tf.cast(img, tf.float32) * (1. / 255) - 0.5 #throw img tensorlabel = tf.cast(features['label'], tf.int32) #throw label tensorimg_batch, label_batch = tf.train.shuffle_batch([img, label],batch_size= batch_size,num_threads=64,capacity=2000,min_after_dequeue=1500,)return img_batch, tf.reshape(label_batch,[batch_size])#%%
tfrecords_file = 'D://Anaconda3//spyder//Tensorflow_ReadData//train.tfrecords'
BATCH_SIZE = 4
image_batch, label_batch = read_and_decode(tfrecords_file, BATCH_SIZE)with tf.Session()  as sess:i = 0coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(coord=coord)try:while not coord.should_stop() and i<1:# just plot one batch sizeimage, label = sess.run([image_batch, label_batch])for j in np.arange(BATCH_SIZE):print('label: %d' % label[j])plt.imshow(image[j,:,:,:])plt.show()i+=1except tf.errors.OutOfRangeError:print('done!')finally:coord.request_stop()coord.join(threads)

运行结果如下:

只需要Ctrl+C和V点击运行,就可以得到上面的结果了。

接下来还会讲解到另外两个深度学习框架Keras,pytorch如何将自己的数据转化为框架可以使用的格式,敬请期待吧!!!!

用Tensorflow处理自己的数据:制作自己的TFRecords数据集相关推荐

  1. 利用matlab将自己的数据制作为标准VOC数据集格式

    在使用各种深度网络的时候,需要根据自己的需求,自己的数据fine-tuning自己的模型,首要的一步就是讲自己的数据制作成标准VOC数据集,本文记录自己利用matlab制作标准VOC数据集的方法. 1 ...

  2. TensorFlow 制作自己的TFRecord数据集

    TensorFlow 制作自己的TFRecord数据集 准备图片数据 网上下载了2类吉他和房子的图片, 全部 resize成64*64大小 如下图, 保存项目下: 现在利用这2 类 共108张图片制作 ...

  3. TensorFlow神经网络(六)制作数据集,实现特定应用

    [致谢]内容来自mooc人工智能实践第六讲 并广泛参考文章 https://www.jianshu.com/p/766a2af5eb6a 一.数据集生成读取文件mnist_generateds.py ...

  4. 如何制作自己的图片数据集-基于tensorflow

    写在开始 自己最开始接触python的时候,第一个学会使用的库就是tensorflow,在经历了everyone 都会经历的mnist数据集训练后,就开始想自己做一个图片分类的深度学习,期间也是一波三 ...

  5. NIN模块tensorflow实现和一个自己制作的手写字母数据集

    NIN层 简介: 我们提出了一种新型的深度网络结构,称为"Network In Network"(NIN),它可以增强模型在感受野(receptive field)内对局部区域(l ...

  6. TensorFlow csv读取文件数据(代码实现)

    TensorFlow csv读取文件数据(代码实现) 大多数人了解 Pandas 及其在处理大数据文件方面的实用性.TensorFlow 提供了读取这种文件的方法. 前面章节中,介绍了如何在 Tens ...

  7. [caffe] 数据制作和训练

    [caffe] 数据制作和训练 在使用caffe时,我们希望使用自己的数据进行训练,以下给出如何制作自己的数据.所有的数据制作都是基于imagenet的. 1.数据准备,我们需要一个train和val ...

  8. 公安网三合一EWSD交换机数据制作

    公安网三合一EWSD交换机数据制作 信令点编码:xx-xx-xxx     DIU=0-22-1     TUP部分 1.  确定到对端信令点的通道 CRC7LSET:LSNAME=LSGAJ,SPC ...

  9. python读取图像数据流_浅谈TensorFlow中读取图像数据的三种方式

    本文面对三种常常遇到的情况,总结三种读取数据的方式,分别用于处理单张图片.大量图片,和TFRecorder读取方式.并且还补充了功能相近的tf函数. 1.处理单张图片 我们训练完模型之后,常常要用图片 ...

最新文章

  1. PHP判断ajax请求:HTTP_X_REQUESTED_WITH
  2. [云炬ThinkPython阅读笔记]2.9 术语表
  3. #1419 : 后缀数组四·重复旋律4 (重复次数最多的连续字串)
  4. 提交git push 的时候报错,Please make sure you have the correct access rights
  5. hdu2846(字典树)
  6. Delphi Note
  7. 27 构造连续的ICMP数据包
  8. 想用好低代码这把“双刃剑”,先搞清楚这三个问题|低代码系列(四)
  9. Android之多媒体视频的播放和录制
  10. 裂变红包,金额打散的方法
  11. 条码公司的互联网业务调查分析及展望
  12. word文档怎么调成黑底白字
  13. 火鸟门户v4.0 2019全能地方门户系统源码
  14. 锦尚中国 授权文件解密
  15. 小米扫地机器人充电座指示灯不亮_小米扫地机器人怎么充电
  16. greensock缓动类包之TweenLite
  17. 常微分方程数值解法——python实现
  18. 淘客是什么 淘客怎么做
  19. 求一个数所有因子的个数,因子和
  20. 乔布斯05年斯坦福大学毕业典礼上的演讲

热门文章

  1. 《大学数学手册》摘抄
  2. R语言与数据分析之八:时间序列--霍尔特指数平滑法
  3. Flink 命令行参数介绍
  4. oracle+分页很慢,oracle分页查询缓慢的情况
  5. Vue $axios 插件使用
  6. PyQt5遇到QWidget: Must construct a QApplication before a QWidget
  7. 激光雷达方程应用(使用MATLAB和Python语言得到回波强度、距离平方矫正信号、消光系数随距离的变化曲线)
  8. 第一次电赛体会-2019年国赛
  9. 相比之下,在美国看病简直是噩梦
  10. html 判断复选框是否被选中