• CIFAR
  • 代码组织
  • 代码分析
  • cifar10_trainpy
  • cifar10py
  • cifar10_evalpy
  • Reference

根据TensorFlow 1.2.1,改了官方版本的报错。

CIFAR

想了解更多信息请参考CIFAR-10 page,以及Alex Krizhevsky写的技术报告

  • 相关核心数学对象,如卷积、修正线性激活、最大池化以及局部响应归一化;
  • 训练过程中一些网络行为的可视化,这些行为包括输入图像、损失情况、网络行为的分布情况以及梯度;
  • 算法学习参数的移动平均值的计算函数,以及在评估阶段使用这些平均值提高预测性能;
  • 实现了一种机制,使得学习率随着时间的推移而递减;
  • 为输入数据设计预存取队列,将磁盘延迟和高开销的图像预处理操作与模型分离开来处理;

代码组织

文件 作用
cifar10_input.py 读取本地CIFAR-10的二进制文件格式的内容。
cifar10.py 建立CIFAR-10的模型。
cifar10_train.py 在CPU或GPU上训练CIFAR-10的模型。
cifar10_multi_gpu_train.py 在多GPU上训练CIFAR-10的模型。
cifar10_eval.py 评估CIFAR-10模型的预测性能。

代码分析

由于TensorFlow 1.0有些版本改动,导致新版本和以前代码不兼容,具体bug解决方法见:TensorFlow CIFAR-10训练例子报错解决

下面是其官方的训练效果。10万步的准确率为86%左右,所以其实并不用训练到100k。

System        | Step Time (sec/batch)  |     Accuracy
------------------------------------------------------------------
1 Tesla K20m  | 0.35-0.60              | ~86% at 60K steps  (5 hours)
1 Tesla K40m  | 0.25-0.35              | ~86% at 100K steps (4 hours)
  • 1
  • 2
  • 3
  • 4

cifar10_train.py

代码核心的train()函数,如下:

def train():"""Train CIFAR-10 for a number of steps."""with tf.Graph().as_default():global_step = tf.Variable(0, trainable=False)# Get images and labels for CIFAR-10.# 输入图像的预处理,包括亮度、对比度、图像翻转等操作images, labels = cifar10.distorted_inputs()# Build a Graph that computes the logits predictions from the# inference model.logits = cifar10.inference(images)# Calculate loss.loss = cifar10.loss(logits, labels)# Build a Graph that trains the model with one batch of examples and# updates the model parameters.train_op = cifar10.train(loss, global_step)# Create a saver.saver = tf.train.Saver(tf.all_variables())# Build the summary operation based on the TF collection of Summaries.summary_op = tf.summary.merge_all()# Build an initialization operation to run below.init = tf.initialize_all_variables()# Start running operations on the Graph.sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))sess.run(init)# Start the queue runners.tf.train.start_queue_runners(sess=sess)summary_writer = tf.summary.FileWriter(FLAGS.train_dir,graph_def=sess.graph_def)# 按照设置的迭代次数迭代for step in xrange(FLAGS.max_steps):start_time = time.time()_, loss_value = sess.run([train_op, loss])duration = time.time() - start_timeassert not np.isnan(loss_value), 'Model diverged with loss = NaN'# 每10个输入数据显示次step,loss,时间等运行数据if step % 10 == 0:num_examples_per_step = FLAGS.batch_sizeexamples_per_sec = num_examples_per_step / durationsec_per_batch = float(duration)format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ''sec/batch)')print(format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))# 每100个输入数据将网络的状况体现在summary里if step % 100 == 0:summary_str = sess.run(summary_op)summary_writer.add_summary(summary_str, step)# Save the model checkpoint periodically.# 每1000个输入数据保存次模型if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')saver.save(sess, checkpoint_path, global_step=step)def main(argv=None):  # pylint: disable=unused-argument# 检查目录下是否有数据,没有则下载。cifar10.maybe_download_and_extract()# 删除训练日志。if gfile.Exists(FLAGS.train_dir):gfile.DeleteRecursively(FLAGS.train_dir)gfile.MakeDirs(FLAGS.train_dir)# 训练train()if __name__ == '__main__':# 处理flag解析,并执行main函数。tf.app.run()
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82

