引言

TensorFlow 版本1.15pip3 install tensorflow==1.15.0
这是《TensorFlow实战Google深度学习框架(第2版)》的学习笔记,所有代码在TensorFlow 1.15版本中运行正常

TFRecord输入数据格式

TensorFlow提供了一种统一的格式来存储数据,这个格式就是TFRecord

TFRecord格式介绍

TFRecord文件中的数据都是通过tf.train.Example Protocol Buffer的格式存储的。以下给出了tf.train.Example的定义:

message Example {Features features = 1
};message Features {map<string, Feature> feature = 1;
};
message Feature {oneof kind {BytesList bytes_list = 1;FloatList float_list = 2;Int64List int64_list = 3; }
};

tf.train.Example中包含了一个从属性名称到取值的字典feature。取值可以为字符串BytesList、实数列表float_list或整数列表int64_list
比如将一张解码前的图像存为一个字符串,图像所对应的类别编号存为整数列表。下面将给出一个使用TFRecord的具体样例。

TFRecord样例程序

以下程序给出了如何将MNIST输入数据转化为TFRecord的格式。

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport numpy as np# 生成整数型的属性
def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# 生成字符串型的属性
def _bytes_feature(value):return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))mnist = input_data.read_data_sets('./datasets/mnist',dtype=tf.uint8, one_hot=True)images = mnist.train.images
# 训练数据所对应的正确答案,可以作为一个属性保存在TFRecord中
labels = mnist.train.labels
# 训练数据的图像分辨率,可以作为Example中的一个属性
pixels = images.shape[1]
num_examples = mnist.train.num_examples# 输出TFRecord文件的地址
filename = './records/output.tfrecords'
# 创建一个writer来写TFRecord文件
writer = tf.python_io.TFRecordWriter(filename)
for index in range(num_examples):# 将图像矩阵转化成一个字符串image_raw = images[index].tostring()# 将一个样例转化为Example Protocol Buffer,并将所有的信息写入这个数据结构example = tf.train.Example(features=tf.train.Features(feature={'pixels': _int64_feature(pixels),'label': _int64_feature(np.argmax(labels[index])),'image_raw': _bytes_feature(image_raw)}))# 将一个Example写入TFRecord文件writer.write(example.SerializeToString())
writer.close()

以上程序可以将MNIST数据集中所有的训练数据存储到一个TFRecord中。TensorFlow对从文件列表中读取数据提供了很好的支持。
以下给出了如何读取TFRecord文件中的数据:

import tensorflow as tf# 创建一个reader来读取TFRecord文件中的样例
reader = tf.TFRecordReader()
# 创建一个队里来维护输入文件列表
filename_queue = tf.train.string_input_producer(['./records/output.tfrecords'])# 从文件中读出一个样例,也可以用read_up_to函数一次性读取多个样例
_, serialized_example = reader.read(filename_queue)
# 解析读入的一个样例,如果需要解析多个,用parse_example
features = tf.parse_single_example(serialized_example,features = {# TensorFlow提供了两种不同的属性解析方法。一种是tf.FixedLenFeature,它返回一个Tensor# 另一种是tf.VarLenFeature,返回SparseTensor,用于处理系稀疏矩阵# 这里解析数据的格式需要和上面程序写入数据的格式一致'image_raw': tf.FixedLenFeature([], tf.string),'pixels': tf.FixedLenFeature([], tf.int64),'label': tf.FixedLenFeature([], tf.int64),}
)# tf.decode_raw 可以将字符串解析成图像对应的像素数组
image =tf.decode_raw(features['image_raw'], tf.uint8)
label = tf.cast(features['label'], tf.int32)
pixels = tf.cast(features['pixels'], tf.int32)sess = tf.Session()
# 启动多线程处理输入数据
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 每次运行可以读取TFRecord文件中的一个样例,当所有样例读完只有,次样例中程序会从头读取
for i in range(10):print(sess.run([image, label, pixels]))

图像数据处理

在大部分图像识别问题中,通过图像预处理过程可以提高模型的准确率。

TensorFlow图像处理函数

图像编码处理

一张RGB色彩模式的图像可以看成一个三维矩阵,矩阵中的每个数表示了图像上不同位置,不同颜色的亮度。

然而图像在存储时并不是直接记录这些矩阵中的数字,而是记录经过压缩编码之后的结果。所以要将一张图像还原成一个三维矩阵,需要解码的过程。

以下代码示范了如何对jpeg格式图像进行编码/解码。


import tensorflow as tfimport matplotlib.pyplot as plt# 读取图像的原始数据
image_raw_data = tf.gfile.FastGFile('./images/cat.jpg','rb').read()with tf.Session() as sess:# 对图像进行jpeg的格式解码从而得到图像对应的三维矩阵img_data = tf.image.decode_jpeg(image_raw_data)print(img_data.eval()) # 输出解码之后的三维矩阵# 使用plt可视化plt.imshow(img_data.eval())plt.show()# 将表示一张图像的三维矩阵重新按照jpeg格式编码并存入文件encoded_image = tf.image.encode_jpeg(img_data)with tf.gfile.GFile('./images/cat_output.jpg','wb') as f:f.write(encoded_image.eval())

图像大小调整

一般来说,网络上获取的图像大小是不固定的,但是神经网络输入节点的个数是固定的。所以需要将图像的大小统一。

图像大小调整有两种方式,第一种是通过算法使得新的图像尽量保存原始图像上的所有信息。
TensorFlow提供了4种不同的方法,并且将它们封装到了tf.image.resize_images函数。

import tensorflow as tfimport matplotlib.pyplot as plt# 读取图像的原始数据
image_raw_data = tf.gfile.FastGFile('./images/cat.jpg','rb').read()with tf.Session() as sess:# 对图像进行jpeg的格式解码从而得到图像对应的三维矩阵img_data = tf.image.decode_jpeg(image_raw_data)# 首先将图片数据转化为实数类型。将0-255的像素值转化为0.0-1.0的实数img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)# 通过tf.image.resize_images函数调整图像的大小# [300,300]是调整后图像的大小# method指定了调整图像大小的算法resized = tf.image.resize_images(img_data, [300,300], method=0)# 使用plt可视化plt.figure(1)plt.subplot(1, 2, 1) #图一包含1行2列子图,当前画在第一行第一列图上plt.imshow(resized.eval())resized = tf.image.resize_images(img_data, [300,300], method=1)plt.figure(1)plt.subplot(1, 2, 2)#当前画在第一行第2列图上plt.imshow(resized.eval())plt.show()

上面method参数取值有:

Method取值 图像大小调整算法
0 双线性插值法
1 最邻居法
2 双三次插值法
3 面积插值法

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PS6OR0I9-1611829552555)(attachment:method4.png)]

上图是这几种方法的区别,不同的算法调整出来的结果会有细微差别,但不会相差太远。

TensorFlow还提供了对图片进行裁剪或填充。


# 第一个参数是原始图像,后两个参数是调整后的图像大小
croped = tf.image.resize_image_with_crop_or_pad(img_data, 1000, 1000)
padded = tf.image.resize_image_with_crop_or_pad(img_data, 3000, 3000)

TensorFlow还支持通过比例调整图像大小,以下代码给出了一个样例。

central_cropped = tf.image.central_crop(img_data, 0.5)

图像翻转

TensorFlow提供了一些函数来支持对图像的翻转。以下代码实现了将图像上下翻转、左右翻转以及沿对角线翻转:

# 将图像上下翻转
flipped_up_down = tf.image.flip_up_down(img_data)
# 将图像左右翻转
flipped_left_right = tf.image.flip_left_right(img_data)
# 沿对角线翻转
transposed = tf.image.transpose_image(img_data)

在很多图像识别问题中,图像翻转不会影响到识别的结果。于是在训练图像识别的神经网络模型时,可以随机地翻转训练图像,这样训练得到的模型可以识别不同角度的实体。

TensorFlow提供了方便的API完成随机图像翻转的过程(随机就意味着可能不会翻转)。

# 以50%概率上下翻转
flipped_up_down = tf.image.random_flip_up_down(img_data)
# 以50%概率左右翻转
flipped_left_right = tf.image.random_flip_left_right(img_data)

图像色彩调整

