TensorFlow实战minist数据集 softmax回归分类(一)
1、MNIST数据集简介:
MNIST数据集是一个手写体数据集,简单说就是一堆这样东西
MNIST的官网地址是MNIST; 通过阅读官网我们可以知道,这个数据集由四部分组成,分别是:
MNIST数据集主要由一些手写数字的图片和相应标签组成,图片总共分为10类,分别对应0~9十个数字。
如上图所示,每张图片的大小为28×28像素。而标签则由one-hot向量表示,一个one-hot向量除了某一位数字为1外,其余各唯独都是0。比如[1,0,0,0,0,0,0,0,0,0,0]表示数字“0”, [0,0,0,0,0,0,0,0,0,0,1]表示数字“9”,以此类推。
2、下载MNIST数据集
可以使用如下代码下载MNIST数据集到mnist_data文件夹
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("mnist_data", one_hot=True)
下载成功后,可以发现mnist_data文件夹下有以下四个文件:
文件 | 内容 |
---|---|
train-images-idx3-ubyte.gz | 训练集图片 - 55000 张 训练图片, 5000 张 验证图片 |
train-labels-idx1-ubyte.gz | 训练集图片对应的数字标签 |
t10k-images-idx3-ubyte.gz | 测试集图片 - 10000 张 图片 |
t10k-labels-idx1-ubyte.gz | 测试集图片对应的数字标签 |
3、将MNIST数据集保存为图片
为了更直观的了解MNIST数据集,我们可以把上面的训练集文件前50张训练图保存成图片。代码如下:
from tensorflow.examples.tutorials.mnist import input_data
import os
import scipy.misc as sm
import numpy as npmnist = input_data.read_data_sets("mnist_data", one_hot=True)
save_dir = 'mnist_data/image/'
if os.path.exists(save_dir) is False:os.mkdir(save_dir)for i in range(50):image_array = mnist.train.images[i, :]one_hot_label = mnist.train.labels[i, :]label = np.argmax(one_hot_label)image_array = image_array.reshape(28, 28)filename = save_dir + 'image_train_%d_%d.jpg' % (i, label)sm.toimage(image_array).save(filename)
注意:这里有个坑scipy的toimage在scipy1.0版本以后废除了想要运行代码并不报错(AttributeError: module ‘scipy.misc’ has no attribute ‘toimage’)
请使用:pip install scipy==1.0将scipy回归到1.0版本
4、Softmax回归介绍
Softmax回归是一个线性的多类分类模型。对于MNIST数据集的分类问题中,一个有10个类别(0~9),我们希望对输入的图像计算出它属于某个类别的概率,比如属于9的概率是80%,属于1的概率是5%等等,最后模型预测的结果就是概率最大的那个类别。
Softmax公式如下:
Softmax函数的主要功能是将各个类别的“打分”转化成合理的概率值,它将所有的类别转化为0~1之间的概率,而所有类别的概率加起来为1。
假设x是单个样本的特征,W、b是Softmax模型的参数,则计算x属于数字i类别的公式如下:
则整个Softmax模型可以用下面的式子表示:
Softmax回归模型可以用下图解释:
将上图写成等式:
用矩阵乘法和向量加法表示:
5、损失函数和优化器
训练模型的输出值和实际值肯定存在一定偏差,这种偏差越小,表示模型预测越准确,而衡量这种偏差的函数就是损失函数。Softmax回归模型中一般使用交叉熵来做损失函数。
交叉熵是判断一个输出向量与期望向量的接近程度的常用方法之一。它是分类问题中使用比较广的一种损失函数。
假设y是我们预测的概率分布, y’ 是实际的分布,则交叉熵公式为:
既然有损失,那么我们肯定就得去优化以减小这个损失。优化损失的函数有很多,我们这里使用梯度下降法。
6、tensorflow代码实现
首先需要导入tensorflow模块
#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
加载MNIST数据
mnist = input_data.read_data_sets('mnist_data', one_hot=True)
创建占位符和变量,用于存放图片数据和W权重,以及偏置
#创建x占位符,用于临时存放MNIST图片的数据,
# [None, 784]中的None表示不限长度,而784则是一张图片的大小(28×28=784)
x = tf.placeholder(tf.float32, [None, 784])
#W存放的是模型的参数,也就是权重,一张图片有784个像素作为输入数据,而输出为10
#因为(0~9)有10个结果
#b则存放偏置项
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros(10))#y表示Softmax回归模型的输出
y = tf.nn.softmax(tf.matmul(x, W) + b)#_存的是实际图像的标签,即对应于每张输入图片实际的值
y_ = tf.placeholder(tf.float32, [None, 10])
损失函数及优化函数
#定义损失函数,这里用交叉熵来做损失函数,y存的是我们训练的结果,而y_存的是实际标签的值
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y)))#优化函数,这里我们使用梯度下降法进行优化,0.01表示梯度下降优化器的学习率
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
使用Saver将训练结果保存
#将训练结果保存,如果不保存我们这次训练结束后的结果也随着程序运行结束而释放了
saver = tf.train.Saver()
等训练结束后,再用saver.save(sess, ‘./saver/mnist.ckpt’)将会话保存到指定文件即可
创建会话并初始化变量
#上面所做的只是定义算法,并没有真的运行,tensorflow的运行都是在会话(Session)中进行
with tf.Session() as sess:#初始化所有变量tf.global_variables_initializer().run()
开始训练
#开始训练,这里训练一千次
for _ in range(1000):#每次取100张图片数据和对应的标签用于训练batch_xs, batch_ys = mnist.train.next_batch(100)#将取到的数据进行训练sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys})
检测训练结果,并将模型保存
#检测训练结果,tf.argmax取出数组中最大值的下标,tf.equal再对比下标是否一样即可知道预测是否正确
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
#correct_prediction得到的结果是True或者False的数组
#我们再经过tf.cast将其转为数字的形式,即将[True, True, Flase, Flase]转成[1, 1, 0, 0]
#最后用tf.reduce_mean计算数组中所有元素的平均值,即预测的准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
#开始预测运算,并将准确率输出
print (sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels}))
#最后,将会话保存下来
saver.save(sess, './saver/mnist.ckpt')
预测
# coding: utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'mnist = input_data.read_data_sets('mnist_data', one_hot=True)x = tf.placeholder(tf.float32, [None, 784])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros(10))y = tf.nn.softmax(tf.matmul(x, w) + b)
y_ = tf.placeholder(tf.float32, [None, 10])saver = tf.train.Saver();with tf.Session() as sess:#导入saver.restore(sess, './saver/mnist.ckpt')correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))print (sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
运行结果
可以看到用Softmax模型预测的准确率大概为90.85%左右
7、完整代码
minist
8、总结
训练神经网络的过程:
1、定义神经网络的结构和前向传播输出结果
2、定义损失函数以及选择反向传播优化的算法
3、生成会话(tf.Session)并且在训练数据上反复运行反向传播优化算法
4、保存最好的训练结果
TensorFlow实战minist数据集 softmax回归分类(一)相关推荐
- TensorFlow基于minist数据集实现手写字识别实战的三个模型
手写字识别 model1:输入层→全连接→输出层softmax model2:输入层→全连接→隐含层→全连接→输出层softmax model3:输入层→卷积层1→卷积层2→全连接→dropout层→ ...
- 添加softmax层_PyTorch入门之100行代码实现softmax回归分类
本文首发于公众号[拇指笔记] 1. 使用pytorch实现softmax回归模型 使用pytorch可以更加便利的实现softmax回归模型. 1.1 获取和读取数据 读取小批量数据的方法: 首先是获 ...
- TensorFlow HOWTO 1.4 Softmax 回归
1.4 Softmax 回归 Softmax 回归可以看成逻辑回归在多个类别上的推广. 操作步骤 导入所需的包. import tensorflow as tf import numpy as np ...
- Tensorflow基于minist数据集实现自编码器
Tensorflow实现自编码器 自编码器 Denoising AutoEncoder(去噪自编码器) 自编码器 特征的稀疏表达:使用少量的基本特征组合拼装得到更高层抽象的特征. 如:图像碎片可由少量 ...
- Tensorflow实战之实现 Softmax Regression识别手写数字(学习笔记)
Tensorflow概要 Tensorflow是google的分布式机器学习系统,其既是一个实现机器学习算法的接口,同时也是执行机器学习算法的框架.它前段支持python.C++,Go,Java等多种 ...
- VGGNet tensorflow实战(CIFAR10数据集)
VGGNet的主体思想是什么呢 普通的神经网络可能是一个卷积层后面跟一个pooling层或者不跟pooling层 VGGNet是通过使用3*3的卷积或者1*1的卷积将层次进行加深 所以VGGNet可以 ...
- pytoch人工神经网络基础:最简单的分类(softmax回归+交叉熵分类)
softmax回归分类原理 对于回归问题,可以用模型预测值与真实值比较,用均方误差这样的损失函数表示误差,迭代使误差最小训练模型. 那么分类问题是否可以用线性回归模型预测呢.最简单的方法就是用soft ...
- 跟李沐学深度学习-softmax回归
softmax回归 分类和回归的区别 无校验比例 校验比例 交叉熵 常见损失函数 均方误差 L2 loss 绝对值损失L1 loss 鲁棒损失 图像分类数据集 分类和回归的区别 回归:估计一个连续值 ...
- 09 Softmax 回归 + 损失函数 + 图片分类数据集【动手学深度学习v2】
分类问题 分类问题只关心对正确类的预测 梯度下降理解 https://zhuanlan.zhihu.com/p/335191534(强推) 图像分类数据集 import matplotlib.pypl ...
- 【Pytorch神经网络基础理论篇】 08 Softmax 回归 + 损失函数 + 图片分类数据集
3.4. softmax回归 回归可以用于预测多少的问题. 比如预测房屋被售出价格,或者棒球队可能获得的胜场数,又或者患者住院的天数. 事实上,我们也对分类问题感兴趣:不是问"多少" ...
最新文章
- Modeling System Behavior with Use Case(2)
- 沉迷游戏自学编程,创建游戏帝国,却黯然退场的“鬼才程序员”
- 算法----------字符串的排列(Java版本)
- Python 笔试集(1):关于 Python 链式赋值的坑
- 基于Qt的OpenGL可编程管线学习(9)- X射线
- 3d max 安装和导入rvt模型失败
- 操作系统 实验3【动态分区存储管理】
- .NET开发作业调度(job scheduling) - Quartz.NET
- CRM WebUI的错误消息是如何从后台服务器取出并绘制到前台的
- HikariCP连接池配置
- 无人职守安装的设计与部署
- Spring boot实体类中常用基本注解
- 《推荐系统笔记(七)》因子分解机(FM)和它的推广(FFM、DeepFM)
- proxy_cache的使用
- 关于动态库so的makefile编写
- 超星尔雅不让下载?课件,拿来吧你!
- 【游戏开发实战】重温红白机经典FC游戏,顺便教你快速搭建2D游戏关卡(Tilemap | 场景 | 地图)
- 英语音标学习视频教程
- 战略分析师/商业分析师需要掌握的技能
- python怎么念1001python怎么念-python 星号的使用
热门文章
- Guava库学习:学习Guava Cache(二)Guava caches(2)
- 【Go语言】I/O专题
- redhat 安装 snort
- 虚拟机在教学实验中的应用
- Java线程并发与安全性问题详解
- 浙商证券计算机组成原理,中国海洋大学计算机组成原理期末模拟参考答案.doc...
- 计算机基础知识试题和答案6,计算机基础知识试题及答案选择题(九)
- BZOJ5221[Lydsy2017省队十连测] 偏题
- [线筛五连]线筛莫比乌斯函数
- Redis常用命令、数据类型讲解