其中的distorded_inputs函数,对训练集合进行随机的一些操作,包括颠倒,随机翻转等,以保证包含验证集中情况。在cifar10_input.py中,具体代码如下:

def distorted_inputs(data_dir, batch_size):"""Construct distorted input for CIFAR training using the Reader ops.Args:data_dir: Path to the CIFAR-10 data directory.batch_size: Number of images per batch.Returns:images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.labels: Labels. 1D tensor of [batch_size] size."""# 训练集合具有更多随机的操作,包括颠倒,随机翻转,以保证包含验证集中情况。filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)for i in xrange(1, 6)]for f in filenames:if not gfile.Exists(f):raise ValueError('Failed to find file: ' + f)# Create a queue that produces the filenames to read.filename_queue = tf.train.string_input_producer(filenames)# Read examples from files in the filename queue.read_input = read_cifar10(filename_queue)reshaped_image = tf.cast(read_input.uint8image, tf.float32)height = IMAGE_SIZEwidth = IMAGE_SIZE# Image processing for training the network. Note the many random# distortions applied to the image.# 步骤1:随机截取一个以[高,宽]为大小的图矩阵。distorted_image = tf.random_crop(reshaped_image, [height, width, 3])# 步骤2:随机颠倒图片的左右。概率为50%distorted_image = tf.image.random_flip_left_right(distorted_image)# Because these operations are not commutative, consider randomizing# randomize the order their operation.#  步骤3:随机改变图片的亮度以及色彩对比。distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)# Subtract off the mean and divide by the variance of the pixels.float_image = tf.image.per_image_standardization(distorted_image)# Ensure that the random shuffling has good mixing properties.# queue里有了不少于40%的数据的时候训练才能开始min_fraction_of_examples_in_queue = 0.4min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *min_fraction_of_examples_in_queue)print('Filling queue with %d CIFAR images before starting to train. ''This will take a few minutes.' % min_queue_examples)# Generate a batch of images and labels by building up a queue of examples.return _generate_image_and_label_batch(float_image, read_input.label,min_queue_examples, batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59

而对于验证集,不需要上述操作,所以使用inputs()函数,基本和上面相似,不赘述了,在cifar10_input.py中,具体代码如下:

def inputs(eval_data, data_dir, batch_size):"""Construct input for CIFAR evaluation using the Reader ops.Args:eval_data: bool, indicating if one should use the train or eval data set.data_dir: Path to the CIFAR-10 data directory.batch_size: Number of images per batch.Returns:images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.labels: Labels. 1D tensor of [batch_size] size."""if not eval_data:filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)for i in xrange(1, 6)]num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAINelse:filenames = [os.path.join(data_dir, 'test_batch.bin')]num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVALfor f in filenames:if not gfile.Exists(f):raise ValueError('Failed to find file: ' + f)# Create a queue that produces the filenames to read.filename_queue = tf.train.string_input_producer(filenames)# Read examples from files in the filename queue.read_input = read_cifar10(filename_queue)reshaped_image = tf.cast(read_input.uint8image, tf.float32)height = IMAGE_SIZEwidth = IMAGE_SIZE# Image processing for evaluation.# Crop the central [height, width] of the image.resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,width, height)# Subtract off the mean and divide by the variance of the pixels.float_image = tf.image.per_image_standardization(resized_image)# Ensure that the random shuffling has good mixing properties.min_fraction_of_examples_in_queue = 0.4min_queue_examples = int(num_examples_per_epoch *min_fraction_of_examples_in_queue)# Generate a batch of images and labels by building up a queue of examples.return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

cifar10.py

这一部分是训练CNN模型的代码,inference()函数主要是定义CNN结构,在MINIST中已经详细解释过,参考tensorFlow搭建CNN-mnist上手:

def inference(images):"""Build the CIFAR-10 model.Args:images: Images returned from distorted_inputs() or inputs().Returns:Logits."""# We instantiate all variables using tf.get_variable() instead of# tf.Variable() in order to share variables across multiple GPU training runs.# If we only ran this model on a single GPU, we could simplify this function# by replacing all instances of tf.get_variable() with tf.Variable().## conv1with tf.variable_scope('conv1') as scope:kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64],stddev=1e-4, wd=0.0)conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))bias = tf.nn.bias_add(conv, biases)conv1 = tf.nn.relu(bias, name=scope.name)_activation_summary(conv1)# pool1pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME', name='pool1')# norm1norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm1')# conv2with tf.variable_scope('conv2') as scope:kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64],stddev=1e-4, wd=0.0)conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))bias = tf.nn.bias_add(conv, biases)conv2 = tf.nn.relu(bias, name=scope.name)_activation_summary(conv2)# norm2norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm2')# pool2pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],strides=[1, 2, 2, 1], padding='SAME', name='pool2')# local3with tf.variable_scope('local3') as scope:# Move everything into depth so we can perform a single matrix multiply.dim = 1for d in pool2.get_shape()[1:].as_list():dim *= dreshape = tf.reshape(pool2, [FLAGS.batch_size, dim])weights = _variable_with_weight_decay('weights', shape=[dim, 384],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)_activation_summary(local3)# local4with tf.variable_scope('local4') as scope:weights = _variable_with_weight_decay('weights', shape=[384, 192],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)_activation_summary(local4)# softmax, i.e. softmax(WX + b)with tf.variable_scope('softmax_linear') as scope:weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],stddev=1 / 192.0, wd=0.0)biases = _variable_on_cpu('biases', [NUM_CLASSES],tf.constant_initializer(0.0))softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)_activation_summary(softmax_linear)return softmax_linear
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78

训练CNN,train()函数:

def train(total_loss, global_step):"""Train CIFAR-10 model.Create an optimizer and apply to all trainable variables. Add movingaverage for all trainable variables.Args:total_loss: Total loss from loss().global_step: Integer Variable counting the number of training stepsprocessed.Returns:train_op: op for training."""# Variables that affect learning rate.num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_sizedecay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)# Decay the learning rate exponentially based on the number of steps.lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,decay_steps,LEARNING_RATE_DECAY_FACTOR,staircase=True)tf.summary.scalar('learning_rate', lr)# Generate moving averages of all losses and associated summaries.loss_averages_op = _add_loss_summaries(total_loss)# Compute gradients.# control dependencies的运用。这里只有loss_averages_op完成了# 我们才会进行gradient descent的优化。with tf.control_dependencies([loss_averages_op]):opt = tf.train.GradientDescentOptimizer(lr)grads = opt.compute_gradients(total_loss)# Apply gradients.apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)# Add histograms for trainable variables.for var in tf.trainable_variables():tf.summary.histogram(var.op.name, var)# Add histograms for gradients.for grad, var in grads:if grad is not None:tf.summary.histogram(var.op.name + '/gradients', grad)# Track the moving averages of all trainable variables.variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables())with tf.control_dependencies([apply_gradient_op, variables_averages_op]):train_op = tf.no_op(name='train')return train_op
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56

cifar10_eval.py

然后是评估函数evaluate():

def evaluate():"""Eval CIFAR-10 for a number of steps."""with tf.Graph().as_default():# Get images and labels for CIFAR-10.eval_data = FLAGS.eval_data == 'test'images, labels = cifar10.inputs(eval_data=eval_data)# Build a Graph that computes the logits predictions from the# inference model.logits = cifar10.inference(images)# Calculate predictions.top_k_op = tf.nn.in_top_k(logits, labels, 1)# Restore the moving average version of the learned variables for eval.variable_averages = tf.train.ExponentialMovingAverage(cifar10.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# Build the summary operation based on the TF collection of Summaries.summary_op = tf.summary.merge_all()graph_def = tf.get_default_graph().as_graph_def()summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,graph_def=graph_def)while True:eval_once(saver, summary_writer, top_k_op, summary_op)if FLAGS.run_once:breaktime.sleep(FLAGS.eval_interval_secs)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

