8、源码分析

1、入口函数

要训练tensorflow官方的cifar10模型,只要执行python cifar10_train.py即可,所以入口函数应该是在cifar10_train.py里。找到

def main(argv=None):  # pylint: disable=unused-argumentcifar10.maybe_download_and_extract()if tf.gfile.Exists(FLAGS.train_dir):tf.gfile.DeleteRecursively(FLAGS.train_dir)tf.gfile.MakeDirs(FLAGS.train_dir)train()if __name__ == '__main__':tf.app.run()

前面是下载和解压cifar10数据集的功能,不是重点,接着看train()函数

2、train()函数

def train():"""Train CIFAR-10 for a number of steps."""with tf.Graph().as_default():global_step = tf.train.get_or_create_global_step()with tf.device('/cpu:0'):images, labels = cifar10.distorted_inputs()
...

函数cifar10.distorted_inputs()是获取CIFAR10数据集的图片数据和对应的标签的,我们接着去看cifar10.py里的distorted_inputs()函数干了什么

3、cifar10.distorted_inputs()函数

def distorted_inputs():if not FLAGS.data_dir:raise ValueError('Please supply a data_dir')data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')images, labels = cifar10_input.distorted_inputs(data_dir=data_dir,batch_size=FLAGS.batch_size)if FLAGS.use_fp16:images = tf.cast(images, tf.float16)labels = tf.cast(labels, tf.float16)return images, labels

可以看到其将CIFAR10数据集的路径和batch大小传到cifar10_input.distorted_inputs函数,cifar10_input.distorted_inputs函数再返回图片和标签的数据,我们接着看cifar10_input.distorted_inputs函数

4、cifar10_input.distorted_inputs(data_dir, batch_size)函数

def distorted_inputs(data_dir, batch_size):filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)for i in xrange(1, 6)]for f in filenames:if not tf.gfile.Exists(f):raise ValueError('Failed to find file: ' + f)# Create a queue that produces the filenames to read.filename_queue = tf.train.string_input_producer(filenames)

因为数据集图片和标签的数据实际上放在data_batch_1.bin~data_batch_5.bin里,所以先将其放到数组filenames里,然后传给tf.train.string_input_producer函数,而tf.train.string_input_producer函数就是创建文件名队列的。接着往下看,

# Read examples from files in the filename queue.read_input = read_cifar10(filename_queue)reshaped_image = tf.cast(read_input.uint8image, tf.float32)

read_cifar10函数实际上跟CIFAR10图像识别(上)那节写的get_record函数的作用类似,该函数返回一个类,而标签数据存在read_input.label中,图片数据存在read_input.uint8image中,再经过tf.cast将数据转成float32型,接着看,

height = IMAGE_SIZE #24
width = IMAGE_SIZE #24distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
distorted_image = tf.image.random_flip_left_right(distorted_image)
distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)
distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)
float_image = tf.image.per_image_standardization(distorted_image)
# Set the shapes of tensors.
float_image.set_shape([height, width, 3])
read_input.label.set_shape([1])
min_fraction_of_examples_in_queue = 0.4
min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *min_fraction_of_examples_in_queue)return _generate_image_and_label_batch(float_image, read_input.label,min_queue_examples, batch_size,shuffle=True)

上面操作都是对图片数据进行数据增强操作,将源图片随机切割成24*24图片,随机左右翻转等等操作,再转成张量[24,24,3]的形式。接着看看_generate_image_and_label_batch函数做了什么,

def _generate_image_and_label_batch(image, label, min_queue_examples,batch_size, shuffle):num_preprocess_threads = 16if shuffle:images, label_batch = tf.train.shuffle_batch([image, label],batch_size=batch_size,num_threads=num_preprocess_threads,capacity=min_queue_examples + 3 * batch_size,min_after_dequeue=min_queue_examples)else:images, label_batch = tf.train.batch([image, label],batch_size=batch_size,num_threads=num_preprocess_threads,capacity=min_queue_examples + 3 * batch_size)# Display the training images in the visualizer.tf.summary.image('images', images)return images, tf.reshape(label_batch, [batch_size])

