声明:本文通过CNN实现mnist例子总结了TensorFlow 1.12的相关API。代码来源于《Learning TensorFlow》这本书,API查阅了TensorFlow官网API


作者: SixAbs
摘要

  本文通过对经典深度学习的入门示例“mnist手写体数字识别”API进行总结,其目的是使自己在初学时候熟悉TensorFlow相关API,同时熟悉TensorFlow的基本使用。为了直奔主题,本文忽略了对CNN相关知识的详述,首先直接给出了用TensorFlow实现mnist手写体数字识别的代码,然后直接依次罗列其中的API并给出解释和加以拓展。

  关键词:TensorFlow;API r1.12;mnist;


文章目录

  • 1 TensorFlow实现mnist的代码
  • 2 API总结
    • 2.1 截断正太分布
    • 2.2 Variable
    • 2.3 constant常量
    • 2.4 tf.placeholder() 占位符
    • 2.5 tf.reshape()
    • 2.6 tf.conv2d() 二维卷积操作
    • 2.7 tf.nn.max_pool() 最大池化
    • 2.8 tf.nn.relu() 修正线性单元
    • 2.9 tf.nn.dropout()
    • 其他以后再做总结

1 TensorFlow实现mnist的代码

"""
CNN手写体数字识别
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'# 先定义好需要用到的函数
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev=0.1)  return tf.Variable(initial)  def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')def max_pool_2x2(x):return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')def conv_layer(input, shape):W = weight_variable(shape)b = bias_variable([shape[3]])return tf.nn.relu(conv2d(input, W) + b)def full_layer(input, size):in_size = int(input.get_shape()[1])W = weight_variable([in_size, size])b = bias_variable([size])return tf.matmul(input, W) + bx = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])x_image = tf.reshape(x, [-1, 28, 28, 1])
conv1 = conv_layer(x_image, shape=[5, 5, 1, 32])
conv1_pool = max_pool_2x2(conv1)conv2 = conv_layer(conv1_pool, shape=[5, 5, 32, 64])
conv2_pool = max_pool_2x2(conv2)conv2_flat = tf.reshape(conv2_pool, [-1, 7*7*64])
full_1 = tf.nn.relu(full_layer(conv2_flat, 1024))keep_prob = tf.placeholder(tf.float32)
full1_drop = tf.nn.dropout(full_1, keep_prob=keep_prob)y_conv = full_layer(full1_drop, 10)sess = tf.Session()
tf.summary.FileWriter('log/', sess.graph)mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_conv, labels=y_))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))STEPS = 1001
with tf.Session() as sess:sess.run(tf.global_variables_initializer())for i in range(STEPS):batch = mnist.train.next_batch(50)if i % 100 == 0:train_accuracy = sess.run(accuracy, feed_dict={x: batch[0],y_: batch[1],keep_prob: 1.0})print("step {}, training accuracy {}".format(i, train_accuracy))sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})X = mnist.test.images.reshape(10, 1000, 784)Y = mnist.test.labels.reshape(10, 1000, 10)test_accuracy = np.mean([sess.run(accuracy, feed_dict={x: X[i], y_: Y[i], keep_prob: 1.0})for i in range(10)])print("test accuracy: {}".format(test_accuracy))

2 API总结

2.1 截断正太分布
tf.random.truncated_normal(shape,  # A 1-D integer Tensor or Python array. The shape of the output tensor.mean=0.0,  # A 0-D Tensor or Python value of type dtype. The mean of the truncated normal distribution.stddev=1.0,  # A 0-D Tensor or Python value of type dtype. The standard deviation of the normal distribution, before truncation.dtype=tf.float32,  # The type of the output.seed=None,  # A Python integer. Used to create a random seed for the distribution. See tf.set_random_seed for behavior.name=None  # A name for the operation (optional).
)

截断正太分布常用于给训练参数(如:权值,偏置)产生一些初值。

2.2 Variable

  变量(Variable)是特殊的张量(Tensor),它的值可以是一个任何类型和形状的张量。与其他张量不同,变量存在于单个 session.run 调用的上下文之外,也就是说,变量存储的是持久张量,当训练模型时,用变量来存储和更新参数。除此之外,在调用operator之前,所有变量都应被显式地初始化过。Variable其实是一个python类,该类的构造函数其实包含了很多参数:

__init__(initial_value=None,trainable=True,collections=None,validate_shape=True,caching_device=None,name=None,variable_def=None,dtype=None,expected_shape=None,import_scope=None,constraint=None,use_resource=None,synchronization=tf.VariableSynchronization.AUTO,aggregation=tf.VariableAggregation.NONE
)

其中initial_value是传入的初始化值,官网是这样描述的:

initial_value: A Tensor, or Python object convertible to a Tensor, which is the initial value for the Variable. The initial value must have a shape specified unless validate_shape is set to False. Can also be a callable with no argument that returns the initial value when called. In that case, dtype must be specified. (Note that initializer functions from init_ops.py must first be bound to a shape before being used here.)
  变量的初始值可以是一个张量,或者是可转换为张量的Python对象。初始值必须具有指定的形状,除非validate_shape参数设置为False。也可以是一个无参数调用,调用时返回初始值。在这种情况下,dtype必须指定。(请注意,init_ops.py中的初始化函数在使用之前必须先绑定到一个形状。

  变量创建时使用tf.Variable(), 在使用前需要为初始化数据分类内存,这时需要给sess.run()传入一个tf.global_variables_initializer()。

init = tf.global_variables_initializer()
sess.run(init)

  和其他张量对象一样,Variables只有在运行模型时才会计算。同时重用同一个variable为了提高效率,我们可以调用tf.get_variables()。
  其他更详细的信息直接参看官方api。

2.3 constant常量
tf.constant(value,  # A constant value (or list) of output type dtype.dtype=None,  # The type of the elements of the resulting tensor.shape=None,  # Optional dimensions of resulting tensor.name='Const',  # Optional name for the tensor.verify_shape=False  #  Boolean that enables verification of a shape of values.
)

根据函数参数信息,可以发现value是必须传的参数。
与tf.fill()比较:

tf.constant differs from tf.fill in a few ways:

  1. tf.constant supports arbitrary constants, not just uniform scalar Tensors like tf.fill.
  2. tf.constant creates a Const node in the computation graph with the exact value at graph construction time. On the other hand, tf.fill creates an Op in the graph that is expanded at runtime.
  3. Because tf.constant only embeds constant values in the graph, it does not support dynamic shapes based on other runtime Tensors, whereas tf.fill does.
2.4 tf.placeholder() 占位符

  TensorFlow已经为我们指定内置结构用于供给输入值,这些结构称为占位符。 占位符可以被认为是空变量,并将在随后填充数据。 我们首先使用它们来构建我们的图形,并且只有在执行它时才使用输入数据。

tf.placeholder(dtype,  # 指定占位符数据类型shape=None,  # shape指定输入的shape,当某一维指定位None,表示这一维可以是任意值。如x = tf.placeholder(tf.float32, shape=[None, 784]),这里None表示这个维度可以是任意大小,通常用于表示样本数量。name=None  # 名称
)
2.5 tf.reshape()
tf.reshape(  # 将给定的tensor的形状转换为指定的shapetensor,shape,name=None
)

注意shape参数可以有一个-1,表示缺省值(自适应),就是先根据其他维度调整,到时tensor总维度乘积除以其他几个维度乘积,就是缺省的维度大小。如:

a = tf.placeholder(tf.float32, shape=[1, 24])
print(a.get_shape())
b = tf.reshape(a, [-1, 3, 4])
print(b.get_shape())
# out =
# (1, 24)
# (2, 3, 4)

b的第一维设为 − 1 -1 1,通过自适应(如果不能整除将会报错),第一维reshape后为 ( 1 × 24 ) / ( 3 × 4 ) = 2 (1\times 24)/(3\times4) = 2 (1×24)/(3×4)=2。特殊的 shape=[-1]表示将tensor展成一维。

2.6 tf.conv2d() 二维卷积操作
tf.nn.conv2d(input,   # 输入张量data_format默认为 [batch, in_height, in_width, in_channels], 数值类型必须是half, bfloat16, float32, float64filter,  # 滤波器(核)[filter_height, filter_width, in_channels, out_channels]strides,  # 移动步长:[batch, height, weight, channel]分别表示对应的移动步长padding,  # 边缘补0,设置为‘SAME’添加后产生的特征图和输入维度一样大use_cudnn_on_gpu=True,  # bool类型,是否使用cudnn加速,默认为truedata_format='NHWC',  # 输入输出数据格式:默认为 [batch, height, width, channels]dilations=[1, 1, 1, 1],  # 每一维与data_format对应。如果设置为k> 1,则该维度上的每个滤镜元素之间将有k-1个跳过的单元格。可以做出中空滤波器的效果,用相同数量的参数获得更大的感受野name=None   # 名字
)
2.7 tf.nn.max_pool() 最大池化
tf.nn.max_pool(value,  # data_format格式的4维张量,一般是卷积后的feture mapksize,  # 池化窗口大小,取一个四维向量,一般是[1, height, width, 1],因为我们不想在batch和channels上做池化strides,  # 和卷积类似,窗口在每一个维度上滑动的步长,一般也是[1, stride,stride, 1]padding,  # 和卷积类似,可以取'VALID' 或者'SAME'data_format='NHWC', name=None
)
2.8 tf.nn.relu() 修正线性单元
tf.nn.relu(features,  # 张量name=None  # 名称
)
2.9 tf.nn.dropout()
tf.nn.dropout(x,  # float类型的tensorkeep_prob,  # float类型,每个元素被保留下来的概率,设置神经元被选中的概率,在初始化时keep_prob是一个占位符, keep_prob = tf.placeholder(tf.float32) 。tensorflow在run时设置keep_prob具体的值,例如keep_prob: 0.5noise_shape=None, # 一个1维的int32张量,代表了随机产生“保留/丢弃”标志的shapeseed=None,  # 整形变量,随机数种子name=None
)
其他以后再做总结
tf.matmul(a,b)  # 矩阵相乘a*b
tf.Session()  # 在会话中启动图
tf.summary.FileWriter()  # 将摘要协议缓冲区写入事件文件
tf.reduce_mean() # 默认reduce_mean(x)对所有元素求均值,指定减小的维度用axis属性
tf.equal() # 判断张量相等
tf.argmax() # 返回张量中的最大值
tf.cast() # 将tensor转型为新的类型

mnist手写体识别中用到的TensorFlow API总结相关推荐

  1. TensorRT(3)-C++ API使用:mnist手写体识别

    本节将介绍如何使用tensorRT C++ API 进行网络模型创建. 1 使用C++ API 进行 tensorRT 模型创建 还是通过 tensorRT官方给的一个例程来学习. 还是mnist手写 ...

  2. 【人工智能项目】MNIST手写体识别实验及分析

    [人工智能项目]MNIST数据集实验报告 这是之前接的小作业,现在分享出来,给大家以学习!!! [人工智能项目]MNIST手写体识别实验及分析 1.实验内容简述 1.1 实验环境 本实验采用的软硬件实 ...

  3. TensorRT(2)-基本使用:mnist手写体识别

    结合 tensorRT官方给出的一个例程,介绍tensorRT的使用. 这个例程是mnist手写体识别.例程位于目录: /usr/src/tensorrt/samples/sampleMNIST 文件 ...

  4. python模拟手写笔迹_pytorch实现MNIST手写体识别

    本文实例为大家分享了pytorch实现MNIST手写体识别的具体代码,供大家参考,具体内容如下 实验环境 pytorch 1.4 Windows 10 python 3.7 cuda 10.1(我笔记 ...

  5. R︱Softmax Regression建模 (MNIST 手写体识别和文档多分类应用)

    本文转载自经管之家论坛, R语言中的Softmax Regression建模 (MNIST 手写体识别和文档多分类应用) R中的softmaxreg包,发自2016-09-09,链接:https:// ...

  6. 2021年人工神经网络第四次作业 - 第二题MNIST手写体识别

    简 介: ※MNIST数据集合是深度学习基础训练数据集合.改数据集合可以使用稠密前馈神经网络训练,也可以使用CNN.本文采用了单隐层BP网络和LeNet网络对于MNIST数据集合进行测试.实验结果标明 ...

  7. python神经网络案例——CNN卷积神经网络实现mnist手写体识别

    分享一个朋友的人工智能教程.零基础!通俗易懂!风趣幽默!还带黄段子!大家可以看看是否对自己有帮助:点击打开 全栈工程师开发手册 (作者:栾鹏) python教程全解 CNN卷积神经网络的理论教程参考 ...

  8. python神经网络案例——FC全连接神经网络实现mnist手写体识别

    全栈工程师开发手册 (作者:栾鹏) python教程全解 FC全连接神经网络的理论教程参考 http://blog.csdn.net/luanpeng825485697/article/details ...

  9. [Python人工智能] 六.TensorFlow实现分类学习及MNIST手写体识别案例

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了Tensorboard可视化的基本用法,并绘制整个神经网络及训练.学习的参数变化情况:本篇文章将通过Te ...

最新文章

  1. Windows10为什么自带Linux,一直没有发现原来 Win10 内置了一个 Linux
  2. js获取浏览器高和宽的基本信息:屏幕信息
  3. 异步GridView(ASPxGridView) 特点介绍(2) - 筛选(Filter)、弹出编辑(Editing)
  4. 查看系统各个进程打开的文件描述符数量
  5. 通道Channel-使用NIO 写入数据
  6. [Head First Java] - 给线程命名
  7. C++细节系列(零):零散记录
  8. Spring学习(20)--- Schema-based AOP(基于配置的AOP实现) -- 配置切入点pointcut
  9. jquery Demo 以及code
  10. VR AR体验或成2017圣丹斯电影节“新主角”
  11. MySQL中的float和decimal类型有什么区别
  12. php 获取城市列表接口,省份城市区域列表
  13. OneNote for win10 登录不了
  14. 数据分析师真实的工作是怎样的,这篇文章带你看他们的职责
  15. flash在C#中的应用
  16. 广州大学2021计算机组成原理课程设计实验报告
  17. 实用工具WGestures全局鼠标手势
  18. iframe标签全屏
  19. 微信小程序背景图片background无法在手机端显示问题解决方案
  20. SAP中图文展示分摊和分配的区别

热门文章

  1. SAP 库存查询函数
  2. CAD二次开发--二维多段线Polyline与三维多段线Polyline3d创建总结
  3. 公元前到现在的所有朝代
  4. vue3 Ant Design Vue DatePicker 默认当前年
  5. ew 135 G3 森海塞尔无线话筒
  6. Swift JSON与模型转换(HandyJSON封装)
  7. 多功能输入法――内码转换模块设计与实现(1)
  8. 汽车功能安全研究:主机厂和供应商的ISO26262布局
  9. 2025年渗透率超17%!高阶智驾受热捧,供应商扎堆入场
  10. CAx软件开发技术专题:后处理可视化常用算法