【机器学习】Softmax Regression算法原理与java实现

  • 1、Softmax Regression算法原理
    • 1.1、样本概率
    • 1.2、损失函数
    • 1.3、梯度下降法训练模型
  • 2、java实现
  • 参考资料

Logistic Regression算法是线性二分类算法,Softmax Regression算法是Logistic Regression算法在多分类问题上的推广,其中任意两个类别的样本是线性可分的(参考资料1)。

1、Softmax Regression算法原理

1.1、样本概率

假设样本{x1,x2,⋯ ,xn}\left\{ {{x_1},{x_2}, \cdots ,{x_n}} \right\}{x1​,x2​,⋯,xn​}的个数为nnn,样本特征个数为mmm,样本标签类别为jjj。为了使样本映射到jjj个类别中,则权重矩阵(W=[w1,w2,⋯ ,wj]W = \left[ {{w_1},{w_2}, \cdots ,{w_j}} \right]W=[w1​,w2​,⋯,wj​])的维度为m×jm \times jm×j。
样本xix_ixi​属于类别k(k∈{0,1,⋯ ,j−1})k(k \in \left\{ {0,1, \cdots ,j - 1} \right\})k(k∈{0,1,⋯,j−1})的概率为:

p(yi=k∣xi;W)=ewkTxi∑l=1jewlTxip\left( {{y_i} = k\left| {{x_i}} \right.;W} \right) = \frac{{{e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }}p(yi​=k∣xi​;W)=l=1∑j​ewlT​xi​ewkT​xi​​

将样本属于所有类别的概率合并后样本的概率为:
p(yi∣xi;W)=∏l=1k(ewkTxi∑l=1jewlTxi)I{yi=l}p\left( {{y_i}\left| {{x_i};W} \right.} \right) = \prod\limits_{l = 1}^k {{{\left( {\frac{{{e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }}} \right)}^{I\left\{ {{y_i} = l} \right\}}}}p(yi​∣xi​;W)=l=1∏k​⎝⎛​l=1∑j​ewlT​xi​ewkT​xi​​⎠⎞​I{yi​=l}

其中,当样本xix_ixi​属于类别lll时I{yi=l}=1{I\left\{ {{y_i} = l} \right\}}=1I{yi​=l}=1,否则I{yi=l}=0{I\left\{ {{y_i} = l} \right\}}=0I{yi​=l}=0。

1.2、损失函数

如本人之前博文中描述的那样,基于概率的机器学习算法的损失函数为负的log似然函数。
似然函数如下:

LW=∏i=1mp(yi∣xi;W)=∏i=1m∏l=1k(ewkTxi∑l=1jewlTxi)I{yi=l}{L_W} = \prod\limits_{i = 1}^m {p\left( {{y_i}\left| {{x_i};W} \right.} \right) = } \prod\limits_{i = 1}^m {\prod\limits_{l = 1}^k {{{\left( {\frac{{{e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }}} \right)}^{I\left\{ {{y_i} = l} \right\}}}} }LW​=i=1∏m​p(yi​∣xi​;W)=i=1∏m​l=1∏k​⎝⎛​l=1∑j​ewlT​xi​ewkT​xi​​⎠⎞​I{yi​=l}

损失函数为:

lW=−1m[∑i=1m∑l=1kI{yi=l}log⁡(ewkTxi∑l=1jewlTxi)]{l_W} = - \frac{1}{m}\left[ {\sum\limits_{i = 1}^m {\sum\limits_{l = 1}^k {I\left\{ {{y_i} = l} \right\}} } \log \left( {\frac{{{e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }}} \right)} \right]lW​=−m1​⎣⎡​i=1∑m​l=1∑k​I{yi​=l}log⎝⎛​l=1∑j​ewlT​xi​ewkT​xi​​⎠⎞​⎦⎤​

1.3、梯度下降法训练模型

权重矩阵的第kkk个分量wkw_kwk​的梯度方向为∂lW∂wk\frac{{\partial {l_W}}}{{\partial {w_k}}}∂wk​∂lW​​。
当yi=ky_i=kyi​=k时:

∂lW∂wk=−1m∑i=1m[∑l=1jewlTxi−ewkTxi∑l=1jewlTxi⋅xi]\frac{{\partial {l_W}}}{{\partial {w_k}}} = - \frac{1}{m}\sum\limits_{i = 1}^m {\left[ {\frac{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} - {e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }} \cdot {x_i}} \right]}∂wk​∂lW​​=−m1​i=1∑m​⎣⎡​l=1∑j​ewlT​xi​l=1∑j​ewlT​xi​−ewkT​xi​​⋅xi​⎦⎤​

当y≠ky \ne ky̸​=k时:

∂lW∂wk=−1m∑i=1m[−ewkTxi∑l=1jewlTxi⋅xi]\frac{{\partial {l_W}}}{{\partial {w_k}}} = - \frac{1}{m}\sum\limits_{i = 1}^m {\left[ {\frac{{ - {e^{w_k^T{x_i}}}}}{{\sum\limits_{l = 1}^j {{e^{w_l^T{x_i}}}} }} \cdot {x_i}} \right]}∂wk​∂lW​​=−m1​i=1∑m​⎣⎡​l=1∑j​ewlT​xi​−ewkT​xi​​⋅xi​⎦⎤​

将两种情况合并:

∂lW∂wk=−1m∑i=1m[xi⋅(I{yi=k}−p(yi=k∣xi;W))]\frac{{\partial {l_W}}}{{\partial {w_k}}} = - \frac{1}{m}\sum\limits_{i = 1}^m {\left[ {{x_i} \cdot \left( {I\left\{ {{y_i} = k} \right\} - p\left( {{y_i} = k\left| {{x_i}} \right.;W} \right)} \right)} \right]}∂wk​∂lW​​=−m1​i=1∑m​[xi​⋅(I{yi​=k}−p(yi​=k∣xi​;W))]

2、java实现

完整java代码与样本地址:https://github.com/shiluqiang/Softmax-Regression-java
首先,导入数据样本和标签。

import java.util.*;
import java.io.*;
public class LoadData {//导入样本特征public static double[][] Loadfeature(String filename) throws IOException{File f = new File(filename);FileInputStream fip = new FileInputStream(f);// 构建FileInputStream对象InputStreamReader reader = new InputStreamReader(fip,"UTF-8");// 构建InputStreamReader对象StringBuffer sb = new StringBuffer();while(reader.ready()) {sb.append((char) reader.read());}reader.close();fip.close();//将读入的数据流转换为字符串String sb1 = sb.toString();//按行将字符串分割,计算二维数组行数String [] a = sb1.split("\n");int n = a.length;System.out.println("二维数组行数为:" + n);//计算二维数组列数String [] a0 = a[0].split("\t");int m = a0.length;System.out.println("二维数组列数为:" + m);double [][] feature = new double[n][m];for (int i = 0; i < n; i ++) {String [] tmp = a[i].split("\t");for(int j = 0; j < m; j ++) {if (j == m-1) {feature[i][j] = (double) 1;}else {feature[i][j] = Double.parseDouble(tmp[j]);}            }           }return feature;        }//导入样本标签public static double[] LoadLabel(String filename) throws IOException{File f = new File(filename);FileInputStream fip = new FileInputStream(f);// 构建FileInputStream对象InputStreamReader reader = new InputStreamReader(fip,"UTF-8");// 构建InputStreamReader对象,编码与写入相同StringBuffer sb = new StringBuffer();while(reader.ready()) {sb.append((char) reader.read());}reader.close();fip.close();//将读入的数据流转换为字符串String sb1 = sb.toString();//按行将字符串分割,计算二维数组行数String [] a = sb1.split("\n");int n = a.length;System.out.println("二维数组行数为:" + n);//计算二维数组列数String [] a0 = a[0].split("\t");int m = a0.length;System.out.println("二维数组列数为:" + m);double [] Label = new double[n];for (int i = 0; i < n; i ++) {String [] tmp = a[i].split("\t");Label[i] = Double.parseDouble(tmp[m-1]);            }return Label;      }public static int LabelNum(double [] Label) {int n = Label.length;double [] LabelTmp = new double [n];System.arraycopy(Label, 0, LabelTmp, 0, n); int labelNum = 1;Arrays.sort(LabelTmp);for(int i = 1; i < n; i ++) {if (LabelTmp[i] != LabelTmp[i-1]) {labelNum ++;}}return labelNum;}}

然后,利用梯度下降算法优化模型。

public class SRtrainGradientDescent {int paraNum; //权重参数的个数double rate; //学习率int samNum; //样本个数double [][] feature; //样本特征矩阵double [] Label;//样本标签int maxCycle; //最大迭代次数int labelNum; //标签个数//初始化构造器public SRtrainGradientDescent(double [][] feature, double [] Label, int paraNum,double rate, int samNum,int maxCycle,int labelNum) {this.feature = feature;this.Label = Label;this.maxCycle = maxCycle;this.paraNum = paraNum;this.rate = rate;this.samNum = samNum; this.labelNum = labelNum;}// 权值矩阵初始化public double [][] ParaInitialize(int paraNum,int labelNum) {double [][] W = new double[paraNum][labelNum];for (int i = 0; i < paraNum; i ++) {for (int j = 0; j < labelNum; j ++) {W[i][j] =  1.0;}           }return W;      }//计算假设函数的分子部分public double [][] err(double[][] W, double [][] feature){double [][] errMatrix = new double[feature.length][W[0].length];for (int i = 0; i < feature.length; i ++) {for (int j = 0; j < W[0].length; j ++) {double tmp = 0;for (int n = 0; n < W.length; n ++) {tmp = tmp + feature[i][n] * W[n][j];}errMatrix[i][j] = Math.exp(tmp);}}return errMatrix;}//计算假设函数的分母部分public double [] errSum(double [][] errMatrix) {double [] errsum = new double[errMatrix.length];for (int i = 0; i < errMatrix.length; i ++) {double tmp = 0;for (int j = 0; j < errMatrix[0].length; j ++) {tmp = tmp - errMatrix[i][j];}errsum[i] = tmp;}return errsum;}//计算假设函数的负数矩阵public double [][] errFunction(double [][] errMatrix, double [] errsum){double [][] errResult = new double [errMatrix.length][errMatrix[0].length];for (int i = 0; i < errMatrix.length; i ++) {for (int j = 0; j < errMatrix[0].length; j ++) {errResult[i][j] = errMatrix[i][j] / errsum[i];}}return errResult;}//计算预测损失函数值public double cost(double [] Label,double [][] errMatrix, double [] errsum,int samNum) {double sum_cost = 0;for(int i = 0; i < samNum; i ++) {int m = (int) Label[i];if ((errMatrix[i][m] / (- errsum[i])) > 0) {sum_cost -= Math.log(errMatrix[i][m] / (- errsum[i]));}else {sum_cost -= 0;}}return sum_cost / samNum;}public double [][] Update(double [][] feature, double[] Label, int maxCycle, double rate,int paraNum,int labelNum, int samNum){//初始化权重矩阵double [][] weights = ParaInitialize(paraNum,labelNum);// 循环迭代优化权重矩阵for(int i = 0; i < maxCycle; i ++) {//假设函数的分子部分double [][] errMatrix = err(weights,feature);//假设函数的分母部分的负数double [] errsum = errSum(errMatrix);if (i % 10 == 0) {double cost = cost(Label,errMatrix,errsum,samNum);System.out.println("第" + i + "次迭代的损失函数值为:" + cost);}//假设函数的负数矩阵double [][] errResult = errFunction(errMatrix,errsum);for (int j = 0; j < samNum; j ++) {int m = (int) Label[j];errResult[j][m] += 1; }            // 计算权重矩阵中每个权重参数的梯度方向double [][] delt_weights = new double[paraNum][labelNum];for (int iter1 = 0; iter1 < paraNum; iter1 ++) {for (int iter2 = 0; iter2 < labelNum; iter2 ++) {double tmp = 0;for (int iter3 = 0; iter3 < samNum; iter3 ++) {tmp = tmp + feature[iter3][iter1] * errResult[iter3][iter2];}delt_weights[iter1][iter2] = tmp / samNum;}}for (int iter1 = 0; iter1 < paraNum; iter1 ++) {for (int iter2 = 0; iter2 < labelNum; iter2 ++) {weights[iter1][iter2] = weights[iter1][iter2] + rate * delt_weights[iter1][iter2];}         }       }return weights; }
}

其次,模型测试。

public class SRTest {//从矩阵的一行中找到最大元素对应的指针public static int MaxSearch(double [] array) {int  pointer = 0;double tmp = 0;for (int j = 0; j < array.length; j ++) {if (array[j] > tmp) {tmp = array[j];pointer = j;}}return pointer;}//计算预测结果public static double [] SRtest(int labelNum,int samNum,int paraNum,double [][] feature,double [][] weights) {double [][] pre_results = new double [samNum][labelNum];for (int i = 0; i < samNum; i ++) {for (int j = 0; j < labelNum; j ++) {double tmp = 0;for (int n = 0; n < paraNum; n ++) {tmp += feature[i][n] * weights[n][j];}pre_results[i][j] = tmp;}}double [] results = new double [samNum];for (int m = 0; m < samNum; m ++) {results[m] = MaxSearch(pre_results[m]);}return results;}
}

再其次,保存模型权重矩阵与预测结果。

import java.io.*;public class SaveModelResults {public static void savemodel(String filename, double [][] W) throws IOException{File f = new File(filename);// 构建FileOutputStream对象FileOutputStream fip = new FileOutputStream(f);// 构建OutputStreamWriter对象OutputStreamWriter writer = new OutputStreamWriter(fip,"UTF-8");//计算模型矩阵的元素个数int n = W.length;int m = W[0].length;StringBuffer sb = new StringBuffer();for (int j = 0; j < n-1; j ++) {for (int i = 0; i < m-1; i ++) {sb.append(String.valueOf(W[j][i]));sb.append("\t");}sb.append(String.valueOf(W[j][m-1]));sb.append("\n");}for (int i = 0; i < m-1; i ++) {sb.append(String.valueOf(W[n-1][i]));sb.append("\t");}sb.append(String.valueOf(W[n-1][m-1]));        String sb1 = sb.toString();writer.write(sb1);writer.close();fip.close();}public static void saveresults(String filename, double [] results) throws IOException{File f = new File(filename);// 构建FileOutputStream对象FileOutputStream fip = new FileOutputStream(f);// 构建OutputStreamWriter对象OutputStreamWriter writer = new OutputStreamWriter(fip,"UTF-8");//计算的预测结果中元素个数int n = results.length;StringBuffer sb = new StringBuffer();for (int i = 0; i < n; i ++) {sb.append(results[i]);sb.append("\n");            }String sb1 = sb.toString();writer.write(sb1);writer.close();fip.close();      }
}

主类:

import java.io.IOException;public class SRMain {public static void main(String[] args) throws IOException{// filename String filename = "SoftInput.txt";// 导入样本特征和标签double [][] feature = LoadData.Loadfeature(filename);double [] Label = LoadData.LoadLabel(filename); int labelNum = LoadData.LabelNum(Label);// 参数设置int samNum = feature.length;int paraNum = feature[0].length;double rate = 0.04;int maxCycle = 10000;// SR模型训练SRtrainGradientDescent SR = new SRtrainGradientDescent(feature,Label,paraNum,rate,samNum,maxCycle,labelNum);double [][] weights = SR.Update(feature, Label, maxCycle, rate, paraNum, labelNum, samNum);//保存模型String model_path = "wrights.txt";SaveModelResults.savemodel(model_path, weights);//模型测试double [] results = SRTest.SRtest(labelNum, samNum, paraNum, feature, weights);String results_path = "results.txt";SaveModelResults.saveresults(results_path, results);}
}

参考资料

1、《Python机器学习实战》

【机器学习】Softmax Regression算法原理与java实现相关推荐

  1. 机器学习之KNN算法原理

    机器学习之KNN算法原理 1 KNN算法简介 2 算法思想 3 多种距离度量公式 ① 欧氏距离(Euclidean distance) ② 曼哈顿距离(Manhattan distance) ③ 闵式 ...

  2. 【机器学习】总结了九种机器学习集成分类算法(原理+代码)

    大家好,我是云朵君! 导读: 本文是分类分析(基于Python实现五大常用分类算法(原理+代码))第二部分,继续沿用第一部分的数据.会总结性介绍集成分类算法原理及应用,模型调参数将不在本次讨论范围内. ...

  3. 总结了九种机器学习集成分类算法(原理+代码)

    公众号后台回复"图书",了解更多号主新书内容作者:云朵君来源: 数据STUDIO 导读: 本文是分类分析(基于Python实现五大常用分类算法(原理+代码))第二部分,继续沿用第一 ...

  4. 机器学习-KNN最近邻算法原理及实践

    一.KNN最近邻算法原理 1.KNN(k-nearest neighbor)最近邻算法是一种有监督学习(存在特征标签)能够解决分类与回归问题的方法,是一个理论上比较成熟的.也是最简单的机器学习算法之一 ...

  5. RC4算法原理、Java实现RC4加密算法、DES AES RC4算法比较

    DES AES RC4算法比较 根据密钥类型的不同,加密算法分为对称和非对称两种.对称加密系统是指加密和解密均采用同一把密钥. 对称加密算法是最常用的加密算法,优势在于算法公开,计算量小,加密效率高. ...

  6. 克鲁斯卡尔算法原理及JAVA代码

    原理 视频 最小生成树(Kruskal(克鲁斯卡尔)和Prim(普里姆))算法动画演示_哔哩哔哩_bilibili 文章 聊一聊数据结构图的克鲁斯卡尔算法 - 简书 (jianshu.com) 根据前 ...

  7. 机器学习:推荐系统算法原理解析

    0.序言   最近因为PAC平台自动化的需求,开始探坑推荐系统.这个乍一听去乐趣无穷的课题,对于算法大神们来说是这样的: 而对于刚接触这个领域的我来说,是这样的:   在深坑外围徘徊了一周后,我整理了 ...

  8. java 实现气泡 源码分析_冒泡排序算法原理及JAVA实现代码

    冒泡排序法:关键字较小的记录好比气泡逐趟上浮,关键字较大的记录好比石块下沉,每趟有一块最大的石块沉底. 算法本质:(最大值是关键点,肯定放到最后了,如此循环)每次都从第一位向后滚动比较,使最大值沉底, ...

  9. java 排序原理_简单选择排序算法原理及java实现(超详细)

    简单选择排序的原理 简单选择排序的原理非常简单,即在待排序的数列中寻找最大(或者最小)的一个数,与第 1 个元素进行交换,接着在剩余的待排序的数列中继续找最大(最小)的一个数,与第 2 个元素交换.以 ...

最新文章

  1. 008_多配置文件的引入
  2. 如果你是IT技术人员,请思考这15个问题
  3. VMware桥接模式无法连网
  4. 如何解决IE6的3像素问题?
  5. jQuery 时间获取扩展
  6. oracle触发器(转载收集)
  7. java21天打卡day20-集合
  8. Python静态方法 类方法
  9. 联合哈佛大学,Google 要用人工智能来预测地震余震
  10. Android APK XML解析与反编译方法
  11. 【Maven】阿里镜像仓库地址
  12. 8个国外在线学习网课的网站
  13. 论文是否被SCI, EI 检索
  14. 数组、单链表和双链表介绍 以及 双向链表的C/C++/Java实现
  15. linux彻底清除历史记录
  16. 荣耀9青春版能刷鸿蒙系统吗,华为荣耀9青春版刷机教程_荣耀9青春版强刷升级更新系统包...
  17. [React hooks] Antd Form: Instance created by `useForm` is not connected to any Form element.Forget t
  18. Java PDF 水印
  19. 细数那些年我们一起玩过的Unity3D游戏(unity开发的游戏有哪些)
  20. 2020年金属非金属矿山(露天矿山)主要负责人考试报名及金属非金属矿山(露天矿山)主要负责人考试资料

热门文章

  1. cout输出精确小数点
  2. JAVA编程语言的基础知识(2)
  3. java反编译的语句_Java开发网 - 请教,java反编译的问题
  4. 面试题 10.09. 排序矩阵查找
  5. 大数据应用需注意哪些安全问题
  6. Matlab停在载入界面,试图在Matlab用户界面中实现保存/加载对象功能时遇到了困难...
  7. 集群提交HBase代码报错:Caused by: java.lang.ClassNotFoundException: org.apache.hadoop.hbase.HBaseConfiguratio
  8. Shell脚本编程之(二)简单的Shell脚本练习
  9. Toolbar的简单使用和封装
  10. [USACO13MAR]Farm Painting【枚举】