主要看tf.train.shuffle_batch函数,该函数主要输出一个打乱顺序排列的样本batch,[image, label]表示样本和样本标签,batch_size是样本batch长度,capacity是队列的容量,num_threads表示开启多少个线程,min_after_dequeue表示出队后,队列中最少要有min_after_dequeue个数据。所以可知,经过这些运算以后,得到的图片数据为一个四维张量[batch_size, height, width, 3],标签为一维张量[batch_size]。回到train()函数,继续往下看,

# Build a Graph that computes the logits predictions from the
# inference model.
logits = cifar10.inference(images)

这里cifar10.inference函数就是我们卷积模型的重点了,进去看看,

def inference(images):# conv1with tf.variable_scope('conv1') as scope:kernel = _variable_with_weight_decay('weights',shape=[5, 5, 3, 64],stddev=5e-2,wd=None)conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))pre_activation = tf.nn.bias_add(conv, biases)conv1 = tf.nn.relu(pre_activation, name=scope.name)_activation_summary(conv1)# pool1pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME', name='pool1')# norm1norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm1')# conv2with tf.variable_scope('conv2') as scope:kernel = _variable_with_weight_decay('weights',shape=[5, 5, 64, 64],stddev=5e-2,wd=None)conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))pre_activation = tf.nn.bias_add(conv, biases)conv2 = tf.nn.relu(pre_activation, name=scope.name)_activation_summary(conv2)# norm2norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm2')# pool2pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],strides=[1, 2, 2, 1], padding='SAME', name='pool2')# local3with tf.variable_scope('local3') as scope:# Move everything into depth so we can perform a single matrix multiply.reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1])dim = reshape.get_shape()[1].valueweights = _variable_with_weight_decay('weights', shape=[dim, 384],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)_activation_summary(local3)# local4with tf.variable_scope('local4') as scope:weights = _variable_with_weight_decay('weights', shape=[384, 192],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)_activation_summary(local4)# linear layer(WX + b),with tf.variable_scope('softmax_linear') as scope:weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],stddev=1/192.0, wd=None)biases = _variable_on_cpu('biases', [NUM_CLASSES],tf.constant_initializer(0.0))softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)_activation_summary(softmax_linear)return softmax_linear

可以看到,这个模型跟我们讲的两层卷积神经网络识别MNIST模型是类似的,经过第一层卷积层和池化层,第二层卷积层和池化层,再经过三层全连接层。接着往下看,

# Calculate loss.
loss = cifar10.loss(logits, labels)

这里就是计算损失函数了,进去看看,

def loss(logits, labels):labels = tf.cast(labels, tf.int64)cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='cross_entropy_per_example')cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')tf.add_to_collection('losses', cross_entropy_mean)# The total loss is defined as the cross entropy loss plus all of the weight# decay terms (L2 loss).return tf.add_n(tf.get_collection('losses'), name='total_loss')

其中,tf.nn.sparse_softmax_cross_entropy_with_logits函数是计算logits和labels的softmax交叉熵,再用tf.reduce_mean求均值,再用tf.add_n求和。回到cifar10_train继续往下看,

# Build a Graph that trains the model with one batch of examples and
# updates the model parameters.
train_op = cifar10.train(loss, global_step)

这个就是训练的函数,进去看看,

def train(total_loss, global_step):# Variables that affect learning rate.num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_sizedecay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)# Decay the learning rate exponentially based on the number of steps.lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,decay_steps,LEARNING_RATE_DECAY_FACTOR,staircase=True)tf.summary.scalar('learning_rate', lr)# Generate moving averages of all losses and associated summaries.loss_averages_op = _add_loss_summaries(total_loss)# Compute gradients.with tf.control_dependencies([loss_averages_op]):opt = tf.train.GradientDescentOptimizer(lr)grads = opt.compute_gradients(total_loss)# Apply gradients.apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)# Add histograms for trainable variables.for var in tf.trainable_variables():tf.summary.histogram(var.op.name, var)# Add histograms for gradients.for grad, var in grads:if grad is not None:tf.summary.histogram(var.op.name + '/gradients', grad)# Track the moving averages of all trainable variables.variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)with tf.control_dependencies([apply_gradient_op]):variables_averages_op = variable_averages.apply(tf.trainable_variables())return variables_averages_op