主线就是这样,具体代码见:CIFAR_TensorFlow

Reference

卷积神经网络

原文地址: http://blog.csdn.net/shine19930820/article/details/76608648

TensorFlow-CIFAR10 CNN代码分析相关推荐

  1. tensorflow笔记:多层CNN代码分析

    tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代 ...

  2. Tensorflow实现DeepFM(代码分析)

    参考: 源码:https://github.com/ChenglongChen/tensorflow-DeepFM 原文下载:https://arxiv.org/abs/1703.04247 参看原文 ...

  3. tensorflow笔记:多层LSTM代码分析

    tensorflow笔记系列:  (一) tensorflow笔记:流程,概念和简单代码注释  (二) tensorflow笔记:多层CNN代码分析  (三) tensorflow笔记:多层LSTM代 ...

  4. [深度学习]-基于tensorflow的CNN和RNN-LSTM文本情感分析对比

    基于tensorflow的CNN和LSTM文本情感分析对比 1. 背景介绍 2. 数据集介绍 2.0 wordsList.npy 2.1 wordVectors.npy 2.2 idsMatrix.n ...

  5. python cnn代码详解图解_基于TensorFlow的CNN实现Mnist手写数字识别

    本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一.CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5* ...

  6. Resnet论文解读与TensorFlow代码分析

    残差网络Resnet论文解读 1.论文解读 博客地址:https://blog.csdn.net/loveliuzz/article/details/79117397 2.理解ResNet结构与Ten ...

  7. GraphSAGE NIPS 2017 代码分析(Tensorflow版)

    文章目录 数据集 ppi数据集信息 toy-ppi-G.json 图的信息 toy-ppi-class_map.json toy-ppi-id_map.json toy-ppi-walks.txt t ...

  8. CNN网络实现手写数字(MNIST)识别 代码分析

    CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...

  9. 理解ResNet结构与TensorFlow代码分析

    该博客主要以TensorFlow提供的ResNet代码为主,但是我并不想把它称之为代码解析,因为代码和方法,实践和理论总是缺一不可.  github地址,其中: resnet_model.py为残差网 ...

最新文章

  1. OpenCV-Python,计算机视觉开发利器
  2. 矩阵是怎样变换向量的
  3. python中求二维数组元素之和_乘以二维数组元素和和
  4. 正则表达式替换排除特定情况
  5. Akamai “三驾马车”,如何应对疫情后新场景形态下的新考验?
  6. 信息学奥赛C++语言:最大数max(x,y,z)
  7. STC51-l2C总线
  8. mysql游标是什么特性_[转]MySQL游标特性
  9. php恒等符,PHP基础-运算符
  10. java分层ppt_java程序设计第10章图形用户界面.ppt
  11. js:常用的3种弹出提示框(alert、confirm、prompt)
  12. 如何删除微软拼音输入法2003
  13. 单层石墨烯工业化量产科研成果及工业化量产基地落地
  14. 排球分组循环交叉编排_同学!中国海洋大学第一届排球联赛等你来战!
  15. MMC子系统之SDIO卡驱动
  16. ballerina 学习 三十一 扩展开发(二)
  17. flasgger手写phpwind接口文档
  18. 关于浏览器访问servlet404异常
  19. windows7 图形界面远程 centos6.5
  20. 原生js弹框Alert插件

热门文章

  1. 关于积累-accumulation
  2. [转]游戏中各种性能优化方法(不断更新)
  3. asp.net学习资源汇总
  4. CNN卷积神经网络推导和实现
  5. R语言:ggplot2
  6. android 常见分辨率(mdpi、hdpi 、xhdpi、xxhdpi )及屏幕适配注意事
  7. 云炬金融每日一题20210906
  8. mysql和oracle的锁_关于数据库行锁与表锁的认识
  9. go语言中的匿名函数
  10. 给定两个字符串形式的非负整数 num1 和num2 ,计算它们的和。