调整图像的亮度、对比度、饱和度和色相在很多图像识别应用中都不会影响识别结果。所以在训练时,也可以随机调整这些属性。
以下代码展示了如何修改图像的亮度:

 # 将图像的亮度-0.5
adjusted_minus = tf.image.adjust_brightness(img_data,-0.5)
# 色彩调整的API可能导致像素的实际数值超过0.0-1.0范围,因此在输出最终图像前需要将其截断在0.0-1.0范围内
# 截断过程应该在所有处理完成之后进行
adjusted_minus = tf.clip_by_value(adjusted_minus, 0.0,1.0)
# 将图像的亮度+0.5
adjusted_plus = tf.image.adjust_brightness(img_data,0.5)
# 在[-max_delta,max_delta)的范围随机调整图像的亮度
adjusted = tf.image.random_brightness(img_data,max_delta)

以下代码显示了如何调整图像的对比度。

# 将图像的对比度减少到0.5倍
adjusted_minus = tf.image.adjust_contrast(img_data,0.5)
# 将图像的对比度增加5倍
adjusted_plus = tf.image.adjust_contrast(img_data,5)
# 在[lower,upper]的范围随机调整图像的对比度
adjusted = tf.image.random_contrast(img_data,lower,upper)

以下代码显示了如何调整图像的色相。

# 将图像的色相分别增加0.1、0.3、0.6、0.9
adjusted_0_1 = tf.image.adjust_hue(img_data,0.1)
adjusted_0_3 = tf.image.adjust_hue(img_data,0.3)
adjusted_0_6 = tf.image.adjust_hue(img_data,0.6)
adjusted_0_9 = tf.image.adjust_hue(img_data,0.9)
# 在[-max_delta,max_delta]的范围内随机调整图像的色相
# max_delta的取值在[0,0.5}之间
adjusted_random = tf.image.random_hue(img_data,max_delta)

以下代码显示了如何调整图像的饱和度:

 # 将图像的饱和度-5
adjusted_minus = tf.image.adjust_saturation(img_data,-5)
# 将图像的饱和度+5
adjusted_plus = tf.image.adjust_saturation(img_data,+53)
# 在[lower,upper]的范围内随机调整图像的饱和度
adjusted_random = tf.image.random_saturation(img_data,lower,upper)

除了上面这些功能,TensorFlow还提供API来完成图像标准化的过程。这个操作就是将图像上的亮度均值变为0,方差变为1:

# 将代表一张图像的三维矩阵中的数字均值变为0,方差变为1
adjusted= tf.image.per_image_standardization(img_data)

处理标注框

在目标检测数据集中,图像中需要关注的物体通常会被标注框圈出来。
TensorFlow提供了一些工具来处理标注框,以下代码展示了如何做:

# 将图像缩小一些,这样可视化能让标注框更加清楚
img_data= tf.image.resize_images(img_data,[180,267],method=1)
# 将解码后的图像矩阵增加一维
batched = tf.expand_dims(tf.image.convert_image_dtype(img_data, tf.float32),0)
# 给出每一张图像的所有标注框,一个标注框有4个数字,分别代表[y_min,x_min,y_max,x_max]。
# 这里给出的数字都是图像的相对位置
boxes = tf.constant([[[0.05,0.05,0.9,0.7],[0.35,0.47,0.5,0.56]]])
result = tf.image.draw_bounding_boxes(batched, boxes)

和随机翻转图像、随机调整颜色类似,随机截取图像上有信息含量的部分也一个提模型健壮性的一种方式。
这样可以使训练得到的模型不受被识别物体大小的影响。

以下程序中展示了如何通过 tf.image.sample_distorted_bounding_box 函数来完成随机截取图像的过程:

