TensorFlow-CIFAR10 CNN代码分析
- CIFAR
- 代码组织
- 代码分析
- cifar10_trainpy
- cifar10py
- cifar10_evalpy
- Reference
根据TensorFlow 1.2.1,改了官方版本的报错。
CIFAR
想了解更多信息请参考CIFAR-10 page,以及Alex Krizhevsky写的技术报告
- 相关核心数学对象,如卷积、修正线性激活、最大池化以及局部响应归一化;
- 训练过程中一些网络行为的可视化,这些行为包括输入图像、损失情况、网络行为的分布情况以及梯度;
- 算法学习参数的移动平均值的计算函数,以及在评估阶段使用这些平均值提高预测性能;
- 实现了一种机制,使得学习率随着时间的推移而递减;
- 为输入数据设计预存取队列,将磁盘延迟和高开销的图像预处理操作与模型分离开来处理;
代码组织
文件 | 作用 |
---|---|
cifar10_input.py | 读取本地CIFAR-10的二进制文件格式的内容。 |
cifar10.py | 建立CIFAR-10的模型。 |
cifar10_train.py | 在CPU或GPU上训练CIFAR-10的模型。 |
cifar10_multi_gpu_train.py | 在多GPU上训练CIFAR-10的模型。 |
cifar10_eval.py | 评估CIFAR-10模型的预测性能。 |
代码分析
由于TensorFlow 1.0有些版本改动,导致新版本和以前代码不兼容,具体bug解决方法见:TensorFlow CIFAR-10训练例子报错解决
下面是其官方的训练效果。10万步的准确率为86%左右,所以其实并不用训练到100k。
System | Step Time (sec/batch) | Accuracy
------------------------------------------------------------------
1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
- 1
- 2
- 3
- 4
cifar10_train.py
def train():"""Train CIFAR-10 for a number of steps."""with tf.Graph().as_default():global_step = tf.Variable(0, trainable=False)# Get images and labels for CIFAR-10.# 输入图像的预处理,包括亮度、对比度、图像翻转等操作images, labels = cifar10.distorted_inputs()# Build a Graph that computes the logits predictions from the# inference model.logits = cifar10.inference(images)# Calculate loss.loss = cifar10.loss(logits, labels)# Build a Graph that trains the model with one batch of examples and# updates the model parameters.train_op = cifar10.train(loss, global_step)# Create a saver.saver = tf.train.Saver(tf.all_variables())# Build the summary operation based on the TF collection of Summaries.summary_op = tf.summary.merge_all()# Build an initialization operation to run below.init = tf.initialize_all_variables()# Start running operations on the Graph.sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement))sess.run(init)# Start the queue runners.tf.train.start_queue_runners(sess=sess)summary_writer = tf.summary.FileWriter(FLAGS.train_dir,graph_def=sess.graph_def)# 按照设置的迭代次数迭代for step in xrange(FLAGS.max_steps):start_time = time.time()_, loss_value = sess.run([train_op, loss])duration = time.time() - start_timeassert not np.isnan(loss_value), 'Model diverged with loss = NaN'# 每10个输入数据显示次step,loss,时间等运行数据if step % 10 == 0:num_examples_per_step = FLAGS.batch_sizeexamples_per_sec = num_examples_per_step / durationsec_per_batch = float(duration)format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ''sec/batch)')print(format_str % (datetime.now(), step, loss_value,examples_per_sec, sec_per_batch))# 每100个输入数据将网络的状况体现在summary里if step % 100 == 0:summary_str = sess.run(summary_op)summary_writer.add_summary(summary_str, step)# Save the model checkpoint periodically.# 每1000个输入数据保存次模型if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')saver.save(sess, checkpoint_path, global_step=step)def main(argv=None): # pylint: disable=unused-argument# 检查目录下是否有数据,没有则下载。cifar10.maybe_download_and_extract()# 删除训练日志。if gfile.Exists(FLAGS.train_dir):gfile.DeleteRecursively(FLAGS.train_dir)gfile.MakeDirs(FLAGS.train_dir)# 训练train()if __name__ == '__main__':# 处理flag解析,并执行main函数。tf.app.run()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
其中的distorded_inputs函数,对训练集合进行随机的一些操作,包括颠倒,随机翻转等,以保证包含验证集中情况。在cifar10_input.py中,具体代码如下:
def distorted_inputs(data_dir, batch_size):"""Construct distorted input for CIFAR training using the Reader ops.Args:data_dir: Path to the CIFAR-10 data directory.batch_size: Number of images per batch.Returns:images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.labels: Labels. 1D tensor of [batch_size] size."""# 训练集合具有更多随机的操作,包括颠倒,随机翻转,以保证包含验证集中情况。filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)for i in xrange(1, 6)]for f in filenames:if not gfile.Exists(f):raise ValueError('Failed to find file: ' + f)# Create a queue that produces the filenames to read.filename_queue = tf.train.string_input_producer(filenames)# Read examples from files in the filename queue.read_input = read_cifar10(filename_queue)reshaped_image = tf.cast(read_input.uint8image, tf.float32)height = IMAGE_SIZEwidth = IMAGE_SIZE# Image processing for training the network. Note the many random# distortions applied to the image.# 步骤1:随机截取一个以[高,宽]为大小的图矩阵。distorted_image = tf.random_crop(reshaped_image, [height, width, 3])# 步骤2:随机颠倒图片的左右。概率为50%distorted_image = tf.image.random_flip_left_right(distorted_image)# Because these operations are not commutative, consider randomizing# randomize the order their operation.# 步骤3:随机改变图片的亮度以及色彩对比。distorted_image = tf.image.random_brightness(distorted_image,max_delta=63)distorted_image = tf.image.random_contrast(distorted_image,lower=0.2, upper=1.8)# Subtract off the mean and divide by the variance of the pixels.float_image = tf.image.per_image_standardization(distorted_image)# Ensure that the random shuffling has good mixing properties.# queue里有了不少于40%的数据的时候训练才能开始min_fraction_of_examples_in_queue = 0.4min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *min_fraction_of_examples_in_queue)print('Filling queue with %d CIFAR images before starting to train. ''This will take a few minutes.' % min_queue_examples)# Generate a batch of images and labels by building up a queue of examples.return _generate_image_and_label_batch(float_image, read_input.label,min_queue_examples, batch_size)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
而对于验证集,不需要上述操作,所以使用inputs()函数,基本和上面相似,不赘述了,在cifar10_input.py中,具体代码如下:
def inputs(eval_data, data_dir, batch_size):"""Construct input for CIFAR evaluation using the Reader ops.Args:eval_data: bool, indicating if one should use the train or eval data set.data_dir: Path to the CIFAR-10 data directory.batch_size: Number of images per batch.Returns:images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.labels: Labels. 1D tensor of [batch_size] size."""if not eval_data:filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)for i in xrange(1, 6)]num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAINelse:filenames = [os.path.join(data_dir, 'test_batch.bin')]num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVALfor f in filenames:if not gfile.Exists(f):raise ValueError('Failed to find file: ' + f)# Create a queue that produces the filenames to read.filename_queue = tf.train.string_input_producer(filenames)# Read examples from files in the filename queue.read_input = read_cifar10(filename_queue)reshaped_image = tf.cast(read_input.uint8image, tf.float32)height = IMAGE_SIZEwidth = IMAGE_SIZE# Image processing for evaluation.# Crop the central [height, width] of the image.resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,width, height)# Subtract off the mean and divide by the variance of the pixels.float_image = tf.image.per_image_standardization(resized_image)# Ensure that the random shuffling has good mixing properties.min_fraction_of_examples_in_queue = 0.4min_queue_examples = int(num_examples_per_epoch *min_fraction_of_examples_in_queue)# Generate a batch of images and labels by building up a queue of examples.return _generate_image_and_label_batch(float_image, read_input.label, min_queue_examples, batch_size)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
cifar10.py
这一部分是训练CNN模型的代码,inference()函数主要是定义CNN结构,在MINIST中已经详细解释过,参考tensorFlow搭建CNN-mnist上手:
def inference(images):"""Build the CIFAR-10 model.Args:images: Images returned from distorted_inputs() or inputs().Returns:Logits."""# We instantiate all variables using tf.get_variable() instead of# tf.Variable() in order to share variables across multiple GPU training runs.# If we only ran this model on a single GPU, we could simplify this function# by replacing all instances of tf.get_variable() with tf.Variable().## conv1with tf.variable_scope('conv1') as scope:kernel = _variable_with_weight_decay('weights', shape=[5, 5, 3, 64],stddev=1e-4, wd=0.0)conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))bias = tf.nn.bias_add(conv, biases)conv1 = tf.nn.relu(bias, name=scope.name)_activation_summary(conv1)# pool1pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],padding='SAME', name='pool1')# norm1norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm1')# conv2with tf.variable_scope('conv2') as scope:kernel = _variable_with_weight_decay('weights', shape=[5, 5, 64, 64],stddev=1e-4, wd=0.0)conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1))bias = tf.nn.bias_add(conv, biases)conv2 = tf.nn.relu(bias, name=scope.name)_activation_summary(conv2)# norm2norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,name='norm2')# pool2pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],strides=[1, 2, 2, 1], padding='SAME', name='pool2')# local3with tf.variable_scope('local3') as scope:# Move everything into depth so we can perform a single matrix multiply.dim = 1for d in pool2.get_shape()[1:].as_list():dim *= dreshape = tf.reshape(pool2, [FLAGS.batch_size, dim])weights = _variable_with_weight_decay('weights', shape=[dim, 384],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)_activation_summary(local3)# local4with tf.variable_scope('local4') as scope:weights = _variable_with_weight_decay('weights', shape=[384, 192],stddev=0.04, wd=0.004)biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)_activation_summary(local4)# softmax, i.e. softmax(WX + b)with tf.variable_scope('softmax_linear') as scope:weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES],stddev=1 / 192.0, wd=0.0)biases = _variable_on_cpu('biases', [NUM_CLASSES],tf.constant_initializer(0.0))softmax_linear = tf.add(tf.matmul(local4, weights), biases, name=scope.name)_activation_summary(softmax_linear)return softmax_linear
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
训练CNN,train()函数:
def train(total_loss, global_step):"""Train CIFAR-10 model.Create an optimizer and apply to all trainable variables. Add movingaverage for all trainable variables.Args:total_loss: Total loss from loss().global_step: Integer Variable counting the number of training stepsprocessed.Returns:train_op: op for training."""# Variables that affect learning rate.num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / FLAGS.batch_sizedecay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY)# Decay the learning rate exponentially based on the number of steps.lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE,global_step,decay_steps,LEARNING_RATE_DECAY_FACTOR,staircase=True)tf.summary.scalar('learning_rate', lr)# Generate moving averages of all losses and associated summaries.loss_averages_op = _add_loss_summaries(total_loss)# Compute gradients.# control dependencies的运用。这里只有loss_averages_op完成了# 我们才会进行gradient descent的优化。with tf.control_dependencies([loss_averages_op]):opt = tf.train.GradientDescentOptimizer(lr)grads = opt.compute_gradients(total_loss)# Apply gradients.apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)# Add histograms for trainable variables.for var in tf.trainable_variables():tf.summary.histogram(var.op.name, var)# Add histograms for gradients.for grad, var in grads:if grad is not None:tf.summary.histogram(var.op.name + '/gradients', grad)# Track the moving averages of all trainable variables.variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)variables_averages_op = variable_averages.apply(tf.trainable_variables())with tf.control_dependencies([apply_gradient_op, variables_averages_op]):train_op = tf.no_op(name='train')return train_op
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
cifar10_eval.py
def evaluate():"""Eval CIFAR-10 for a number of steps."""with tf.Graph().as_default():# Get images and labels for CIFAR-10.eval_data = FLAGS.eval_data == 'test'images, labels = cifar10.inputs(eval_data=eval_data)# Build a Graph that computes the logits predictions from the# inference model.logits = cifar10.inference(images)# Calculate predictions.top_k_op = tf.nn.in_top_k(logits, labels, 1)# Restore the moving average version of the learned variables for eval.variable_averages = tf.train.ExponentialMovingAverage(cifar10.MOVING_AVERAGE_DECAY)variables_to_restore = variable_averages.variables_to_restore()saver = tf.train.Saver(variables_to_restore)# Build the summary operation based on the TF collection of Summaries.summary_op = tf.summary.merge_all()graph_def = tf.get_default_graph().as_graph_def()summary_writer = tf.summary.FileWriter(FLAGS.eval_dir,graph_def=graph_def)while True:eval_once(saver, summary_writer, top_k_op, summary_op)if FLAGS.run_once:breaktime.sleep(FLAGS.eval_interval_secs)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
主线就是这样,具体代码见:CIFAR_TensorFlow
Reference
卷积神经网络
原文地址: http://blog.csdn.net/shine19930820/article/details/76608648
TensorFlow-CIFAR10 CNN代码分析相关推荐
- tensorflow笔记:多层CNN代码分析
tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代 ...
- Tensorflow实现DeepFM(代码分析)
参考: 源码:https://github.com/ChenglongChen/tensorflow-DeepFM 原文下载:https://arxiv.org/abs/1703.04247 参看原文 ...
- tensorflow笔记:多层LSTM代码分析
tensorflow笔记系列: (一) tensorflow笔记:流程,概念和简单代码注释 (二) tensorflow笔记:多层CNN代码分析 (三) tensorflow笔记:多层LSTM代 ...
- [深度学习]-基于tensorflow的CNN和RNN-LSTM文本情感分析对比
基于tensorflow的CNN和LSTM文本情感分析对比 1. 背景介绍 2. 数据集介绍 2.0 wordsList.npy 2.1 wordVectors.npy 2.2 idsMatrix.n ...
- python cnn代码详解图解_基于TensorFlow的CNN实现Mnist手写数字识别
本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一.CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5* ...
- Resnet论文解读与TensorFlow代码分析
残差网络Resnet论文解读 1.论文解读 博客地址:https://blog.csdn.net/loveliuzz/article/details/79117397 2.理解ResNet结构与Ten ...
- GraphSAGE NIPS 2017 代码分析(Tensorflow版)
文章目录 数据集 ppi数据集信息 toy-ppi-G.json 图的信息 toy-ppi-class_map.json toy-ppi-id_map.json toy-ppi-walks.txt t ...
- CNN网络实现手写数字(MNIST)识别 代码分析
CNN网络实现手写数字(MNIST)识别 代码分析(自学用) Github代码源文件 本文是学习了使用Pytorch框架的CNN网络实现手写数字(MNIST)识别 #导入需要的包 import num ...
- 理解ResNet结构与TensorFlow代码分析
该博客主要以TensorFlow提供的ResNet代码为主,但是我并不想把它称之为代码解析,因为代码和方法,实践和理论总是缺一不可. github地址,其中: resnet_model.py为残差网 ...
最新文章
- OpenCV-Python,计算机视觉开发利器
- 矩阵是怎样变换向量的
- python中求二维数组元素之和_乘以二维数组元素和和
- 正则表达式替换排除特定情况
- Akamai “三驾马车”,如何应对疫情后新场景形态下的新考验?
- 信息学奥赛C++语言:最大数max(x,y,z)
- STC51-l2C总线
- mysql游标是什么特性_[转]MySQL游标特性
- php恒等符,PHP基础-运算符
- java分层ppt_java程序设计第10章图形用户界面.ppt
- js:常用的3种弹出提示框(alert、confirm、prompt)
- 如何删除微软拼音输入法2003
- 单层石墨烯工业化量产科研成果及工业化量产基地落地
- 排球分组循环交叉编排_同学!中国海洋大学第一届排球联赛等你来战!
- MMC子系统之SDIO卡驱动
- ballerina 学习 三十一 扩展开发(二)
- flasgger手写phpwind接口文档
- 关于浏览器访问servlet404异常
- windows7 图形界面远程 centos6.5
- 原生js弹框Alert插件