其中,tf.train.exponential_decay函数就是上一节提到的学习率的指数衰减法,设置学习率后,再使用梯度下降法tf.train.GradientDescentOptimizer优化损失,而grads = opt.compute_gradients(total_loss)和opt.apply_gradients(grads, global_step=global_step)函数,其实和前面用到的tf.train.Optimizer.minimize一样的,只不过minimize合并了这两个函数。tf.train.ExponentialMovingAverage函数是使用滑动平均法更新参数,在回到cifar10_train,继续往下看,

class _LoggerHook(tf.train.SessionRunHook):"""Logs loss and runtime."""def begin(self):self._step = -1self._start_time = time.time()def before_run(self, run_context):self._step += 1return tf.train.SessionRunArgs(loss)  # Asks for loss value.def after_run(self, run_context, run_values):if self._step % FLAGS.log_frequency == 0:current_time = time.time()duration = current_time - self._start_timeself._start_time = current_timeloss_value = run_values.resultsexamples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / durationsec_per_batch = float(duration / FLAGS.log_frequency)format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ''sec/batch)')print (format_str % (datetime.now(), self._step, loss_value,examples_per_sec, sec_per_batch))with tf.train.MonitoredTrainingSession(checkpoint_dir=FLAGS.train_dir,hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),tf.train.NanTensorHook(loss),_LoggerHook()],config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) as mon_sess:while not mon_sess.should_stop():mon_sess.run(train_op)

上面就是真的开始计算了,这里不用之前的tf.Session()会话来计算,而是用tf.train.MonitoredTrainingSession,好处是,这个会话能自动保存和载入模型的文件,默认每10分钟保存一次,就不需要我们自己写保存代码了。checkpoint_dir传入保存的路径,tf.train.StopAtStepHook函数指定训练多少步后就停止,tf.train.NanTensorHook用于监控loss,如果loss是Nan,则停止训练。_LoggerHook则用于打印时间、步数、损失值等,打印格式如下:

2018-05-22 15:57:15.755749: step 3370, loss = 1.22 (475.8 examples/sec; 0.269 sec/batch)

2018-05-22 15:57:18.434435: step 3380, loss = 1.12 (477.8 examples/sec; 0.268 sec/batch)

2018-05-22 15:57:21.054679: step 3390, loss = 1.30 (488.5 examples/sec; 0.262 sec/batch)

2018-05-22 15:57:23.721501: step 3400, loss = 1.22 (480.0 examples/sec; 0.267 sec/batch)

2018-05-22 15:57:26.337015: step 3410, loss = 1.21 (489.4 examples/sec; 0.262 sec/batch)


总结:

这个模型相比前面两个训练MNIST的模型来说稍微复杂一点,但是其过程还是差不多的:

1、定义神经网络的结构和前向传播输出结果

2、定义损失函数以及选择反向传播优化的算法

3、生成会话并且在训练数据上反复运行反向传播优化算法