import tensorflow as tfimport matplotlib.pyplot as plt# 读取图像的原始数据
image_raw_data = tf.gfile.FastGFile('./images/cat.jpg','rb').read()plt.rcParams['font.sans-serif'] = ['KaiTi'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题def trim_axs(axs, N):axs = axs.flatfor ax in axs[N:]:ax.remove()return axs[:N]with tf.Session() as sess:# 对图像进行jpeg的格式解码从而得到图像对应的三维矩阵img_data = tf.image.decode_jpeg(image_raw_data)# 首先将图片数据转化为实数类型。将0-255的像素值转化为0.0-1.0的实数# img_data = tf.image.convert_image_dtype(img_data, dtype=tf.float32)# 将图像缩小一些,这样可视化能让标注框更加清楚img_data= tf.image.resize_images(img_data,[180,267],method=1)# 将解码后的图像矩阵增加一维# 给出每一张图像的所有标注框,一个标注框有4个数字,分别代表[y_min,x_min,y_max,x_max]。# 这里给出的数字都是图像的相对位置boxes = tf.constant([[[0.05,0.05,0.9,0.7],[0.35,0.47,0.5,0.56]]])# min_object_covered=0.4表示截图部分至少包含某个标注框40%的内容begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(img_data), bounding_boxes=boxes,min_object_covered=0.4)# 通过标注框可视化随机截取到的图像batched = tf.expand_dims(tf.image.convert_image_dtype(img_data, tf.float32),0)image_with_box = tf.image.draw_bounding_boxes(batched, bbox_for_draw)# 截取随机出来的图像distored_image = tf.slice(img_data,begin,size)axs = plt.figure(figsize=(9, 4.5),constrained_layout=True).subplots(1,2)axs = trim_axs(axs,2)axs[0].set_xlabel('在图像中随机加入的标注框')axs[0].imshow(image_with_box[0].eval())axs[1].set_xlabel('通过这个标注框截取的图像')axs[1].imshow(distored_image.eval())plt.show()

图像处理完整样例

在解决真实的图像识别问题时,一般会同时使用多种处理方法。

以下程序完成了从图像片段截取,到图像大小调整再到图像翻转及色彩调整的整个图像预处理过程。

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt# 给定一张图片,随机调整色彩
# 因为调整亮度、对比度、饱和度和色相的顺序会影响最后得到的结果,所以可以定义多种不通的训练
def distort_color(image,color_ordering=0):if color_ordering == 0:image = tf.image.random_brightness(image, max_delta=32. / 255.)iamge = tf.image.random_saturation(image,lower=0.5,upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)image = tf.image.random_contrast(image,lower=0.5,upper=1.5)elif color_ordering == 1:iamge = tf.image.random_saturation(image,lower=0.5,upper=1.5)image = tf.image.random_brightness(image, max_delta=32. / 255.)image = tf.image.random_contrast(image,lower=0.5,upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)else:image = tf.image.random_contrast(image,lower=0.5,upper=1.5)image = tf.image.random_hue(image, max_delta=0.2)image = tf.image.random_brightness(image, max_delta=32. / 255.)iamge = tf.image.random_saturation(image,lower=0.5,upper=1.5)return tf.clip_by_value(image,0.0,1.0)# 给定一张解码后的图像、目标图像的尺寸以及图像上的标注框,此函数可以对给出的图像进行预处理。
# 输入是原始的训练图像,输出是神经网络模型的输入层,一般只需要处理训练数据。
def preprocess_for_train(image,height,width,bbox):# 如果没有提供标注框,则认为整个图像就是需要关注的部分if bbox is None:bbox = tf.constant([0.0,0.0,1.0,1.0],dtype=tf.float32,shape=[1,1,4])# 转换图像张量的类型if image.dtype != tf.float32:image = tf.image.convert_image_dtype(image, dtype=tf.float32)# 随机截取图像,减小需要关注的物体大小对图像识别算法的影响。bbox_begin, bbox_size, _ = tf.image.sample_distorted_bounding_box(tf.shape(image), bounding_boxes=bbox)distorted_image = tf.slice(image, bbox_begin, bbox_size)# 将随机截取的图像调整为神经网络输入层的大小。大小调整的算法是随机选择的distorted_image = tf.image.resize_images(distorted_image, [height, width], method=np.random.randint(4))# 随机左右翻转图像distorted_image = tf.image.random_flip_left_right(distorted_image, np.random.randint(2))# 使用一种随机的顺序调整图像色彩distorted_image = distort_color(distorted_image, np.random.randint(3))return distorted_image# 读取图像的原始数据
image_raw_data = tf.gfile.FastGFile('./images/cat.jpg','rb').read()plt.rcParams['font.sans-serif'] = ['KaiTi'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False # 解决保存图像是负号'-'显示为方块的问题def trim_axs(axs, N):axs = axs.flatfor ax in axs[N:]:ax.remove()return axs[:N]with tf.Session() as sess:# 对图像进行jpeg的格式解码从而得到图像对应的三维矩阵img_data = tf.image.decode_jpeg(image_raw_data)boxes = tf.constant([[[0.05,0.05,0.9,0.7],[0.35,0.47,0.5,0.56]]])axs = plt.figure(figsize=(9, 4.5),constrained_layout=True).subplots(2,3)    axs = trim_axs(axs,6)# 运行6次获得6种不同的图像for i in range(6):# 将图像的尺寸调整为299x299result = preprocess_for_train(img_data, 299, 299, boxes)axs[i].imshow(result.eval())plt.show()

这样一张训练图像就可以衍生出很多训练样本。

多线程输入数据处理框架

上面介绍的这些预处理会减慢整个训练过程,为了避免预处理成为模型训练效率的瓶颈,TensorFlow提供了一套多线程处理输入数据的框架。

#mermaid-svg-wudu4OWNsgubyfwp .label{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);fill:#333;color:#333}#mermaid-svg-wudu4OWNsgubyfwp .label text{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .node rect,#mermaid-svg-wudu4OWNsgubyfwp .node circle,#mermaid-svg-wudu4OWNsgubyfwp .node ellipse,#mermaid-svg-wudu4OWNsgubyfwp .node polygon,#mermaid-svg-wudu4OWNsgubyfwp .node path{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-wudu4OWNsgubyfwp .node .label{text-align:center;fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .node.clickable{cursor:pointer}#mermaid-svg-wudu4OWNsgubyfwp .arrowheadPath{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .edgePath .path{stroke:#333;stroke-width:1.5px}#mermaid-svg-wudu4OWNsgubyfwp .flowchart-link{stroke:#333;fill:none}#mermaid-svg-wudu4OWNsgubyfwp .edgeLabel{background-color:#e8e8e8;text-align:center}#mermaid-svg-wudu4OWNsgubyfwp .edgeLabel rect{opacity:0.9}#mermaid-svg-wudu4OWNsgubyfwp .edgeLabel span{color:#333}#mermaid-svg-wudu4OWNsgubyfwp .cluster rect{fill:#ffffde;stroke:#aa3;stroke-width:1px}#mermaid-svg-wudu4OWNsgubyfwp .cluster text{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp div.mermaidTooltip{position:absolute;text-align:center;max-width:200px;padding:2px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:12px;background:#ffffde;border:1px solid #aa3;border-radius:2px;pointer-events:none;z-index:100}#mermaid-svg-wudu4OWNsgubyfwp .actor{stroke:#ccf;fill:#ECECFF}#mermaid-svg-wudu4OWNsgubyfwp text.actor>tspan{fill:#000;stroke:none}#mermaid-svg-wudu4OWNsgubyfwp .actor-line{stroke:grey}#mermaid-svg-wudu4OWNsgubyfwp .messageLine0{stroke-width:1.5;stroke-dasharray:none;stroke:#333}#mermaid-svg-wudu4OWNsgubyfwp .messageLine1{stroke-width:1.5;stroke-dasharray:2, 2;stroke:#333}#mermaid-svg-wudu4OWNsgubyfwp #arrowhead path{fill:#333;stroke:#333}#mermaid-svg-wudu4OWNsgubyfwp .sequenceNumber{fill:#fff}#mermaid-svg-wudu4OWNsgubyfwp #sequencenumber{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp #crosshead path{fill:#333;stroke:#333}#mermaid-svg-wudu4OWNsgubyfwp .messageText{fill:#333;stroke:#333}#mermaid-svg-wudu4OWNsgubyfwp .labelBox{stroke:#ccf;fill:#ECECFF}#mermaid-svg-wudu4OWNsgubyfwp .labelText,#mermaid-svg-wudu4OWNsgubyfwp .labelText>tspan{fill:#000;stroke:none}#mermaid-svg-wudu4OWNsgubyfwp .loopText,#mermaid-svg-wudu4OWNsgubyfwp .loopText>tspan{fill:#000;stroke:none}#mermaid-svg-wudu4OWNsgubyfwp .loopLine{stroke-width:2px;stroke-dasharray:2, 2;stroke:#ccf;fill:#ccf}#mermaid-svg-wudu4OWNsgubyfwp .note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-wudu4OWNsgubyfwp .noteText,#mermaid-svg-wudu4OWNsgubyfwp .noteText>tspan{fill:#000;stroke:none}#mermaid-svg-wudu4OWNsgubyfwp .activation0{fill:#f4f4f4;stroke:#666}#mermaid-svg-wudu4OWNsgubyfwp .activation1{fill:#f4f4f4;stroke:#666}#mermaid-svg-wudu4OWNsgubyfwp .activation2{fill:#f4f4f4;stroke:#666}#mermaid-svg-wudu4OWNsgubyfwp .mermaid-main-font{font-family:"trebuchet ms", verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .section{stroke:none;opacity:0.2}#mermaid-svg-wudu4OWNsgubyfwp .section0{fill:rgba(102,102,255,0.49)}#mermaid-svg-wudu4OWNsgubyfwp .section2{fill:#fff400}#mermaid-svg-wudu4OWNsgubyfwp .section1,#mermaid-svg-wudu4OWNsgubyfwp .section3{fill:#fff;opacity:0.2}#mermaid-svg-wudu4OWNsgubyfwp .sectionTitle0{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .sectionTitle1{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .sectionTitle2{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .sectionTitle3{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .sectionTitle{text-anchor:start;font-size:11px;text-height:14px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .grid .tick{stroke:#d3d3d3;opacity:0.8;shape-rendering:crispEdges}#mermaid-svg-wudu4OWNsgubyfwp .grid .tick text{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .grid path{stroke-width:0}#mermaid-svg-wudu4OWNsgubyfwp .today{fill:none;stroke:red;stroke-width:2px}#mermaid-svg-wudu4OWNsgubyfwp .task{stroke-width:2}#mermaid-svg-wudu4OWNsgubyfwp .taskText{text-anchor:middle;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .taskText:not([font-size]){font-size:11px}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutsideRight{fill:#000;text-anchor:start;font-size:11px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutsideLeft{fill:#000;text-anchor:end;font-size:11px}#mermaid-svg-wudu4OWNsgubyfwp .task.clickable{cursor:pointer}#mermaid-svg-wudu4OWNsgubyfwp .taskText.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutsideLeft.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutsideRight.clickable{cursor:pointer;fill:#003163 !important;font-weight:bold}#mermaid-svg-wudu4OWNsgubyfwp .taskText0,#mermaid-svg-wudu4OWNsgubyfwp .taskText1,#mermaid-svg-wudu4OWNsgubyfwp .taskText2,#mermaid-svg-wudu4OWNsgubyfwp .taskText3{fill:#fff}#mermaid-svg-wudu4OWNsgubyfwp .task0,#mermaid-svg-wudu4OWNsgubyfwp .task1,#mermaid-svg-wudu4OWNsgubyfwp .task2,#mermaid-svg-wudu4OWNsgubyfwp .task3{fill:#8a90dd;stroke:#534fbc}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutside0,#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutside2{fill:#000}#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutside1,#mermaid-svg-wudu4OWNsgubyfwp .taskTextOutside3{fill:#000}#mermaid-svg-wudu4OWNsgubyfwp .active0,#mermaid-svg-wudu4OWNsgubyfwp .active1,#mermaid-svg-wudu4OWNsgubyfwp .active2,#mermaid-svg-wudu4OWNsgubyfwp .active3{fill:#bfc7ff;stroke:#534fbc}#mermaid-svg-wudu4OWNsgubyfwp .activeText0,#mermaid-svg-wudu4OWNsgubyfwp .activeText1,#mermaid-svg-wudu4OWNsgubyfwp .activeText2,#mermaid-svg-wudu4OWNsgubyfwp .activeText3{fill:#000 !important}#mermaid-svg-wudu4OWNsgubyfwp .done0,#mermaid-svg-wudu4OWNsgubyfwp .done1,#mermaid-svg-wudu4OWNsgubyfwp .done2,#mermaid-svg-wudu4OWNsgubyfwp .done3{stroke:grey;fill:#d3d3d3;stroke-width:2}#mermaid-svg-wudu4OWNsgubyfwp .doneText0,#mermaid-svg-wudu4OWNsgubyfwp .doneText1,#mermaid-svg-wudu4OWNsgubyfwp .doneText2,#mermaid-svg-wudu4OWNsgubyfwp .doneText3{fill:#000 !important}#mermaid-svg-wudu4OWNsgubyfwp .crit0,#mermaid-svg-wudu4OWNsgubyfwp .crit1,#mermaid-svg-wudu4OWNsgubyfwp .crit2,#mermaid-svg-wudu4OWNsgubyfwp .crit3{stroke:#f88;fill:red;stroke-width:2}#mermaid-svg-wudu4OWNsgubyfwp .activeCrit0,#mermaid-svg-wudu4OWNsgubyfwp .activeCrit1,#mermaid-svg-wudu4OWNsgubyfwp .activeCrit2,#mermaid-svg-wudu4OWNsgubyfwp .activeCrit3{stroke:#f88;fill:#bfc7ff;stroke-width:2}#mermaid-svg-wudu4OWNsgubyfwp .doneCrit0,#mermaid-svg-wudu4OWNsgubyfwp .doneCrit1,#mermaid-svg-wudu4OWNsgubyfwp .doneCrit2,#mermaid-svg-wudu4OWNsgubyfwp .doneCrit3{stroke:#f88;fill:#d3d3d3;stroke-width:2;cursor:pointer;shape-rendering:crispEdges}#mermaid-svg-wudu4OWNsgubyfwp .milestone{transform:rotate(45deg) scale(0.8, 0.8)}#mermaid-svg-wudu4OWNsgubyfwp .milestoneText{font-style:italic}#mermaid-svg-wudu4OWNsgubyfwp .doneCritText0,#mermaid-svg-wudu4OWNsgubyfwp .doneCritText1,#mermaid-svg-wudu4OWNsgubyfwp .doneCritText2,#mermaid-svg-wudu4OWNsgubyfwp .doneCritText3{fill:#000 !important}#mermaid-svg-wudu4OWNsgubyfwp .activeCritText0,#mermaid-svg-wudu4OWNsgubyfwp .activeCritText1,#mermaid-svg-wudu4OWNsgubyfwp .activeCritText2,#mermaid-svg-wudu4OWNsgubyfwp .activeCritText3{fill:#000 !important}#mermaid-svg-wudu4OWNsgubyfwp .titleText{text-anchor:middle;font-size:18px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp g.classGroup text{fill:#9370db;stroke:none;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family);font-size:10px}#mermaid-svg-wudu4OWNsgubyfwp g.classGroup text .title{font-weight:bolder}#mermaid-svg-wudu4OWNsgubyfwp g.clickable{cursor:pointer}#mermaid-svg-wudu4OWNsgubyfwp g.classGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-wudu4OWNsgubyfwp g.classGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp .classLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.5}#mermaid-svg-wudu4OWNsgubyfwp .classLabel .label{fill:#9370db;font-size:10px}#mermaid-svg-wudu4OWNsgubyfwp .relation{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-wudu4OWNsgubyfwp .dashed-line{stroke-dasharray:3}#mermaid-svg-wudu4OWNsgubyfwp #compositionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #compositionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #aggregationStart{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #aggregationEnd{fill:#ECECFF;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #dependencyStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #dependencyEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #extensionStart{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp #extensionEnd{fill:#9370db;stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp .commit-id,#mermaid-svg-wudu4OWNsgubyfwp .commit-msg,#mermaid-svg-wudu4OWNsgubyfwp .branch-label{fill:lightgrey;color:lightgrey;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .pieTitleText{text-anchor:middle;font-size:25px;fill:#000;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .slice{font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp g.stateGroup text{fill:#9370db;stroke:none;font-size:10px;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp g.stateGroup text{fill:#9370db;fill:#333;stroke:none;font-size:10px}#mermaid-svg-wudu4OWNsgubyfwp g.statediagram-cluster .cluster-label text{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp g.stateGroup .state-title{font-weight:bolder;fill:#000}#mermaid-svg-wudu4OWNsgubyfwp g.stateGroup rect{fill:#ECECFF;stroke:#9370db}#mermaid-svg-wudu4OWNsgubyfwp g.stateGroup line{stroke:#9370db;stroke-width:1}#mermaid-svg-wudu4OWNsgubyfwp .transition{stroke:#9370db;stroke-width:1;fill:none}#mermaid-svg-wudu4OWNsgubyfwp .stateGroup .composit{fill:white;border-bottom:1px}#mermaid-svg-wudu4OWNsgubyfwp .stateGroup .alt-composit{fill:#e0e0e0;border-bottom:1px}#mermaid-svg-wudu4OWNsgubyfwp .state-note{stroke:#aa3;fill:#fff5ad}#mermaid-svg-wudu4OWNsgubyfwp .state-note text{fill:black;stroke:none;font-size:10px}#mermaid-svg-wudu4OWNsgubyfwp .stateLabel .box{stroke:none;stroke-width:0;fill:#ECECFF;opacity:0.7}#mermaid-svg-wudu4OWNsgubyfwp .edgeLabel text{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .stateLabel text{fill:#000;font-size:10px;font-weight:bold;font-family:'trebuchet ms', verdana, arial;font-family:var(--mermaid-font-family)}#mermaid-svg-wudu4OWNsgubyfwp .node circle.state-start{fill:black;stroke:black}#mermaid-svg-wudu4OWNsgubyfwp .node circle.state-end{fill:black;stroke:white;stroke-width:1.5}#mermaid-svg-wudu4OWNsgubyfwp #statediagram-barbEnd{fill:#9370db}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-cluster rect{fill:#ECECFF;stroke:#9370db;stroke-width:1px}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-cluster rect.outer{rx:5px;ry:5px}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-state .divider{stroke:#9370db}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-state .title-state{rx:5px;ry:5px}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-cluster.statediagram-cluster .inner{fill:white}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-cluster.statediagram-cluster-alt .inner{fill:#e0e0e0}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-cluster .inner{rx:0;ry:0}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-state rect.basic{rx:5px;ry:5px}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-state rect.divider{stroke-dasharray:10,10;fill:#efefef}#mermaid-svg-wudu4OWNsgubyfwp .note-edge{stroke-dasharray:5}#mermaid-svg-wudu4OWNsgubyfwp .statediagram-note rect{fill:#fff5ad;stroke:#aa3;stroke-width:1px;rx:0;ry:0}:root{--mermaid-font-family: '"trebuchet ms", verdana, arial';--mermaid-font-family: "Comic Sans MS", "Comic Sans", cursive}#mermaid-svg-wudu4OWNsgubyfwp .error-icon{fill:#522}#mermaid-svg-wudu4OWNsgubyfwp .error-text{fill:#522;stroke:#522}#mermaid-svg-wudu4OWNsgubyfwp .edge-thickness-normal{stroke-width:2px}#mermaid-svg-wudu4OWNsgubyfwp .edge-thickness-thick{stroke-width:3.5px}#mermaid-svg-wudu4OWNsgubyfwp .edge-pattern-solid{stroke-dasharray:0}#mermaid-svg-wudu4OWNsgubyfwp .edge-pattern-dashed{stroke-dasharray:3}#mermaid-svg-wudu4OWNsgubyfwp .edge-pattern-dotted{stroke-dasharray:2}#mermaid-svg-wudu4OWNsgubyfwp .marker{fill:#333}#mermaid-svg-wudu4OWNsgubyfwp .marker.cross{stroke:#333}:root { --mermaid-font-family: "trebuchet ms", verdana, arial;}#mermaid-svg-wudu4OWNsgubyfwp {color: rgba(0, 0, 0, 0.75);font: ;}

指定原始数据的文件列表
创建文件列表队列
从文件中读取数据
数据预处理
整理成Batch作为神经网络输入

上图总结了一个经典的输入数据处理的流程。在下面的各小节中,将依次介绍这个流程的不同部分。

队列与多线性

队列和变量类似,都是计算图上有状态的节点。
对于队列,修改队列状态的操作主要有EnqueueEnqueueManyDequeue
以下程序展示了如何使用这些函数来操作队列:

import tensorflow as tf# 创建一个FIFO队列,指定队列中最多可以保存两个元素
q = tf.FIFOQueue(2, 'int32')
# 使用enqueue_many函数来初始化队列中的元素,和变量初始化类似,在使用队列之前需要明确的调用这个初始化过程
init = q.enqueue_many(([0,10],))
# 使用Dequeue函数将队列的第一个元素出队,这个元素的值被保存在变量x中
x = q.dequeue()
# 将得到的值加1
y = x + 1
# 将加1后的值再重新加入队列
q_inc = q.enqueue([y])with tf.Session() as sess:# 运行队列初始化操作init.run()for _ in range(5):# 运行q_inc执行数据出队、出队的元素+1、重新加入队列的整个过程v, _ = sess.run([x, q_inc])# 打印出队元素的取值print(v)

输出

0
10
1
11
2

TensorFlow中提供了FIFOQueueRandomShuffleQueue两种队列。RandomShuffleQueue会将队列中的元素打乱,每次出队时从当前队列所有元素中随机选择一个。

在TensorFlow中,队列不仅是一种数据结构,还是异步计算张量取值的一个重要机制(生产者-消费者模式)。
TensorFlow提供了tf.Coordinatortf.QueueRunner两个类来完成多线性协同的功能。
tf.Coordinator主要用于协同多个线程一起停止,并提供了should_stoprequest_stopjoin三个函数。
在启动线程之前,需要先声明一个tf.Coordinator类,并将这个类传入每一个创建的线程中。
启动的线程需要一直查询tf.Coordinator类中提供的should_stop函数,当这个函数的返回值为True时,则当前线程也需要退出。
每一个启动的线程都可以通过调用request_stop函数来通知其他线程退出。
当某一个线程调用request_stop之后,should_stop的返回值被设置为True,这样其他的线程就可以同时停止了。
如下所示:

import tensorflow as tfimport numpy as np
import threading
import time# 线程中运行的程序,这个程序每隔1秒判断是否需要停止并打印自己的ID
def MyLoop(coord, worker_id):# 使用tf.Coordinator类提供的协同工具判断当前线程是否需要停止while not coord.should_stop():# 随机停止所有的线程if np.random.rand() < 0.1:print('Stopping from id: %d\n' % worker_id)# 调用request_stop函数来通知其他线程停止coord.request_stop()else:# 打印当前线程的IDprint('Working on id:%d\n' % worker_id)# 暂停1秒time.sleep(1)# 声明一个tf.train.Coordinator类协同多个线程
coord = tf.train.Coordinator()
# 声明创建5个线程
threads = [threading.Thread(target=MyLoop, args=(coord, i,)) for i in range(5)]
# 启动所有的线程
for t in threads:t.start()
# 等待所有线程退出
coord.join(threads)

输出

Working on id:0Working on id:1Working on id:2Stopping from id: 3

每个线程启动后,有90%的概率打印自己的id,有10%的概率停掉所有线程。
从输出看,id为3的线程触发了停止操作,它没有打印自己的id,同时id为4的线程可能还未打印就碰到线程关闭了。

tf.QueueRunner主要用于启动多个线程来操作同一个队列,启动的这些线程可以通过tf.Coordinator来统一管理。如下所示:

import tensorflow as tf# 队列中最多100个元素
queue = tf.FIFOQueue(100,'float')
# 定义队列的入队操作
enqueue_op = queue.enqueue([tf.random_normal([1])])
# 使用tf.train.QueueRunner来创建多个线程运行队列的入队操作
# 第一个参数给出了被操作的队列
# [enqueue_op] * 5表示需要启动5个线程,每个线程中运行的是enqueue_op操作
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
# 将定义的QueueRunner加入TensorFlow计算图上指定的集合
# tf.train.add_queue_runner函数没有指定集合的话,则加入默认集合tf.GraphKeys.QUEUE_RUNNERS
# 下面的函数就是将刚刚定义的qr加入模型的tf.GraphKeys.QUEUE_RUNNERS集合
tf.train.add_queue_runner(qr)
# 定义出队操作
out_tensor = queue.dequeue()with tf.Session() as sess:# 使用tf.train.Coordinator来协同启动的线程coord = tf.train.Coordinator()# 使用QueueRunner时,需要明确调用start_queue_runners来启动所有线程。# 否则因为没有线程运行入队操作,当调用出队操作时,程序会一直等待入队操作被运行。# start_queue_runners会默认启动tf.GraphKeys.QUEUE_RUNNERS集合中所有的QueueRunner。# 因为这个函数只支持启动指定集合中的QueueRunner,所以一般来说add_queue_runner函数和# start_queue_runners函数会指定同一个集合threads = tf.train.start_queue_runners(sess=sess, coord=coord)# 获取队列中的取值for _ in range(3):print(sess.run(out_tensor)[0])# 使用Coordinator来停止所有的线程coord.request_stop()coord.join(threads)

输出

-0.9941363
-0.91206753
-1.8396323

以上程序将启动5个线程来执行队列入队操作,其中每个线程都是将随机数写入队列。于是在每次运行出队操作时,可以得到一个随机数。

输入文件队列

本节将介绍如何使用TensorFlow中的队列管理输入文件列表。
假设所有的输入数据都已经整理成了TFRecord格式。

TensorFlow提供了 tf.train.match_filenames_once 函数来获取符合一个正则表达式的所有文件,
得到的文件列表可以通过 tf.train.string_input_producer 函数进行有效的管理。

tf.train.string_input_producer 函数会使用初始化时提供的文件列表创建一个输入队列,
输入队列中原始的元素为文件列表中的所有文件。
创建好的输入队列可以作为文件读取函数的参数。每次调用文件读取函数时,该函数会先判断当前是否己有打开的文件可读,如果没有或者打开的文件己经读完,这个函数会从输入队列中出队一个文件并从这个文件中读取数据。

通过设置 shuffle 参数, tf.train.string_input_producer 函数支持随机打乱文件列表中文件出队的顺序 。
shuffle 参数为 True 时,文件在加入队列之前会被打乱顺序,所以出队的顺序也是随机的。
随机打乱文件顺序以及加入输入队列的过程会跑在一个单独的线程上,这样不会影响获取文件的速度。
tf.train.string_input_producer 生成的输入队列可以同时被多个文件读取线程操作,而且输入队列会将队列中的文件均匀地分给不同的线程,不出现有些文件被处理过多次而有些文件还没有被处理过的情况。

当一个输入队列中的所有文件都被处理完后,它会将初始化时提供的文件列表中的文件全部重新加入队列。
tf.train.string_input_producer 函数可以设 num_epochs 参数来限制加载初始文件列表的最大轮数。
当所有文件都己经被使用了设定的轮数后,如果继续尝试读取新的文件,
输入队列会报 OutOfRange 的错误。
在测试神经网络模型时,因为所有测试数据只需要使用一次,所以可以将 num_epochs 参数设置为1
这样在计算完一轮之后程序将自动停止。在展示 tf.train.match_ filenames_oncetf.train.string_ input_producer 函数的使用方法之前,下面先给出一个简单的程序来生成样例数据:

import tensorflow as tf# 创建TFRecord文件的帮助函数
def _int64_feature(value):return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))# 模型海量数据情况下将数据写入不同的文件。num_shards定义了总共写入多少个文件
# instances_per_shard定义了每个文件中有多少个数据
num_shards = 2
instances_per_shard = 2
for i in range(num_shards):# 将数据分为多个文件时,可以将不同文件以类似0000n-of-0000m的后缀区分。# 其中m表示了数据总共被存在了多少个文件中,n表示当前文件的编号。# 这样的方式既方便了通过正则获取文件列表,又在文件名中加入了更多的信息。filename = ('./files/data.tfrecords-%.5d-of-%.5d' % (i, num_shards))writer = tf.python_io.TFRecordWriter(filename)# 将数据封装成Example结构并写入TFRecord文件for j in range(instances_per_shard):# Example结构仅包含当前样例属于第几个文件以及是当前文件的第几个样本example = tf.train.Example(features=tf.train.Features(feature={'i': _int64_feature(i),'j': _int64_feature(j)}))writer.write(example.SerializeToString())writer.close()

