TensorFlow实现Softmax Regression识别手写数字

本文是按照黄文坚、唐源所著的《TensorFlow实战》一书,进行编写。在TensorFlow实战之余,力求简洁地讲清当中涉及机器学习、深度学习的原理,而并非只是简单调用TensorFlow中的Python api!

1 Softmax Regression原理

要讲Softmax回归,就要先讲Logistic回归。定性地来讲,Softmax回归是Logitic回归的拓展。一般而言,Logistic常用于二分类问题上,而Softmax回归则可用于多分类的问题上,比如后面实现的手写数字识别。

这里在使用TensorFlow实现Softmax回归识别手写数字之前,简单地讲解一下Softmax回归的原理。此处讲解以和Logistic回归对比为主。

Logistic回归和Softmax回归的对比:

前提:训练集由m个已标记(label)的样本构成:,其中,输入特征

无论是使用Logitic回归,还是Softmax回归,对于J(θ)的最小化问题,目前还没有闭式解法(即没有严格的公式,给出任意自变量就可以求出因变量的方法)。因此,我们可使用梯度下降法等,进行迭代优化。此时,需要求偏导,然后,调整学习速率进行权重更新。

2 算法实现流程

A 加载MNIST数据集

MNIST数据集下载网站:http://yann.lecun.com/exdb/mnist/,TensorFlow中有函数可以自动下载MNIST数据,如果下好的话,运行会更快一些。

B 初始化参数

在TensorFlow中使用placeholder初始化自变量x,使用Variable初始化权重w和偏置b。

C 构造模型

此程序使用的是Softmax模型,在TensorFlow中直接可调用

D 迭代训练参数

代价函数使用交叉熵的思想,梯度下降进行优化,多次迭代进行参数更新。

E 显示在测试集中的准确率

3 编程实现

# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
# 本程序是tensorflow中的基本例程: 使用softmax回归实现手写数字识别
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # 只显示errormnist = input_data.read_data_sets("./MNIST_data/",one_hot=True)
print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)
#print(mnist.train.labels[0])sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32,[None,784]) # placeholder可指定数据类型
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,w)+b) # 预测值y_ = tf.placeholder(tf.float32,[None,10]) # 真实值
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)tf.global_variables_initializer().run()for i in range(1000):batch_xs,batch_ys = mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})#print(sess.run(y,feed_dict={x:batch_xs})) # 显示每一次Softmax回归的结果,即每一类别的概率值
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))

4 实验结果

显示部分Softmax回归求出来的值,可以从下图可见,其求出的值是一个概率值。

识别率大致为92%

TensorFlow实现Softmax相关推荐

  1. TensorFlow MNIST (Softmax)

    代码 import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_ ...

  2. tensorflow之softmax

    softmax 就是把值做个映射,映射到0-1之间,并且映射之后,和为1. 举个例子: tt = tf.constant([1.0,2.0]) y = tf.nn.softmax(tt)with tf ...

  3. TensorFlow实战之Softmax Regression识别手写数字

       本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关概念 1.MNIST MNIST(Mixed N ...

  4. tensorflow问题

    20210121 ImportError: No module named 'tensorflow.python' https://stackoverflow.com/questions/414156 ...

  5. Tensorflow快速入门2--实现手写数字识别

    Tensorflow快速入门2–实现手写数字识别 环境:  虚拟机ubuntun16.0.4  Tensorflow 版本:0.12.0(仅使用cpu下) Tensorflow安装见:  http:/ ...

  6. tensorflow保存模型和加载模型的方法(Python和Android)

    tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...

  7. softmax实现多分类算法推导及代码实现

    关于多分类 我们常见的逻辑回归.SVM等常用于解决二分类问题,对于多分类问题,比如识别手写数字,它就需要10个分类,同样也可以用逻辑回归或SVM,只是需要多个二分类来组成多分类,但这里讨论另外一种方式 ...

  8. tensorflow实战学习笔记(1)

    tensorflow提供了三种不同的加速神经网路训练的并行计算模式 (一)数据并行: (二)模型并行: (三)流水线并行: 主流深度学习框架对比(2017): 第一章 Tensorflow实现Soft ...

  9. tensor如何实现转置_转置()TensorFlow中的函数

    转置是TensorFlow中提供的函数.此函数用于转置输入张量.语法:转置(input_tensor,perm,conjugate)参数:input_tensor:顾名思义,它是要转置的张量.类型:T ...

最新文章

  1. Yolov5系列AI常见数据集(1)车辆,行人,自动驾驶,人脸,烟雾
  2. VS2013在Release情况下使用vector有时候会崩溃的一个可能原因
  3. arcgis for javascript ArcGISDynamicMapServiceLayer 过滤图层点
  4. 【转】Unity利用WWW http传输Json数据
  5. C#之判断Mysql数据库表是否存在
  6. UE4官方文档学习笔记材质篇——分层材质
  7. 需求与商业模式分析-2-商业模式类型
  8. 制图利器—MapGIS10.5制图版体验
  9. Excel文件编辑保护如何取消?
  10. 手机电脑怎么上P站-国内版pixiv你可知晓
  11. 行列式计算程序(基于Python)
  12. 海康威视摄像头用yolo检测行人的一些问题
  13. twig html不转义,twig输出转义
  14. 一种物联网型的电能监控排插
  15. 宜信微服务架构落地及其演进|分享实录
  16. 建立工资计算系统(2)
  17. Nature Neuroscience:利用深度神经网络进行基于磁共振的眼动追踪
  18. 小啊呜产品读书笔记001:《邱岳的产品手记-11》第21讲 产品案例分析:Fabulous的精致养成
  19. 初链-解读初链黄皮书
  20. 27岁了,老大不小了,转载一篇文章作年度回顾

热门文章

  1. handler java_Java中以handler命名的类有什么含义吗?
  2. 软件测试面试题整理附答案小总结
  3. python输入一组数据找出被七除余一的数_【数学竞赛】七年级数学思维探究(4)信息技术中的数学问题(含答案)...
  4. 支付宝蜻蜓设备---修改HID模式的输出格式
  5. hive(spark-sql) -e -f -d以及传参数, sh并行
  6. 毕业生求职网用例说明文档
  7. Qt编写物联网管理平台40-类型种类
  8. [译|转]ESX 3.5中使用QLogic QLE 220 HBA卡
  9. CMOS图像传感器基础知识和参数理解
  10. 马来西亚理科大学 计算机 校区,马来西亚理科大学在马来西亚是一个怎样的存在?...