tensorflow(2) TensorFlow Mechanics 101

标签(空格分隔): tensorflow


参考 TF英文社区TF中文社区

fully_connected_feed.py 是总体的运行过程
mnist.py中定义了四个函数,inference,training,loss,evaluation

mnist.py

一、inference

就是网络结构函数,mnist.py中的inference定义的网络有一对全连接层,和一个有10个线性节点的线性层

  • input:inference输入placeholder和第一层,第二层网络hidden units的个数
  • 每一层都有唯一的name_scope,所有的item都创建在这个namescope下,相当于给这一层的所有item加了一个前缀
with tf.name_scope('hidden1'):
  • 在每一个scope中,weight和biase由tf.Variable生成,大小根据(输入输出)的维度设置
    weight=[connect from,connect to]
    biase=[connect to]
  • 每个变量在创建时,都会被给予一个初始化操作
weights = tf.Variable(tf.truncated_normal([IMAGE_PIXELS, hidden1_units],stddev=1.0 / math.sqrt(float(IMAGE_PIXELS))),name='weights')
biases = tf.Variable(tf.zeros([hidden1_units]),name='biases')

比如weights会用tf.truncated_normal初始化器,根据给定的均值和标准差生成一个随机分布
biase根据tf.zeros保证它们的初始值都是0。

graph中主要有三个operation,两个tf.nn.relu和一个tf.matmul
最后,程序会返回包含了输出结果的logits Tensor。

二、loss

loss() 也是graph的一部分,输入两个参数,神经网络的分类结果和labels正确结果。进行比较,计算损失。

def loss(logits, labels):labels = tf.to_int64(labels)cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels, logits=logits, name='xentropy')return tf.reduce_mean(cross_entropy, name='xentropy_mean')