程序运之后,在指定的目录下生产两个文件:data.tfrecords-00000-of-00002data.tfrecords-00001-of-00002

每个文件中存储了两个样例。在生成了样例数据之后,以下代码展示了tf.train.match_filenames_once函数和tf.train.string_input_producer函数的使用方法。

import tensorflow as tf
# 获取文件列表
files = tf.train.match_filenames_once('./files/data.tfrecords-*')# 通过string_input_producer创建输入队列,输入队列中的文件列表为match_filenames_once函数获取的文件列表。
# 这里将shuffle参数设为False来避免随机打乱文件的顺序
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features = {'i': tf.FixedLenFeature([], tf.int64),'j': tf.FixedLenFeature([], tf.int64),}
)with tf.Session() as sess:# 使用match_filenames_once函数时需要初始化tf.local_variables_initializer().run()print(sess.run(files))'''打印:[b'.\\files\\data.tfrecords-00000-of-00002'b'.\\files\\data.tfrecords-00001-of-00002']'''# 声明Coordinator类来协同不同线程coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)# 多次执行获取数据的操作for i in range(6):print(sess.run([features['i'] , features['j']]))coord.request_stop()coord.join(threads)

输出

[0, 0]
[0, 1]
[1, 0]
[1, 1]
[0, 0]
[0, 1]

