关于多分类

我们常见的逻辑回归、SVM等常用于解决二分类问题,对于多分类问题,比如识别手写数字,它就需要10个分类,同样也可以用逻辑回归或SVM,只是需要多个二分类来组成多分类,但这里讨论另外一种方式来解决多分类——softmax。

关于softmax

如何多分类

从下图看,神经网络中包含了输入层,然后通过两个特征层处理,最后通过softmax分析器就能得到不同条件下的概率,这里需要分成三个类别,最终会得到y=0、y=1、y=2的概率值。

继续看下面的图,三个输入通过softmax后得到一个数组[0.05 , 0.10 , 0.85],这就是soft的功能。

代价函数

使用场景

在多分类场景中可以用softmax也可以用多个二分类器组合成多分类,比如多个逻辑分类器或SVM分类器等等。该使用softmax还是组合分类器,主要看分类的类别是否互斥,如果互斥则用softmax,如果不是互斥的则使用组合分类器。

下面是使用tensorflow构建softmax分类器的代码:

# -*- coding: UTF-8 -*-
#!/usr/bin/python#从网上download到minist相关的数据以及处理函数
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)import tensorflow as tf#a classifier with only softmax layer#x是一个N * 784的矩阵,784指的是28 * 28 的图片拉伸为一行
#N为批数,由用户在运行时指定。
#x存放的是每一批的训练数据,是不断变更的,因此需要用到tf中的feed方法,因此在这里只用占位符
#None指的是这个维度的值是任意的,在这里是输入的batch的大小是任意的
x = tf.placeholder(tf.float32, [None, 784])#用于存放真实标签
y_ = tf.placeholder("float", [None,10])#variable是tensorflow中可以被修改的变量
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))#一行代码实现softmax的前向传播,y是预测输出
y = tf.nn.softmax(tf.matmul(x,W) + b)#训练模型,这里用的是交叉熵,交叉熵被认为是比较好的loss function
#reduce_sum是求张量所有元素总和的求和函数。log(y)和*(这里是点乘)都是逐个元素进行的
#从这里也可以看出tensorflow中的loss function 是需要自己定义的
cross_entropy = -tf.reduce_sum(y_*tf.log(y))#用梯度下降法优化交叉熵
#其他优化算法也是一行代码,详情查阅文档,0.01指的是学习率
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)#启动session,初始化变量,这里的图用的是默认图
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)#训练模型
for i in range(1000):batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})#测试并输出准确率
#tf.argmax 能给出某个tensor对象在某一维上的其数据最大值所在的索引值。由于标签向量是由0,1组成,因此最大值#1所在的索引位置就是类别标签,比如tf.argmax(y,1)返回的是模型对于任一输入x预测到的标签值,
#而 tf.argmax(y_,1) 代表正确的标签,我们可以用 tf.equal 来检测我们的预测是否真实标签匹配(索引位置一样表#示匹配)。
#y的维度可选值是[0, 1]
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))#上行代码会给我们一组布尔值。为了确定正确预测项的比例,我们可以把布尔值转换成浮点数,然后取平均值。例如,#[True, False, True, True] 会变成 [1,0,1,1] ,取平均值(reduce_mean)后得到 0.75.
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))#将预测集输入并输出准确率
print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})#总结:placeholder用于存放外部输入的tensor,variable用于存放内部自己变化的tensor