这个函数分为三步

  • 1.先将labels转换成所需要的格式
    tf.to_int64(labels)这个操作可以将labels抓换成指定的格式1-hot labels,
    1-hot labels:例如,如果类标识符为“3”,那么该值就会被转换为:
    [0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
    将inference的判断结果和labels进行比较
  • 2.利用函数tf.nn.sparse_softmax_cross_entropy_with_logits 计算交叉熵
  • 3.计算一个batch的平均loss
loss = tf.reduce_mean(cross_entropy, name='xentropy_mean')

tf.reduce_mean函数(可跨维度的计算平均值),计算batch维度(第一维度)下交叉熵(cross entropy)的平均值,将将该值作为总损失

三、training

  • input: loss tensor , learning rate
    主要分为四步
  • 1、 创建一个summarizer, 用来更新损失,summary的值会被写在events file里面
    tf.summary.scalar('loss', loss)
  • 2、创建一个optimizer优化器对象tf.train.GradientDescentOptimizer(设置学习率)
  • 3、创建global_step变量 ,用于记录全局训练步骤的单值
  • 4、开始优化 optimizer.minimize(输入loss 和 global step)
  • return train_op

四、evaluation

输入网络的分类结果和labels,和loss函数的输入一样

def evaluation(logits, labels):"""Evaluate the quality of the logits at predicting the label.Args:logits: Logits tensor, float - [batch_size, NUM_CLASSES].labels: Labels tensor, int32 - [batch_size], with values in therange [0, NUM_CLASSES).Returns:A scalar int32 tensor with the number of examples (out of batch_size)that were predicted correctly."""# For a classifier model, we can use the in_top_k Op.# It returns a bool tensor with shape [batch_size] that is true for# the examples where the label is in the top k (here k=1)# of all logits for that example.correct = tf.nn.in_top_k(logits, labels, 1)# Return the number of true entries.return tf.reduce_sum(tf.cast(correct, tf.int32))

fully_connected_feed.py

一旦图建立完之后,就可以在循环训练和评估
tensorflow/tensorflow/examples/tutorials/mnist/fully_connected_feed.py

总体步骤

1、设置输入
定义placeholder,函数def placeholder_inputs(batch_size)
2、开始训练run_rainning

  • 读入数据集
  • 建立图
  • 创建session
  • 初始化
  • 开始循环训练
    • check status
    • do evaluation

1.place holder

2.the graph

建图中所有的操作都是在with tf.Graph().as_default()下进行的
tf.graph可能会执行所有的ops,可以包含多个图,创建多个线程
我们只需要一个single graph

3.session

在定义完图后,需要创建一个会话session来开启这个图
- 创建session sess=tf.session()
- 创建initializer, initializer=tf.global_variables_initializer
- sess.run(initializer) 会自动初始化所有的变量

4.training loop

在变量初始化完成之后,就可以开始训练了
最简单的训练过程就以下两行代码

with step in xrange(FLAGS.max_Step)sess.run(train_op)

但是本例子要复杂一点,读入的数据每一步都要进行切分,以适应之前生成的place_holder

(1).fill_feed_dict

先让image_feed和labels_feed去向dataset索要下一次训练的一个batchsize的数据

images_feed, labels_feed = data_set.next_batch(FLAGS.batch_size,FLAGS.fake_data)

再讲这个数据整合成一个python字典的形式,image_placeholder 和labels_placeholder作为字典的key, image_feed和labels_feed作为字典的value

feed_dict = {images_placeholder: images_feed,labels_placeholder: labels_feed,
}

(2).检查训练状态

see.run 在每一步训练之后都会取得两个值,loss 和train_op,(train_op不返回任何值,discard)
,所以会得到每一步的loss
每训练100次,check一下,输出loss
每训练1000次,进行evaluation,将生成的model保存一下

(3).do_eval

计算整个epoch的精度

  true_count = 0  # Counts the number of correct predictions.steps_per_epoch = data_set.num_examples // FLAGS.batch_sizenum_examples = steps_per_epoch * FLAGS.batch_sizefor step in xrange(steps_per_epoch):feed_dict = fill_feed_dict(data_set,images_placeholder,labels_placeholder)true_count += sess.run(eval_correct, feed_dict=feed_dict)precision = float(true_count) / num_examplesprint('  Num examples: %d  Num correct: %d  Precision @ 1: %0.04f' %(num_examples, true_count, precision))

代码

建图步骤

  • Generate placeholders for the images and labels.
  • Build a Graph that computes predictions from the inference model.
  • Add to the Graph the Ops for loss calculation.
  • Add to the Graph the Ops that calculate and apply gradients.
  • Add the Op to compare the logits to the labels during evaluation.
  • Build the summary Tensor based on the TF collection of Summaries.
  • Add the variable initializer Op.
  • Create a saver for writing training checkpoints.
  • Create a session for running Ops on the Graph.
  • Instantiate a SummaryWriter to output summaries and the Graph.
  • And then after everything is built:Run the Op to initialize the variables.
    Start the training loop.

读取数据

def run_training():"""Train MNIST for a number of steps."""# Get the sets of images and labels for training, validation, and# test on MNIST.data_sets = input_data.read_data_sets(FLAGS.input_data_dir, FLAGS.fake_data)

开始建图

  # Tell TensorFlow that the model will be built into the default Graph.with tf.Graph().as_default():# Generate placeholders for the images and labels.images_placeholder, labels_placeholder = placeholder_inputs(FLAGS.batch_size)# Build a Graph that computes predictions from the inference model.logits = mnist.inference(images_placeholder,FLAGS.hidden1,FLAGS.hidden2)# Add to the Graph the Ops for loss calculation.loss = mnist.loss(logits, labels_placeholder)# Add to the Graph the Ops that calculate and apply gradients.train_op = mnist.training(loss, FLAGS.learning_rate)# Add the Op to compare the logits to the labels during evaluation.eval_correct = mnist.evaluation(logits, labels_placeholder)# Build the summary Tensor based on the TF collection of Summaries.summary = tf.summary.merge_all()# Add the variable initializer Op.init = tf.global_variables_initializer()# Create a saver for writing training checkpoints.saver = tf.train.Saver()# Create a session for running Ops on the Graph.sess = tf.Session()# Instantiate a SummaryWriter to output summaries and the Graph.summary_writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)# And then after everything is built:# Run the Op to initialize the variables.sess.run(init)

开始循环训练

    # Start the training loop.for step in xrange(FLAGS.max_steps):start_time = time.time()# Fill a feed dictionary with the actual set of images and labels# for this particular training step.feed_dict = fill_feed_dict(data_sets.train,images_placeholder,labels_placeholder)# Run one step of the model.  The return values are the activations# from the `train_op` (which is discarded) and the `loss` Op.  To# inspect the values of your Ops or variables, you may include them# in the list passed to sess.run() and the value tensors will be# returned in the tuple from the call._, loss_value = sess.run([train_op, loss],feed_dict=feed_dict)duration = time.time() - start_time# Write the summaries and print an overview fairly often.if step % 100 == 0:# Print status to stdout.print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration))# Update the events file.summary_str = sess.run(summary, feed_dict=feed_dict)summary_writer.add_summary(summary_str, step)summary_writer.flush()# Save a checkpoint and evaluate the model periodically.if (step + 1) % 1000 == 0 or (step + 1) == FLAGS.max_steps:checkpoint_file = os.path.join(FLAGS.log_dir, 'model.ckpt')saver.save(sess, checkpoint_file, global_step=step)# Evaluate against the training set.print('Training Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.train)# Evaluate against the validation set.print('Validation Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.validation)# Evaluate against the test set.print('Test Data Eval:')do_eval(sess,eval_correct,images_placeholder,labels_placeholder,data_sets.test)

tensorflow学习(2)TensorFlow Mechanics 101相关推荐

  1. tensorflow学习笔记——使用TensorFlow操作MNIST数据(1)

    续集请点击我:tensorflow学习笔记--使用TensorFlow操作MNIST数据(2) 本节开始学习使用tensorflow教程,当然从最简单的MNIST开始.这怎么说呢,就好比编程入门有He ...

  2. 深度学习与TensorFlow

    深度学习与TensorFlow DNN(深度神经网络算法)现在是AI社区的流行词.最近,DNN 在许多数据科学竞赛/Kaggle 竞赛中获得了多次冠军. 自从 1962 年 Rosenblat 提出感 ...

  3. 深度学习调用TensorFlow、PyTorch等框架

    深度学习调用TensorFlow.PyTorch等框架 一.开发目标目标 提供统一接口的库,它可以从C++和Python中的多个框架中运行深度学习模型.欧米诺使研究人员能够在自己选择的框架内轻松建立模 ...

  4. TensorFlow学习笔记——实现经典LeNet5模型

    TensorFlow实现LeNet-5模型 文章目录 TensorFlow实现LeNet-5模型 前言 一.什么是TensorFlow? 计算图 Session 二.什么是LeNet-5? INPUT ...

  5. 在浏览器中进行深度学习:TensorFlow.js (四)用基本模型对MNIST数据进行识别

    2019独角兽企业重金招聘Python工程师标准>>> 在了解了TensorflowJS的一些基本模型的后,大家会问,这究竟有什么用呢?我们就用深度学习中被广泛使用的MINST数据集 ...

  6. Tensorflow学习资源

    欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习.深度学习的知识! 作者:AI小昕 在之前的Tensorflow系列文章中,我们教大家 ...

  7. TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络

    TensorFlow 深度学习笔记 TensorFlow实现与优化深度神经网络 转载请注明作者:梦里风林 Github工程地址:https://github.com/ahangchen/GDLnote ...

  8. 【干货】史上最全的Tensorflow学习资源汇总,速藏!

    一 .Tensorflow教程资源: 1)适合初学者的Tensorflow教程和代码示例:(https://github.com/aymericdamien/TensorFlow-Examples)该 ...

  9. tensorflow学习函数笔记

    为什么80%的码农都做不了架构师?>>>    [TensorFlow教程资源](https://my.oschina.net/u/3787228/blog/1794868](htt ...

最新文章

  1. phpadmin试用
  2. maven初学者常见错误汇总(三)
  3. mongo mysql 聚合性能_Mongodb和Mysql的性能分析
  4. Azure SQL 数据库引入了新的服务级别
  5. fullbnt matlab,FullBNT学习笔记之一(matlab)
  6. 把canvas标签里的图像下载成本地图片文件
  7. android 获取元素的下标_Appium中定位方式by_android_uiautomator
  8. linux下,查找命令的使用
  9. sysbench数据库性能压测详解
  10. 2022华为软件精英挑战赛比赛经历
  11. Oracle数据库中的数据类型
  12. 关于黑客,你了解多少?----黑客入门学习(常用术语+DOS操作)
  13. 增量式编码器和绝对式编码器,ABI信号和UVW信号、编码器PWM信号
  14. 脚手架vue-cli系列二:vue-cli的工程模板与构建工具
  15. SAP ZSD008:Change SO Item Split
  16. linux 查看端口 程序,Linux查看程序端口占用情况
  17. 机器学习中的AUC理解
  18. 改变 Word正文 底色
  19. win7 aero特效 如何打开
  20. Ajax 查询手机号码归属地

热门文章

  1. lol6月五日服务器维护,LOL英雄联盟6月20日停机更新公告 服务器维护商品上架
  2. ARM架构的标准软硬件系统渐成形
  3. 分析appstore审核失败的真实案例及解决办法
  4. 机战OGS简单金手指制作
  5. 字幕文件批量重命名脚本 —— Linux
  6. 发包技术实现SEO快排原理解密
  7. 忘记密码(通过手机验证码找回设置)自己写
  8. Java 集合深入理解(11):LinkedList
  9. ios10怎么设置电池颜色_iOS10省电设置技巧 iOS10怎么最省电
  10. IDEA中,maven项目下,lombok插件 ,添加lombok.jar, Maven项目下lombok依赖配置