在不打乱文件列表的情况下,会依次读出样例数据中的每一个样例。而且当所有样例都被读完之后,程序会自动从头开始。

组合训练数据(batching)

将多个输入样例组成一个batch可以提高模型训练的效率。
TensorFlow提供了tf.train.batchtf.train.shuffle_batch函数来将单个的样例组织成batch的形式输出。

这两个函数都会生成一个队列,队列的入队操作是生成单个样例的方法,而每次出队得到的是一个batch的样例。它们唯一的区别在于是否会将数据顺序打乱。以下代码展示了这两个函数的使用方法:

import tensorflow as tf
# 获取文件列表
files = tf.train.match_filenames_once('./files/data.tfrecords-*')# 通过string_input_producer创建输入队列,输入队列中的文件列表为match_filenames_once函数获取的文件列表。
# 这里将shuffle参数设为False来避免随机打乱文件的顺序
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features = {'i': tf.FixedLenFeature([], tf.int64),'j': tf.FixedLenFeature([], tf.int64),}
)example, label = features['i'], features['j']# 一个batch中样例的个数
batch_size = 3
# 组合样例的队列中最多可以存储的样例个数。
capacity = 1000 + 3 * batch_size# 使用tf.train.batch函数来组合样例
# [example, label]参数给出了需要组合的元素,
# example和label分别代表训练样本和这个样本对应的正确标签。
# batch_size参数给出了每个batch中样例的个数
# capacity给出了队列的最大容量,当队列长度等于容量时,将暂停入队操作,
# 而只是等待元素出队。当元素个数小于容量时,将自动重启入队操作
example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size,capacity=capacity
)with tf.Session() as sess:tf.global_variables_initializer().run()tf.local_variables_initializer().run()coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)#  获取并打印组合之后的样例for i in range(2):cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])print(cur_example_batch, cur_label_batch)coord.request_stop()coord.join(threads)

