Tensorflow概要
Tensorflow是google的分布式机器学习系统,其既是一个实现机器学习算法的接口,同时也是执行机器学习算法的框架。它前段支持python、C++,Go,Java等多种开发语言,后端使用C++,CUDA等写成。
Tensorflow使用数据流式图来规划计算流程,它可以将计算映射到不同的硬件和操作系统平台。
Tensorflow中的计算可以表示为一个有向图,或称计算图,其中每一个运算操作将作为一个节点。节点与节点之间的连接称为边。这个计算图描述了数据的计算流程,它也负责维护和更新状态,用户可以对计算图的分支进行条件控制或循环操作。计算图的每一个节点可以有任意多个输入和任意多个输出,每一个节点描述了一种运算操作,节点可以算是运算操作的实例化。在计算图的边中流动(flow)的数据被称为张量(tensor),故得名Tensorflow。而tensor的数据类型,可以是事先定义的,也可以根据计算图的结构推断得到。有一类特殊的边中没有数据流动,这种边是依赖控制(control dependencies),作用是让它的起始节点执行完之后再执行目标节点,用户可以使用这样的边进行灵活的条件控制,比如限制内存使用的最高峰值。
Softmax Regression
Softmax Regression 的原理很简单,将可以判断为某类的特征相加,然后将这些特征转化为判定是这一类的概率。
本节使用Softmax Regression来实现手写数字集MNIST,代码与详细注释如下:

import tensorflow as tf#载入Tensorflow库
sess=tf.InteractiveSession()#创建一个新的InteractiveSession,使这个session注册为默认的session 具体解释https://blog.csdn.net/qq_14839543/article/details/77822916
x=tf.placeholder(tf.float32,[None,784])#placeholder是输入数据的地方,第一个参数是数据类型,第二个参数代表tensor的shape,也就是数据的尺寸,这里None代表不限条数的输入

接下来要给Softmax Regression模型中的weights和biases创建Variable对象。Variable在模型训练迭代中是持久化的,它可以长期存在并且在每轮迭代中被更新。

W=tf.Variable(tf.zeros([784,10]))#权重W的shape是[784,10],784是特征的维数,而后面的10代表有10类
b=tf.Variable(tf.zeros([10]))#偏差b是一维的零向量

接下来实现Softmax Regression算法,我们回忆一下上面提到的公式:y=softmax(Wx+b)。改写成Tensorflow的语言就是下面这行代码:

y=tf.nn.softmax(tf.matmul(x,W)+b)#Softmax是tf.nn下面的一个函数,tf.matmul是Tensoflow中的矩阵乘法函数

cross-entropy通常作为多分类问题的损失函数。

#定义cross-entropy
y_tf.placeholder(tf.float32,[None,10])
cross_entropy=tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))
#定义优化算法-随机梯度下降法SGD
train_step=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#优化目标设定为cross_entropy,学习速率为0.5
#使用Tensorflow的全局参数初始化器
tf.global_variables_initializer().run()
#迭代地执行训练操作train_step
for i in range(10000):batch_xs,batch_ys=mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})
#对模型的准确率进行验证
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
#tf.argmax是从一个tensor中寻找最大值的序号,tf.argmax(y,1)就是求各个预测的数字中概率最大的那一个,而tf.argmax(y_,1)则是找样本的真实数字类别。tf.equal判断预测的数字类别是否就是正确的类别
#统计全部样本预测的accuracy
accuary=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#tf.cast转化数据类型
print(accuary.eval({x:mnist.test.images,y_:mnist.test.labels}))
#eval()函数解释
https://www.runoob.com/python/python-func-eval.html

通过使用TensoFlow实现了一个简单的机器学习算法Softmax Regression,这可以算作是一个没有隐含层的最浅的神经网络。整个流程我们做的事情可以分为4个部分:
(1)定义算法公式,也就是神经网络forward时的计算。
(2)定义loss,选定迭代器,并指定优化器优化loss。
(3)迭代地对数据进行训练。
(4)在测试集或验证集上对准确率进行评测。
这几个步骤是我们使用TensorFlow进行算法设计、训练的核心步骤,也将会贯穿之后其他类型神经网络的章节。实际上,我们定义的这些公式其实只是Computation Graph,在执行这行代码时,计算还没有实际发生,只有等调用run方法,并feed数据时计算才真正执行。比如cross entropy、train step、accuracy等都是计算图中的节点,而不是数据结果,可以通过调用run方法执行这些节点或者说运算操作来获取结果。