softmax实现多分类算法推导及代码实现相关推荐

  1. 深入解析GBDT二分类算法(附代码实现)

    目录: GBDT分类算法简介 GBDT二分类算法 2.1 逻辑回归的对数损失函数 2.2 GBDT二分类原理 GBDT二分类算法实例 手撕GBDT二分类算法 4.1 用Python3实现GBDT二分类 ...

  2. 支持向量机 SVM 算法推导优缺点 代码实现 in Python

    1.基本思想 前面讲到的Logistic Regression在拟合过程,实际上关注所有样本点的贡献,即寻找这么一个超平面,使得正例的特征远大于0,负例的特征远小于0,强调在全部训练数据上达到这一目标 ...

  3. BP神经网络算法推导及代码实现笔记

    文章目录 一.前言 二.科普 三.通往沙漠的入口: 神经元是什么,有什么用: 四.茫茫大漠第一步: 激活函数是什么,有什么用 五.沙漠中心的风暴:BP(Back Propagation)算法 1. 神 ...

  4. BP神经网络算法推导及代码实现笔记zz

    一. 前言: 作为AI入门小白,参考了一些文章,想记点笔记加深印象,发出来是给有需求的童鞋学习共勉,大神轻拍! [毒鸡汤]:算法这东西,读完之后的状态多半是 --> "我是谁,我在哪? ...

  5. AI从入门到放弃:BP神经网络算法推导及代码实现笔记

    作者 | @Aloys (腾讯员工,后台工程师) 本文授权转自腾讯的知乎专栏 ▌一. 前言: 作为AI入门小白,参考了一些文章,想记点笔记加深印象,发出来是给有需求的童鞋学习共勉,大神轻拍! [毒鸡汤 ...

  6. 分类算法-KNN(原理+代码+结果)

    KNN,即K最邻近算法,是数据挖掘分类技术中比较简单的方法之一,简单来说,就是根据"最邻近"这一特征对样本进行分类. 1.K-means和KNN区别 K-means是一种比较经典的 ...

  7. 基于deap脑电数据集的脑电情绪识别二分类算法(附代码)

    想尝试一下脑电情绪识别的各个二分类算法. 代码主要分为三部分:快速傅里叶变换处理(fft).数据预处理.以及各个模型处理. 采用的模型包括:决策树.SVM.KNN三个模型(模型采用的比较简单,可以直接 ...

  8. 机器学习——K近邻分类算法及python代码实现

    <机器学习:公式推导与代码实践>鲁伟著读书笔记. K近邻(K-nearest neighbor,K-NN)算法是一种经典的监督学习的分类方法.K近邻算法是依据新样本与k个与其相邻最近的样本 ...

  9. 统计学习导论之R语言应用(四):分类算法R语言代码实战

    统计学习导论之R语言应用(ISLR) 参考资料: The Elements of Statistical Learning An Introduction to Statistical Learnin ...

最新文章

  1. Ubuntu 64bit 安装 ulipad4.1
  2. sock 文件方式控制宿主机_浅析Docker运行安全
  3. ftp服务器搭建遇到的问题
  4. java 8 optional 类,Java8新特性-Optional类
  5. C++关键字--volatile
  6. 移民火星住哪?盖房的事就交给AI机器人Justin吧
  7. Python 将时间戳转换为本地时间并进行格式化
  8. plc通过无线通讯连接服务器,多个plc无线通讯方案
  9. Navicat Premium 12 安装教程 + 注册机 Navicat_Keygen_Patch_v5.0_By_DFoX_CHS [附资源]
  10. erf函数 matlab,关于ODE45初值问题和erf函数的问题
  11. 订单系统设计,消息队列幂等处理思路
  12. windows无法完成安装 若要在此计算机上安装_Win10无法启动,主引导记录(MBR)损坏,用这个方法快速修复...
  13. unity游戏开发之游戏过审后 国行PS4将在3月20日发售
  14. homelede软路由设置方法_低成本组装一台LEDE软路由
  15. HOW2J.CN - 学习笔记(类和对象)
  16. 该网页无法正常运作 目前无法处理此请求。 HTTP ERROR 500
  17. 第十一篇 ANDROID 系统网络连接和管理机制与架构
  18. mui-app开发之项目类型概览
  19. office 论文 页码_WORD目录,页眉,页脚,页码设置技巧 为你的毕业论文收藏吧-以Word 2013演示...
  20. Tomcat——配置Tomcat的服务端口(默认端口8080)

热门文章

  1. [LNOI] 相逢是问候 || 扩展欧拉函数+线段树
  2. Linux学习-YUM 在线升级机制
  3. json格式的简单转换
  4. 递归和对面对象编程初步
  5. Oracle表空间基础(4)
  6. 【Java】springboot学习笔记二
  7. 靠谱测试人员需具备业务分析能力
  8. golang 关闭gc 并手动gc_Golang垃圾回收 屏障技术
  9. 关联规则java apriori_关联规则算法(The Apriori algorithm)详解
  10. stm32实验报告心得体会_嵌入式第9次实验报告