TensorFlow精进之路(六):CIFAR-10图像是被(下)相关推荐

  1. tensorflow精进之路(十九)——python3网络爬虫(下)

    1.概述 这一节,我们将在百度图片中爬取需要训练的图片数据:猪.蛇.狗.大象.老虎. 2.打开待爬取网页 打开百度图片首页: http://image.baidu.com/ 在搜索框中输入" ...

  2. TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来

    1.概述 自从开了专栏<TensorFlow精进之路>关于对TensorFlow的整理思路更加清晰.上两篇讲到Softmax回归模型和两层卷积神经网络模型训练MNIST,虽然使用神经网络能 ...

  3. TensorFlow精进之路(十二):随时间反向传播BPTT

    1.概述 上一节介绍了TensorFlow精进之路(十一):反向传播BP,这一节就简单介绍一下BPTT. 2.网络结构 RNN正向传播可以用上图表示,这里忽略偏置. 上图中, x(1:T)表示输入序列 ...

  4. TensorFlow精进之路(九):TensorFlow编程基础

    1.概述 卷积部分的知识点在博客:TensorFlow精进之路(三):两层卷积神经网络模型将MNIST未识别对的图片筛选出来已经写过,所以不再赘述.这一节简单聊聊tensorflow的编程基础. 2. ...

  5. tensorflow精进之路(二十)——使用slim模型库训练自己的数据

    1.概述 上一节,我们使用python3爬取了百度图片的一些图片数据,这一节,我们就使用这些爬取下来的图片,训练我们自己的模型,用来识别猪.蛇.狗.大象.老虎这五种动物.在这里吐嘲一下百度图片搜索结果 ...

  6. TensorFlow精进之路(四):CIFAR-10图像识别(上)

    1.CIFAR-10数据集简介 CIFAR-10数据集包含10个类别的RGB彩色图片.图片尺寸为32×32,这十个类别包括:飞机.汽车.鸟.猫.鹿.狗.蛙.马.船.卡车.一共有50000张训练图片和1 ...

  7. TensorFlow精进之路(十四):RNN训练MNIST数据集

    1.概述 前面介绍了RNN,这一节就用tensorflow的RNN来训练MNIST数据集,看看准确率如何. 2.代码实现 2.1.导入数据集 # encoding:utf-8 import tenso ...

  8. tensorflow精进之路(二十六)——人脸识别(上)(MTCNN原理)

    1.概述 换了个固态硬盘,本想装最新的系统mint 19,谁知道却是个坑,NVIDIA驱动和CUDA工具老是装不上去,各种问题,折腾了几天,还是用回了原来的系统.不过,这次软件改了一下,使用了pyth ...

  9. TensorFlow精进之路(七):关于两层卷积神经网络对CIFAR-10图像的识别

    1.概述 在前面已经对官方的CIFAR10图像识别模块进行分析,但如果只做到这一步感觉还是不够,没能做到举一反三以及对之前学的知识的巩固,所以这一节,我打算结合之前学的双层卷积神经网络自己写一个dem ...

最新文章

  1. [UWP]实现一个轻量级的应用内消息通知控件
  2. 30岁找不到工作很绝望_计算机为绝望的新编码员工作方式的快速指南
  3. Sublime Text保存文件时自动去掉行末空格
  4. 基2频率抽取实现FFT的Verilog程序
  5. mac云显卡服务器_重磅!NVIDIA GeForce NOW登陆Mac:云显卡玩吃鸡逆天
  6. 利用mvc 模型绑定验证方法验证普通类对象数据是否合法
  7. Spring cloud整合zookeeper
  8. 解决安装pytorch慢的方法(pip安装)
  9. mysql数据库之忘记root密码
  10. Python之访问set
  11. C语言实现2048游戏(Windows版)
  12. 数据结构考研:线性表,顺序表,有序表,链表,数组的概念的区别与联系(软件工程/计算机/王道论坛)
  13. NVIDIA Jetson官网资料整理
  14. Excel表格中数据比对和查找的几种技巧
  15. 免费免安装!3s 获取云数据库,MySQL,Mongo、Redis 全都有!
  16. php排行榜系统,cms排行_PHP CMS系统排行榜
  17. Android类似微信详细地址选择(高德地图)
  18. 如何查看ORACLE各个表空间的使用情况
  19. 对数坐标归一化_归一化方法 Normalization Method
  20. SeleniumLibrary4.5.0 关键字详解(五)

热门文章

  1. 《软件工程(第4版?修订版)》—第2章2.9节本章对研究人员的意义
  2. 《Xcode实战开发》——1.1节下载
  3. 心得复述知识体系:《强化学习》中的蒙特卡洛方法 Monte Carlo Methods in Reinforcement Learning
  4. linux sql server调优,SQL SERVER性能优化(转)
  5. EDA实验课课程笔记(一)——linux操作系统及linux下的基本指令
  6. Android应用及应用管理
  7. C#自动切换Windows窗口程序,如何才能调出主窗口?
  8. 咬肌边上有个滑动疙瘩_猫逆子一个:摔杯子咬箱子,时常给我甩脸子!
  9. ROS Nodelet使用
  10. Jquery获取iframe中的元素