Tensorflow实战之实现 Softmax Regression识别手写数字(学习笔记)相关推荐

  1. TensorFlow实战之Softmax Regression识别手写数字

       本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关概念 1.MNIST MNIST(Mixed N ...

  2. 机器学习实战之k-近邻算法识别手写数字(含拍照检验步骤详解)

    哈哈,这是我写的第一篇博客,就此拉开了我的程序员生涯的序幕.希望有缘人看见之后,能够解决你所遇见的问题.废话不多说,开始办正事. 本例中使用K-近邻算法识别手写数字,参考书目:Peter Harrin ...

  3. 小生不才:tensorflow实战01-基于bp神经网络的手写数字识别

    前言 利用搭建网络八股,使用简单的bp神经网络完成手写数字的识别. 搭建过程 导入相应的包 获取数据集,划分数据集和测试集并进行简单处理(归一化等) 对数据进行乱序处理 定义网络结构 选择网络优化器以 ...

  4. 【TensorFlow-windows】(一)实现Softmax Regression进行手写数字识别(mnist)

    博文主要内容有: 1.softmax regression的TensorFlow实现代码(教科书级的代码注释) 2.该实现中的函数总结 平台: 1.windows 10 64位 2.Anaconda3 ...

  5. 02:一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    标签(空格分隔): 王小草Tensorflow笔记 笔记整理者:王小草 笔记整理时间2017年2月24日 Tensorflow官方英文文档地址:https://www.tensorflow.org/g ...

  6. Tensorflow.js||使用 CNN 识别手写数字

    Tensorflow官方的tesorflow.js实操课程 链接为:link 使用 CNN 识别手写数字 文章目录 使用 CNN 识别手写数字 1. 简介 2. 设置操作 3. 加载数据 4. 定义模 ...

  7. tensorflow实现CNN识别手写数字

    上一篇使用TensorFlow识别手写数字,是直接采用了softmax进行多分类,直接将28*28的图片转换成为了784维的向量作为输入,然后通过一个(784,10)的权重,将输入转换成一个10维的向 ...

  8. TensorFlow实现识别手写数字

    当学习一门新的编程语言的时候,我们总是以输出"hello word"作为学习这门编程语言的开始,表示我们开启了这门编程语言的大门.而在机器学习的领域中,识别手写数字就像输出&quo ...

  9. svm手写数字识别_KNN 算法实战篇如何识别手写数字

    上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算 ...

最新文章

  1. binder-JAVA层机制
  2. Spring-AOP 切点/切面类型和创建切面
  3. Oracle-OLAP和OLTP解读
  4. [摘]一张图 , oracle merge用法:
  5. Freescale MC9S08AW60汇编学习笔记(五)
  6. 介绍一个非常好用的文件服务器 - HFS
  7. yum源查看mysql_获取MySQL各版本yum源 并安装
  8. 004-React入门概述
  9. 计算机毕设-文献摘要,毕设摘要翻译,要人工翻译,不要电脑网站翻译的。
  10. python标准库之logging
  11. 春运公益片“情满回家路”上线 顺风车等出行方式再被呼吁
  12. 基本排序(C语言版)
  13. java比较两个类的值不相同_java 反射---------比较两个相同类型的对象相同属性的属性值是否相同的具体调用...
  14. 坚果种类和营养价值排名
  15. 在计算机编程里pi是什么意思,编程中的术语“钩子”是什么意思?
  16. 【新知实验室 TRTCIM】实时互动课堂最佳实践
  17. 计算机网络实验-网络嗅探器
  18. i.MX6ULL系统移植 | 移植NXP官方linux4.1.15内核
  19. 手机服务器怎么维护,手机维护远程服务器
  20. SpringBoot2.x 集成 七牛云对象存储Kodo

热门文章

  1. 基于RGB-D的语义分割和目标检测介绍
  2. CAN总线学习记录之四:位定时与同步
  3. oracle未选定行大小写_关于Oracle中查询结果为未选定行
  4. 五个最佳的免费网络记事本
  5. C# winform 动物识别专家系统
  6. 微信小程序头部自定义
  7. 去除PDF文件水印方法
  8. spark基于物品的推荐_05.基于商品的协过滤推荐实现
  9. 比AtomicLong还高效的LongAdder源码解析
  10. atoi,itoi,atol,strtol, strtod函数转换