输出

[0 0 1] [0 1 0] #对应第0个到第2个的example和 label值
[1 0 0] [1 0 1]# 对应第3个到第1个example和label值

从这个输出可以看到tf.train.batch函数可以将单个的数据组成3个一组的batch。

example,label中读到的数据依次为:

example: 0, label: 0
example: 0, label: 1
example: 1, label: 0
example: 1, label: 1

这是因为tf.train.batch函数不会随机打乱顺序,所以组合之后得到的数据成立上面的输出。

下面一段代码展示了tf.train.shuffle_batch函数的使用方法。

import tensorflow as tf
# 获取文件列表
files = tf.train.match_filenames_once('./files/data.tfrecords-*')# 通过string_input_producer创建输入队列,输入队列中的文件列表为match_filenames_once函数获取的文件列表。
# 这里将shuffle参数设为False来避免随机打乱文件的顺序
filename_queue = tf.train.string_input_producer(files, shuffle=False)
# 读取并解析一个样本
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features = {'i': tf.FixedLenFeature([], tf.int64),'j': tf.FixedLenFeature([], tf.int64),}
)example, label = features['i'], features['j']
# min_after_dequeue参数限制了出队时队列中元素的最少个数
# 当队列中元素太少时,随机打乱顺序的作用就不大了。
# 当出队函数被调用,但是队列中元素不够时,出队操作将等待更多的元素入队才能完成
example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size,capacity=capacity,min_after_dequeue=30
)with tf.Session() as sess:tf.global_variables_initializer().run()tf.local_variables_initializer().run()coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)#  获取并打印组合之后的样例for i in range(2):cur_example_batch, cur_label_batch = sess.run([example_batch, label_batch])print(cur_example_batch, cur_label_batch)coord.request_stop()coord.join(threads)

输出

[0 1 0] [1 0 0]
[1 0 1] [0 1 0]

从输出可以看到,得到的样例顺序已经被打乱了。

tf.train.batch 函数和tf. train .shuffle_batch 函数除了可以将单个训练数据整理成输入batch ,
也提供了并行化处理输入数据的方法。
tf.train.batch 函数和tf.train.shuftle_batch 函数并行化的方式一致,
所以在本节中仅以应用得更多的tf.train.shuffle_batch 函数为例。
通过设置tf.train.shuftle_batch函数中的num_threads 参数,可以指定多个线程同时执行入队操作。

tf.train.shuftle_batch函数的入队操作就是数据读取以及预处理的过程。当num_threads 参数大于1时,多个线程会同时读取一个文件中的不同样例并进行预处理。
如果需要多个线程处理不同文件中的样例时,可以使用tf.train.shuftle_batch_join 函数。
此函数会从输入文件队列中获取不同的文件分配给不同的线程。

一般来说,输入文件队列是通过上面介绍的tf.train.string_input_producer函数生成的。这个函数会平均分配文件以保证不同文件中的数据会被尽量平均地使用。

tf.train.shuftle_batch函数和tf.train.shuftle_batch_join函数都可以完成多线程并行的方式来进行数据预处理,但它们各有优劣。对于tf.train.shuftle_batch函数,不同线程会读取同一个文件。
如果一个文件中的样例比较相似(比如都属于同一个类别),那么神经网络的训练效果有可能会受到影响。
所以在使用tf.train.shuftle_batch 函数时,需要尽量将同一个TFRecord 文件中的样例随机打乱。而使用tf.train.shuftle_batch_join函数时,不同线程会读取不同文件。如果读取数据的线程数比总文件数还大,那么多个线程可能会读取同一个文件中相近部分的数据。
而且多个线程读取多个文件可能导致过多的硬盘寻址,从而使得读取效率降低。不同的井行化方式各有所长,具体采用哪一种方法需要根据具体情况来确定。

输入数据处理框架

在前面的小节中已经介绍了流程图中的所有步骤。在这一节将把这些步骤串成一个完整的程序来处理输入数据:

import tensorflow as tf# 创建文件列表,并通过文件列表创建输入文件队列。
# 同一所有原始数据的格式并将它们存储到TFRecord文件中
# 下面给出的文件列表应该包含所有提供训练数据的TFRecord文件files = tf.train.match_filenames_once('./files/file-pattern-*')
filename_queue = tf.train.string_input_producer(files, shuffle=False)# 这里假设image中存储的是图像的原始数据
# label为该样例所对应的标签
# height、width和channels给出了图片的维度
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,features={'image': tf.FixedLenFeature([],tf.string),'label': tf.FixedLenFeature([],tf.int64),'height': tf.FixedLenFeature([],tf.int64),'width': tf.FixedLenFeature([],tf.int64),'channels': tf.FixedLenFeature([],tf.int64),}
)image, label = features['image'], features['label']
height, width = features['height'], features['width']channels = features['channels']# 从原始图像数据解析出像素矩阵,并根据图像尺寸还原图像
decoded_iamge = tf.decode_raw(iamge, tf.uint8)
decoded_iamge.set_shape([height, width, channels])image_size = 299
distorted_iamge = preprocess_for_train(decoded_iamge, image_size, image_size,None
)# 将处理后的图像和标签数据通过shuffle_batch整理成神经网络训练时需要的batchmin_after_dequeue = 10000
batch_size = 100
capacity = min_after_dequeue + 3 * batch_sizeimage_batch, label_batch = tf.train.shuffle_batch([distorted_iamge, label], batch_size=batch_size,capacity=capacity, min_after_dequeue=min_after_dequeue
)# 定义神经网络的结构以及优化过程
# image_batch可以作为输入提供给神经网络的输入层
# label_batch可则提供了输入batch中样例的正确答案learning_rate = 0.01
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
# 声明会话并运行神经网络的优化过程
with tf.Session() as sess:# 神经网络训练准备工作sess.run((tf.global_variables_initializer(),tf.local_variables_initializer()))coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)# 神经网络训练过程TRAINING_ROUNDS = 5000for i in range(TRAINING_ROUNDS):sess.run(train_step)# 停止所有线程coord.request_stop()coord.join(threads)

上图展示了以上代码输入数据处理的整个流程。
输入数据处理的第一步为获取存储训练数据的文件列表。在图中,这个文件列表为{A,B,C}\{A,B,C\}{A,B,C}。
通过tf.train.string_input_producer函数,可以选择性地将文件列表中文件的顺序打乱,并加入输入队列。

tf.train.string_input_producer函数会生成并维护一个输入文件队列,不同线程中的文件读取函数可以共享这个输入文件队列。
在读取样例数据之后,需要将图像进行预处理。图像预处理的过程也会通过tf.train.shuffle_batch提供的机制并行地跑在多个线程中。
输入数据处理流程的最后通过tf.train.shuffle_batch函数将处理好的单个输入样例整理成batch提供给神经网络的输入层。

通过这种方式,可以有效地提高数据预处理的效率,避免数据预处理称为模型训练过程中的瓶颈。

数据集(Dataset)

除了队列以外,TensorFlow还提供了一套更高层的数据处理框架。
在新的框架中,每一个数据来源被抽象成一个数据集,开发者可以以数据集为基本对象,
方便地进行batching、shuffle等操作。

数据集的基本使用方法

在数据集框架中,每一个数据集代表一个数据来源:数据可能来自一个张量,一个TFRecord文件,一个文本文件,或者经过sharding的一系列文件等。由于训练数据通常无法全部写入内存中,从数据集中读取数据时需要使用一个迭代器按顺序进行读取。
数据集也是计算图上的一个节点。
下面先看一个简单的例子,从一个张量创建一个数据集,遍历这个数据集,并对每个输入输出y=x2y=x^2y=x2的值。

import tensorflow as tf# 从一个数组创建数据集
input_data = [1, 2, 3, 5, 8]
dataset = tf.data.Dataset.from_tensor_slices(input_data)# 定义一个迭代器用于遍历数据集
iterator = dataset.make_one_shot_iterator()
# get_next() 返回代表一个输入数据的张量,类似于队列的dequeue()
x = iterator.get_next()
y = x * x
with tf.Session() as sess:for i in range(len(input_data)):print(sess.run(y))

输出

1
4
9
25
64

从上面的例子可以看到,利用数据集读取数据有三个基本步骤。

  1. 定义数据集的构造方法
    本例使用了tf.data.Dataset.from_tensor_slices(),表明数据集是从一个张量中构建的
  2. 定义遍历器
    本例中使用了最简单的one_shot_iterator来遍历数据集
  3. 使用get_next()方法从遍历器中读取数据张量,作为计算图其他部分的输入。

在真实项目中,训练数据通常是保持在硬盘文件上的。比如在NLP任务中,训练数据通常是以每行一条数据的形式存在文本文件中,此时可以用TextLineDataset来更方便地读取数据:

import tensorflow as tf# 可以提供多个文件
input_files = ['input_file1','input_file2']
dataset = tf.data.TextLineDataset(input_files)# 定义迭代器用于遍历数据集
it = dataset.make_one_shot_iterator()
# 这里返回的是一个字符串类型的张量,代表文件中的一行
x = it.get_next()
with tf.Session() as sess:for i in range(3):print(sess.run(x))

在图像相关任务中,输入数据通常以TFRecord形式存储,这时可以用TFRecordDataset来读取数据。与文本文件不同,每一个TFRecord都有自己不同的feature格式,因此需要提供一个parse函数来解析所读取的TFRecord的数据格式。

import tensorflow as tfdef parse(record):features = tf.parse_single_example(record,features={'feat1': tf.FixedLenFeature([],tf.int64),'feat2': tf.FixedLenFeature([],tf.int64),})return features['feat1'], features['feat2']# 从TFRecord文件创建数据集
input_files = ['inut_file1','input_file2']
dataset = tf.data.TFRecordDataset(input_files)
# 通过map调用prase()对二进制数据进行解析
dataset = dataset.map(parse)it = dataset.make_one_shot_iterator()feat1, feat2 = it.get_next()
with tf.Session() as sess:for i in range(10):f1, f2 = sess.run([feat1, feat2])

以上例子都使用了最简单的one_shot_iterator来遍历数据集,如果需要用到placeholder来初始化数据集,那就需要用到initializable_iterator。以下代码给出了用initializable_iterator来动态初始化数据集的例子。

import tensorfow as tfdef parse(record):...# 从TFRecord文件创建数据集,路径是一个placeholder,稍后再提供具体路径
input_files = tf.placeholder(tf.string)
dataset = tf.data.TFRecordDataset(input_files)
dataset = dataset.map(parse)# 定义遍历dataset的initializable_iterator
it = dataset.make_initializable_iterator()
feat1, feat2 = it.get_next()with tf.Session() as sess:# 首先初始化iterator并给出input_files的值sess.run(it.initializer, feed_dict={input_files:['input_file1','input_file2']})# 遍历所有数据一个epoch,遍历结束会抛出OutOfRangeErrorwhile True:try:sess.run([feat1,feat2])except tf.errors.OutOfRangeError:break

数据集的高层操作

dataset = dataset.map(parse)

map是在数据集上进行操作的最常用方法之一。这里,map(parse)表示对数据集中的每一条数据调用参数中指定的parse方法。
对每条数据进行处理后,map将处理后的数据包装成一个新的数据集返回。
在队列框架下我们曾使用如下方法来对数据进行预处理:

distorted_image = preprocess_for_train(decoded_image, image_size, image_size, None)

而在数据集框架中,可以通过map来对每一条数据调用preprocess_for_train方法:

dataset = dataset.map(lambda x: preprocess_for_train(x, image_size, image_size, None))

lambda将原来有4个参数的函数转化为只有1个参数的函数。
preprocess_for_train函数的第一个参数decoded_image变成了lambda表达式中的x
preprocess_for_train函数中后3个参数都被换成了具体的数值。

map方法返回的是一个新的数据集,可以直接继续调用其他高层操作。

上面介绍了框架下的tf.train.batchtf.train.shuffle_batch方法。在数据集框中,
shufflebatch操作由两个方法独立实现:

dataset = dataset.shuffle(buffer_size)
dataset = dataset.batch(batch_size) #将数组组合成batch

其中shuffle方法的参数buffer_ size 等效于tf. train.shuffle_ batchmin_after_ dequeue 参数。shuffle 算法在内部使用一个缓冲区中保存buffer_size 条数据,每读入一条新数据时,从这个缓冲区中随机选择一条数据进行输出。
缓冲区的大小越大,随机的性能越好,但占用的内存也越多。

batch 方法的参数batch_size 代表要输出的每个batch 由多少条数据组成。如果数据集中包含多个张量,那么batch 操作将对每一个张量分开进行。举例而言,如果数据集中的每一个数据是imagelabel两个张量,其中image的维度是[300,300],label的维度是[],batch_size是128,那么经过batch操作后的数据集的每一个输出将包含两个维度分别是[128,300,300]和[128]的张量。

repeat 是另一个常用的操作方法。这个方法将数据集中的数据复制多份,其中每一份数据被称为一个epoch 。

