原文地址:http://blog.sina.com.cn/s/blog_6982136301015asd.html

SoftMax回归可以用来进行两种以上的分类,很是神奇!实现过程实在有点坎坷,主要是开始写代码的时候理解并不透彻,而且思路不清晰,引以为戒吧!

SoftMax Regression属于指数家族,证明见( http://cs229.stanford.edu/notes/cs229-notes1.pdf 及http://ufldl.stanford.edu/wiki/index.php/Softmax_Regression),最后得出的结论是:

最后写出似然估计: 

之后采用梯度或牛顿方法进行逼近:

注意,上面公式采用的是批量梯度,随机梯度当然也是可以的。

参数theta的更新如下:
要注意的是,theta[j]是一个向量。

实验还是参考大牛pennyliang(http://blog.csdn.net/pennyliang/article/details/7048291),代码如下:

  1. #include <iostream>
  2. #include <cmath>
  3. #include <assert.h>
  4. using namespace std;
  5. const int K = 2;//有K+1类
  6. const int M = 9;//训练集大小
  7. const int N = 4;//特征数
  8. double x[M][N]={{1,47,76,24}, //include x0=1
  9. {1,46,77,23},
  10. {1,48,74,22},
  11. {1,34,76,21},
  12. {1,35,75,24},
  13. {1,34,77,25},
  14. {1,55,76,21},
  15. {1,56,74,22},
  16. {1,55,72,22},
  17. };
  18. double y[M]={1,
  19. 1,
  20. 1,
  21. 2,
  22. 2,
  23. 2,
  24. 3,
  25. 3,
  26. 3,};
  27. double theta[K][N]={
  28. {0.3,0.3,0.01,0.01},
  29. {0.5,0.5,0.01,0.01}}; // include theta0
  30. double h_value[K];//h(x)向量值
  31. //求exp(QT*x)
  32. double fun_eqx(double* x, double* q)
  33. {
  34. double sum = 0;
  35. for (int i = 0; i < N; i++)
  36. {
  37. sum += x[i] * q[i];
  38. }
  39. return pow(2.718281828, sum);
  40. }
  41. //求h向量
  42. void h(double* x)
  43. {
  44. int i;
  45. double sum = 1;//之前假定theta[K+1]={0},所以exp(Q[K+1]T*x)=1
  46. for (i = 0; i < K; i++)
  47. {
  48. h_value[i] = fun_eqx(x, theta[i]);
  49. sum += h_value[i];
  50. }
  51. assert(sum != 0);
  52. for (i = 0; i < K; i++)
  53. {
  54. h_value[i] /= sum;
  55. }
  56. }
  57. void modify_stochostic()
  58. {
  59. //随机梯度下降,训练参数
  60. int i, j, k;
  61. for (j = 0; j < M; j ++)
  62. {
  63. h(x[j]);
  64. for (i = 0; i < K; i++)
  65. {
  66. for (k = 0; k < N; k++)
  67. {
  68. theta[i][k] += 0.001 * x[j][k] *  ((y[j] == i+1?1:0) - h_value[i]);
  69. }
  70. }
  71. }
  72. }
  73. void modify_batch()
  74. {
  75. //批量梯度下降,训练参数
  76. int i, j, k ;
  77. for (i = 0; i < K; i++)
  78. {
  79. double sum[N] = {0.0};
  80. for (j = 0; j < M; j++)
  81. {
  82. h(x[j]);
  83. for (k = 0; k < N; k++)
  84. {
  85. sum[k] += x[j][k] * ((y[j] == i+1?1:0) - h_value[i]);
  86. }
  87. }
  88. for (k = 0; k < N; k++)
  89. {
  90. theta[i][k] += 0.001 * sum[k] / N;
  91. }
  92. }
  93. }
  94. void train(void)
  95. {
  96. int i;
  97. for (i = 0; i < 10000; i++)
  98. {
  99. //modify_stochostic();
  100. modify_batch();
  101. }
  102. }
  103. void predict(double* pre)
  104. {
  105. //输出预测向量
  106. int i;
  107. for (i = 0; i < K; i++)
  108. h_value[i] = 0;
  109. train();
  110. h(pre);
  111. for (i = 0; i < K; i++)
  112. cout << h_value[i] << " ";
  113. cout << 1 - h_value[0] - h_value[1] << endl;
  114. }
  115. int main(void)
  116. {
  117. for (int i=0; i < M; i++)
  118. {
  119. predict(x[i]);
  120. }
  121. cout << endl;
  122. double pre[] = {1,20, 80, 50 };
  123. predict(pre);
  124. return 0;
  125. }
代码实现了批量梯度和随机梯度两种方法,实验最后分别将训练样本带入进行估计,迭代10000次的结果为:
stochastic:
0.999504 0.000350044 0.000145502
0.997555 0.00242731 1.72341e-005
0.994635 1.24138e-005 0.00535281
2.59353e-005 0.999974 6.07695e-017
0.00105664 0.998943 -1.09071e-016
4.98481e-005 0.99995 3.45318e-017
0.0018048 1.56509e-012 0.998195
0.000176388 1.90889e-015 0.999824
0.000169041 8.42073e-016 0.999831
batch:
0.993387 0.00371185 0.00290158
0.991547 0.0081696 0.000283336
0.979246 0.000132495 0.0206216
0.000630111 0.99937 4.9303e-014
0.00378715 0.996213 9.37462e-014
0.000299602 0.9997 3.50739e-017
0.00759726 2.60939e-010 0.992403
0.0006897 1.09856e-012 0.99931
0.000545117 5.19157e-013 0.999455
可见随机梯度收敛的更快。
对于预测来说,输出结果每行的三个数表示是:对于输入来说,是1 2 3三类的概率分别是多少。

Machine Learning系列实验--SoftMax Regression相关推荐

  1. 【机器学习|数学基础】Mathematics for Machine Learning系列之图论(8):割边、割集、割点

    文章目录 前言 系列文章 3.2 割边.割集.割点 3.2.1 割边与割集 定理3.4 推论3.4 定理3.5 补充知识 定义3.3:割集 定义3.4 定理3.6 生成树与割集的对比 3.2.2 割点 ...

  2. 【机器学习|数学基础】Mathematics for Machine Learning系列之矩阵理论(14):向量范数及其性质

    目录 前言 往期文章 4.1 向量范数及其性质 4.1.1 向量范数的概念及P-范数 定义4.1 例1 向量的几种范数 4.1.2 n n n维线性空间 V V V上的向量范数等价性 定理4.1.1 ...

  3. 【机器学习|数学基础】Mathematics for Machine Learning系列之线性代数(20):用配方法化二次型为标准形

    目录 前言 往期文章 5.6 用配方法化二次型为标准形 题目一 题目二 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出-   自我介绍 ଘ(੭ˊᵕˋ)੭ ...

  4. 【机器学习|数学基础】Mathematics for Machine Learning系列之线性代数(10):向量组及其线性组合

    文章目录 前言 往期文章 4.1 向量组及其线性组合 定义1 定义2 定理1 定义3 定理2 推论 举例 例 1 例2 定理3 小结 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文 ...

  5. 【机器学习|数学基础】Mathematics for Machine Learning系列之线性代数(21):正定二次型

    目录 前言 往期文章 5.7 正定二次型 定理9:惯性定理 定义10 定理10 推论 定理11:赫尔维茨定理 举例 例17 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文中有错误的 ...

  6. 【机器学习|数学基础】Mathematics for Machine Learning系列之线性代数(26):线性变换的矩阵表达式

    目录 前言 往期文章 6.5 线性变换的矩阵表达式 定义6 定理2 定义7 举例 例11 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出-   自我介绍 ...

  7. 【机器学习|数学基础】Mathematics for Machine Learning系列之图论(9):匹配的概念

    文章目录 前言 系列文章 5.1 匹配的概念 定义5.1 定义5.2 定义 5.3 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出-   自我介绍 ଘ(੭ ...

  8. 【机器学习|数学基础】Mathematics for Machine Learning系列之矩阵理论(25):幂级数(补充知识)

    目录 前言 往期文章 幂级数 一.函数项级数的概念 定义:(函数项)无穷级数 幂级数及其收敛性 幂级数 定理1(阿贝尔定理) 推论 定理2 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章, ...

  9. 【机器学习|数学基础】Mathematics for Machine Learning系列之矩阵理论(17):函数矩阵的微分和积分

    目录 前言 往期文章 5.2 函数矩阵的微分和积分 5.2.1 函数矩阵对自变量的微分和积分 定义5.3:函数矩阵 定义5.4:函数矩阵的微分 单元函数矩阵的一些性质 例1 定义5.5 函数矩阵的积分 ...

  10. 【机器学习|数学基础】Mathematics for Machine Learning系列之线性代数(19):二次型及其标准形

    目录 前言 往期文章 5.5 二次型及其标准形 定义8:二次型 定义9:合同 定理8 推论 举例 例14 结语 前言 Hello!小伙伴! 非常感谢您阅读海轰的文章,倘若文中有错误的地方,欢迎您指出- ...

最新文章

  1. ICLR2020放榜 687篇入选34篇得满分! 且看OpenReview数据图文详解
  2. Codeforces Round #441 Div. 2题解
  3. 助力区域性银行突破困局,网易云信入选爱分析报告典型案例
  4. 一个小小指针,竟把Linux内核攻陷了!
  5. pyqy5——控件2
  6. 拾牙的2021年秋招总结(大概会有帮助?)
  7. Ubuntu 16.04 UUID 开机自动挂载硬盘
  8. ideahtml标签不提示_仓储物流加速,电子标签亮灯拣选系统的优势
  9. 语言程序设计 郭有强_「概念篇8」程序语言如何被计算机理解?靠猜?那就搞笑了...
  10. UI2CODE智能生成flutter代码--整体架构 资料下载
  11. Java中private修饰变量的继承问题
  12. 数据结构和算法9——哈希表
  13. [幽默漫画]对于程序猿来说deadline很容易搞定!
  14. linux获取cpu数量函数,Linux上获取CPU Core个数的实现
  15. css浮动与清除浮动相关总结(附图解、实例)
  16. JDK1.6安装_BouncyCastle JCE扩展加密算法解决JDK1.6 sftp连接openssh8.6Algorithm negotiation fail问题
  17. LIO-SAM后端中的回环检测及位姿计算
  18. Python爬网易云音乐的那些事
  19. 一招教你不用任何软件就能知道谁动过你的电脑并做了哪些详细的操作,比查看Recent文件访问记录更厉害的方法开机自动运行PSR录制截取电脑操作
  20. 屏幕撕裂及掉帧原因与解决方案

热门文章

  1. 响应式微服务 in java 译 十六 Deploying a Microservice in OpenShift
  2. MySql的基本操作以及以后开发经常使用的常用指令
  3. bzoj4423[AMPPZ2013]Bytehattan
  4. 寄存器位读写,结构体位域定义,位域操作,位操作
  5. SCCM2012系列之十,SCCM2012软件分发
  6. InputStreamReader
  7. mysql在查询结果列表前添加一列递增的序号列(最简)
  8. AngularJS中$apply
  9. 一些CFD名词缩写的含义(持续更新中)
  10. Runloop与autoreleasePool联系