TensorFlow实现Softmax
TensorFlow实现Softmax Regression识别手写数字
本文是按照黄文坚、唐源所著的《TensorFlow实战》一书,进行编写。在TensorFlow实战之余,力求简洁地讲清当中涉及机器学习、深度学习的原理,而并非只是简单调用TensorFlow中的Python api!
1 Softmax Regression原理
要讲Softmax回归,就要先讲Logistic回归。定性地来讲,Softmax回归是Logitic回归的拓展。一般而言,Logistic常用于二分类问题上,而Softmax回归则可用于多分类的问题上,比如后面实现的手写数字识别。
这里在使用TensorFlow实现Softmax回归识别手写数字之前,简单地讲解一下Softmax回归的原理。此处讲解以和Logistic回归对比为主。
Logistic回归和Softmax回归的对比:
前提:训练集由m个已标记(label)的样本构成:,其中,输入特征
无论是使用Logitic回归,还是Softmax回归,对于J(θ)的最小化问题,目前还没有闭式解法(即没有严格的公式,给出任意自变量就可以求出因变量的方法)。因此,我们可使用梯度下降法等,进行迭代优化。此时,需要求偏导,然后,调整学习速率进行权重更新。
2 算法实现流程
A 加载MNIST数据集
MNIST数据集下载网站:http://yann.lecun.com/exdb/mnist/,TensorFlow中有函数可以自动下载MNIST数据,如果下好的话,运行会更快一些。
B 初始化参数
在TensorFlow中使用placeholder初始化自变量x,使用Variable初始化权重w和偏置b。
C 构造模型
此程序使用的是Softmax模型,在TensorFlow中直接可调用
D 迭代训练参数
代价函数使用交叉熵的思想,梯度下降进行优化,多次迭代进行参数更新。
E 显示在测试集中的准确率
3 编程实现
# -*- coding:utf-8 -*-
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
# 本程序是tensorflow中的基本例程: 使用softmax回归实现手写数字识别
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # 只显示errormnist = input_data.read_data_sets("./MNIST_data/",one_hot=True)
print(mnist.train.images.shape,mnist.train.labels.shape)
print(mnist.test.images.shape,mnist.test.labels.shape)
print(mnist.validation.images.shape,mnist.validation.labels.shape)
#print(mnist.train.labels[0])sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32,[None,784]) # placeholder可指定数据类型
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,w)+b) # 预测值y_ = tf.placeholder(tf.float32,[None,10]) # 真实值
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_*tf.log(y),reduction_indices=[1]))train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)tf.global_variables_initializer().run()for i in range(1000):batch_xs,batch_ys = mnist.train.next_batch(100)train_step.run({x:batch_xs,y_:batch_ys})#print(sess.run(y,feed_dict={x:batch_xs})) # 显示每一次Softmax回归的结果,即每一类别的概率值
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))print(accuracy.eval({x:mnist.test.images,y_:mnist.test.labels}))
4 实验结果
显示部分Softmax回归求出来的值,可以从下图可见,其求出的值是一个概率值。
识别率大致为92%
TensorFlow实现Softmax相关推荐
- TensorFlow MNIST (Softmax)
代码 import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_ ...
- tensorflow之softmax
softmax 就是把值做个映射,映射到0-1之间,并且映射之后,和为1. 举个例子: tt = tf.constant([1.0,2.0]) y = tf.nn.softmax(tt)with tf ...
- TensorFlow实战之Softmax Regression识别手写数字
本文根据最近学习TensorFlow书籍网络文章的情况,特将一些学习心得做了总结,详情如下.如有不当之处,请各位大拿多多指点,在此谢过. 一.相关概念 1.MNIST MNIST(Mixed N ...
- tensorflow问题
20210121 ImportError: No module named 'tensorflow.python' https://stackoverflow.com/questions/414156 ...
- Tensorflow快速入门2--实现手写数字识别
Tensorflow快速入门2–实现手写数字识别 环境: 虚拟机ubuntun16.0.4 Tensorflow 版本:0.12.0(仅使用cpu下) Tensorflow安装见: http:/ ...
- tensorflow保存模型和加载模型的方法(Python和Android)
tensorflow保存模型和加载模型的方法(Python和Android) 一.tensorflow保存模型的几种方法: (1) tf.train.saver()保存模型 使用 tf.train.s ...
- softmax实现多分类算法推导及代码实现
关于多分类 我们常见的逻辑回归.SVM等常用于解决二分类问题,对于多分类问题,比如识别手写数字,它就需要10个分类,同样也可以用逻辑回归或SVM,只是需要多个二分类来组成多分类,但这里讨论另外一种方式 ...
- tensorflow实战学习笔记(1)
tensorflow提供了三种不同的加速神经网路训练的并行计算模式 (一)数据并行: (二)模型并行: (三)流水线并行: 主流深度学习框架对比(2017): 第一章 Tensorflow实现Soft ...
- tensor如何实现转置_转置()TensorFlow中的函数
转置是TensorFlow中提供的函数.此函数用于转置输入张量.语法:转置(input_tensor,perm,conjugate)参数:input_tensor:顾名思义,它是要转置的张量.类型:T ...
最新文章
- Yolov5系列AI常见数据集(1)车辆,行人,自动驾驶,人脸,烟雾
- VS2013在Release情况下使用vector有时候会崩溃的一个可能原因
- arcgis for javascript ArcGISDynamicMapServiceLayer 过滤图层点
- 【转】Unity利用WWW http传输Json数据
- C#之判断Mysql数据库表是否存在
- UE4官方文档学习笔记材质篇——分层材质
- 需求与商业模式分析-2-商业模式类型
- 制图利器—MapGIS10.5制图版体验
- Excel文件编辑保护如何取消?
- 手机电脑怎么上P站-国内版pixiv你可知晓
- 行列式计算程序(基于Python)
- 海康威视摄像头用yolo检测行人的一些问题
- twig html不转义,twig输出转义
- 一种物联网型的电能监控排插
- 宜信微服务架构落地及其演进|分享实录
- 建立工资计算系统(2)
- Nature Neuroscience:利用深度神经网络进行基于磁共振的眼动追踪
- 小啊呜产品读书笔记001:《邱岳的产品手记-11》第21讲 产品案例分析:Fabulous的精致养成
- 初链-解读初链黄皮书
- 27岁了,老大不小了,转载一篇文章作年度回顾
热门文章
- handler java_Java中以handler命名的类有什么含义吗?
- 软件测试面试题整理附答案小总结
- python输入一组数据找出被七除余一的数_【数学竞赛】七年级数学思维探究(4)信息技术中的数学问题(含答案)...
- 支付宝蜻蜓设备---修改HID模式的输出格式
- hive(spark-sql) -e -f -d以及传参数, sh并行
- 毕业生求职网用例说明文档
- Qt编写物联网管理平台40-类型种类
- [译|转]ESX 3.5中使用QLogic QLE 220 HBA卡
- CMOS图像传感器基础知识和参数理解
- 马来西亚理科大学 计算机 校区,马来西亚理科大学在马来西亚是一个怎样的存在?...