dataset = dataset.repeat(N) # 将数据重复N份

需要指出的是,如果数据集在repeat 前己经进行了shuffle操作,输出的每个epoch 中随机shuffle 的结果并不会相同。
例如,如果输入数据是[1,2,3], shuffle 后输出的第一个epoch是[2,1,3],而第二个epoch 则有可能是[3,2,1] 。repeatmapshufflebatch 等操作一样,
都只是计算图中的一个计算节点。

repeat 只代表重复相同的处理过程,并不会记录前一epoch的处理结果。

以下例子将这些方法组合起来:

import tensorflow as tf# 列举输入文件。训练和测试使用不同的数据
train_files = tf.train.match_filenames_once("output.tfrecords")
test_files = tf.train.match_filenames_once("output_test.tfrecords")# 解析一个TFRecord的方法
def parser(record):features = tf.parse_single_example(record,features={'image':tf.FixedLenFeature([],tf.string),'label':tf.FixedLenFeature([],tf.int64),'height':tf.FixedLenFeature([],tf.int64),'width':tf.FixedLenFeature([],tf.int64),'channels':tf.FixedLenFeature([],tf.int64),})# 从原始图像数据解析出像素矩阵,并根据尺寸还原图像decoded_image = tf.decode_raw(features['image'],tf.uint8)decoded_image.set_shape([features['heigth'],features['width'],features['channels'])label = features['label']return decoded_image, label image_size = 299          # 定义神经网络输入层图片的大小。
batch_size = 100          # 定义组合数据batch的大小。
shuffle_buffer = 10000   # 定义随机打乱数据时buffer的大小。# 定义读取训练数据的数据集。
dataset = tf.data.TFRecordDataset(train_files)
dataset = dataset.map(parser)# 对数据依次间预处理、shuffle和batch操作
dataset = dataset.map(lambda image, label: (preprocess_for_train(image, image_size, image_size,None), label))
dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)# 重复NUM_EPOCHS个epoch,指的是整个数据集重复的次数
NUM_EPOCHS = 10
dataset = dataset.repeat(NUM_EPOCHS)# 定义数据集迭代器。
# 虽然定义数据集时没有直接使用placeholder来提供文件地址,但是
# tf.train.match_filenames_once方法得到的结果和placeholder的机制类似
# 也需要初始化,所以这里使用的是make_initializable_iterator
iterator = dataset.make_initializable_iterator()
image_batch, label_batch = iterator.get_next()# 定义神经网络的结构以及优化过程
learning_rate = 0.01
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)# 定义测试用的Dataset
# 测试数据的Dataset不需要经过随机翻转等预处理操作
# 也不要打乱顺序和重复多个epoch
# 这里使用和训练数据相同的parse进行解析
test_dataset = tf.data.TFRecordDataset(test_files)
test_dataset = test_dataset.map(parser).map(lambda image, label: (tf.image.resize_images(image, [image_size,image_size]), label))
test_dataset = test_dataset.batch(batch_size)# 定义测试数据上的迭代器。
test_iterator = test_dataset.make_initializable_iterator()
test_image_batch, test_label_batch = test_iterator.get_next()# 定义测试数据上的预测结果。
test_logit = inference(test_image_batch)
predictions = tf.argmax(test_logit, axis=-1, output_type=tf.int32)# 声明会话并运行神经网络的优化过程。
with tf.Session() as sess:  # 初始化变量。sess.run((tf.global_variables_initializer(),tf.local_variables_initializer()))# 初始化训练数据的迭代器。sess.run(iterator.initializer)# 循环进行训练,直到数据集完成输入、抛出OutOfRangeError错误。while True:try:sess.run(train_step)except tf.errors.OutOfRangeError:break# 初始化测试数据的迭代器。sess.run(test_iterator.initializer)# 获取预测结果。test_results = []test_labels = []while True:try:pred, label = sess.run([predictions, test_label_batch])test_results.extend(pred)test_labels.extend(label)except tf.errors.OutOfRangeError:break# 计算准确率
correct = [float(y == y_) for (y, y_) in zip (test_results, test_labels)]
accuracy = sum(correct) / len(correct)
print("Test accuracy is:", accuracy)

TensorFlow学习笔记——图像数据处理相关推荐

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  2. Win10:tensorflow学习笔记(4)

    前言 学以致用,以学促用.输出检验,完整闭环. 经过前段时间的努力,已经在电脑上搭好了深度学习系统,接下来就要开始跑程序了,将AI落地了. 安装win10下tensforlow 可以参照之前的例子:w ...

  3. tensorflow学习笔记1

    tensorflow学习笔记1 本文主要记录我在慕课上观看北大曹建老师的<人工智能实践:Tensorflow笔记>,链接:https://www.icourse163.org/course ...

  4. tensorflow学习笔记(八):LSTM手写体(MNIST)识别

    文章目录 一.LSTM简介 二.主要函数 三.LSTM手写体(MNIST)识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.LSTM简介 LSTM是一种特殊的RNN,很好的解决了RNN中 ...

  5. [TensorFlow 学习笔记-06]激活函数(Activation Function)

    [版权说明] TensorFlow 学习笔记参考:  李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇  顾思宇 著 TensorFlow实战Googl ...

  6. [TensorFlow 学习笔记-04]卷积函数之tf.nn.conv2d

    [版权说明] TensorFlow 学习笔记参考: 李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇  顾思宇 著 TensorFlow实战Google ...

  7. tensorflow学习笔记(十):GAN生成手写体数字(MNIST)

    文章目录 一.GAN原理 二.项目实战 2.1 项目背景 2.2 网络描述 2.3 项目实战 一.GAN原理 生成对抗网络简称GAN,是由两个网络组成的,一个生成器网络和一个判别器网络.这两个网络可以 ...

  8. tensorflow学习笔记(七):CNN手写体(MNIST)识别

    文章目录 一.CNN简介 二.主要函数 三.CNN的手写体识别 1.MNIST数据集简介 2.网络描述 3.项目实战 一.CNN简介 一般的卷积神经网络由以下几个层组成:卷积层,池化层,非线性激活函数 ...

  9. TensorFlow学习笔记:Retrain Inception_v3(一)

    转:http://www.jianshu.com/p/613c3b08faea 0. 概要 最新的物体识别模型可能含有数百万个参数,将耗费几周的时间去完全训练.因此我们采用迁移学习的方法,在已经训练好 ...

  10. Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题

    Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 参考文章: (1)Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 (2)http ...

最新文章

  1. 吴恩达机器学习入门 2018 高清视频公开,还有习题解答和课程拓展,网友:找不到理由不学!...
  2. 关于MySQLdb连接数据的使用(插入数据——使用前端页面的方式进行可视化)
  3. python ndarray
  4. css3的3d起步——分享
  5. android 后台Activity移到前台
  6. Android使用VideoView播放本地视频及网络视频Demo
  7. 采用TCPListener/TCPClient实现图片传输
  8. 华清远见智能家居ppt_怀揣梦想,一路前行——西安华清与西安培华实训集
  9. 快速构建网站或移动端页面:关于Bootstrap的学习笔记
  10. Java基础---集合的概述---list、set、map、泛型、collections
  11. 【英语学习】【医学】【解剖学】Atlas of Human Anatomy (6e) 的目录
  12. 从零开始学Go之并发(四):互斥
  13. vs2017项目配置
  14. Java相关资源下载路径
  15. jquery 插件 分析
  16. SQL:postgresql中,将geometry转换为geojson数据
  17. 六、容器(高琪java300集+java从入门到精通笔记)
  18. 2022小米运维开发笔试1
  19. vue开发PC端响应式项目
  20. VB中Byval与Byref的区别。

热门文章

  1. 2014年第五届蓝桥杯国赛试题(JavaA组)
  2. 关于Python3的namedtuple问题
  3. 2. Linear Model
  4. [中等]寻找缺失的数
  5. debian安装中文字体
  6. Expression Blend学习四控件-按钮
  7. Cookie编码解码
  8. 从北京站到天通苑怎么走,
  9. 【LOJ】#2532. 「CQOI2018」社交网络
  10. 牛客网 牛客小白月赛2 H.武-最短路(Dijkstra)