声明

来源于莫烦Python:Classification 分类学习


分类问题的通俗理解

 分类和回归在于输出变量的类型上。通俗来讲,连续变量预测,如预测房价问题,属于回归问题; 离散变量预测,如把东西分成几类,属于分类问题。


代码

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)def add_layer(inputs, in_size, out_size, activation_function=None,):# add one more layer and return the output of this layerWeights = tf.Variable(tf.random_normal([in_size, out_size]))biases = tf.Variable(tf.zeros([1, out_size]) + 0.1,)Wx_plus_b = tf.matmul(inputs, Weights) + biasesif activation_function is None:outputs = Wx_plus_belse:outputs = activation_function(Wx_plus_b,)return outputsdef compute_accuracy(v_xs, v_ys):global predictiony_pre = sess.run(prediction, feed_dict={xs: v_xs})correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})return result# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 784]) # 28x28
ys = tf.placeholder(tf.float32, [None, 10])# add output layer
prediction = add_layer(xs, 784, 10,  activation_function=tf.nn.softmax)# the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))       # loss
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(100)  # 分批进行学习,时间短,效果不一定比整套的差sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})if i % 50 == 0:print(compute_accuracy(mnist.test.images, mnist.test.labels))

代码释义

1. MNIST 数据

首先准备数据(MNIST库)

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

MNIST库是手写体数字库,差不多是这样子的

数据中包含55000张训练图片,每张图片的分辨率是28×28,所以我们的训练网络输入应该是28×28=784个像素数据。

搭建网络

xs = tf.placeholder(tf.float32, [None, 784]) # 28x28

每张图片都表示一个数字,所以我们的输出是数字0到9,共10类。

ys = tf.placeholder(tf.float32, [None, 10])

调用add_layer函数搭建一个最简单的训练网络结构,只有输入层和输出层。

prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)

其中输入数据是784个特征,输出数据是10个特征,激励采用softmax函数,网络结构图是这样子的

2. Cross entropy loss

loss函数(即最优化目标函数)选用交叉熵函数。交叉熵用来衡量预测值和真实值的相似程度,如果完全相同,它们的交叉熵等于零。

cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) # loss

train方法(最优化算法)采用梯度下降法。

train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
sess.run(tf.global_variables_initializer())

3. 训练

现在开始train,每次只取100张图片,免得数据太多训练太慢。

batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})

每训练50次输出一下预测精度

if i % 50 == 0:print(compute_accuracy(mnist.test.images, mnist.test.labels))

输出结果如下:

0.1615
0.665
0.7495
0.7905
0.808
0.8262
0.8359
0.8432
0.8486
0.8538
0.8579
0.8592
0.8619
0.8613
0.8645
0.8658
0.8702
0.8719
0.8734
0.8754

Classification 分类学习相关推荐

  1. Classification分类学习

    #Classification分类学习import tensorflow as tfwimport numpy as np#from tensorflow.examples.tutorials.mni ...

  2. 莫烦 python_5.1 莫烦 Python Classification 分类学习

    Classification 分类学习 作者: Mark JingNB 编辑: Morvan 学习资料: 这次我们会介绍如何使用TensorFlow解决Classification(分类)问题. 之前 ...

  3. matlab 分类学习工具箱 Classification Learner

    转载:https://blog.csdn.net/qq_27914913/article/details/71436838 在matlab中,既有各种分类器的训练函数,比如"fitcsvm& ...

  4. matlab 分类学习工具箱 Classification Learner的使用及导出其生成的图,混淆矩阵confusion matrix的画法

    声明:转自https://blog.csdn.net/qq_27914913/article/details/71436838 https://blog.csdn.net/evil_xue/artic ...

  5. [Python人工智能] 六.TensorFlow实现分类学习及MNIST手写体识别案例

    从本专栏开始,作者正式开始研究Python深度学习.神经网络及人工智能相关知识.前一篇文章讲解了Tensorboard可视化的基本用法,并绘制整个神经网络及训练.学习的参数变化情况:本篇文章将通过Te ...

  6. 监督学习之分类学习:支持向量机

    监督学习之分类学习:支持向量机 如果想了解更多的知识,可以去我的机器学习之路 The Road To Machine Learning通道 Introduction 分类学习是最为常见的监督学习问题, ...

  7. Matlab自带的分类学习工具箱(SVM、决策树、Knn等分类器)

    在matlab中,既有各种分类器的训练函数,比如"fitcsvm",也有图形界面的分类学习工具箱,里面包含SVM.决策树.Knn等各类分类器,使用非常方便.接下来讲讲如何使用. 启 ...

  8. 机器学习笔记-多分类学习,类别不平衡,决策树

    读书笔记 多分类学习 基本思想:拆解法:将多分类任务拆解为若干个二分类任务求解,先对这些问题经拆分,为拆分出的每个二分类任务训练一个分类器,测试时,对这些分类器的预测结果进行集成以获得最终的多分类结果 ...

  9. 独家思维导图!让你秒懂李宏毅2020机器学习(二)—— Classification分类

    独家思维导图!让你秒懂李宏毅2020机器学习(二)-- Classification分类 在上一篇文章我总结了李老师Introduction和regression的具体内容,即1-4课的内容,这篇我将 ...

最新文章

  1. OC指示符assign、atomic、nonatomic、copy、retain、strong、week的解释
  2. signature验证/salt验证/token验证的作用
  3. android文件的写入与读取---简单的文本读写context.openFileInput() context.openFileOutput()...
  4. Spring(19)——Profile(二)
  5. Nutch的日志系统
  6. Linux:让普通用户临时性获得root用户权限
  7. Scanner获取用户输入
  8. leetcode1045. 买下所有产品的客户(SQL)
  9. go newscanner判断文件读取结束_Go单元测试-testing
  10. mysql 联合质检_第三次全国国土调查-统一时点更新阶段数据库质检规则业务细则解释(三)...
  11. Odoo10参考系列--混合而有用的类
  12. 黑客可利用 TeamViewer 缺陷远程窃取系统密码
  13. L1-009. N个数求和-PAT团体程序设计天梯赛GPLT
  14. Microsoft Office 2008 for Mac Service Pack 1 更新后无法启动程序问题解决方案
  15. RK3288_Android7.1添加两个gpio的按键
  16. 基于SpringBoot进销存ERP管理系统,源代码分享
  17. 网易服务器维护,网易:方便玩家 各大区服务器维护详细时间表列
  18. dell屏幕亮度调节不了_戴尔笔记本无法调节亮度怎么办?如何找回调节亮度?
  19. Win 10出现bitlocke恢复,蓝屏错误代码0x1600007e
  20. 【前端】js轮播图,简洁代码,一目了然

热门文章

  1. RTK如何进行面积测量,跟攻略学就对了
  2. 日本房产泡沫的崩塌,虽然很长,希望80后仔细阅读
  3. 移动支付服务Dwolla宣布10美元以下交易不收费
  4. 【软考】下午题 解题思路总结
  5. 红颜知己和蓝颜知己的区别
  6. [渝粤教育] 南京交通职业技术学院 计算机基础 参考 资料
  7. 发现美,创造美,拥有美^_^.
  8. 更改计算机休眠,win 7 无法设置自动休眠时间
  9. 智慧社区的现状分析及发展前景
  10. 域组策略与本地组策略