模型实现代码,关键是train函数和predict函数,都很容易。

#include <iostream>
#include <string>
#include <math.h>
#include "LogisticRegression.h"
using namespace std;LogisticRegression::LogisticRegression(int size,    // Nint in,     // n_inint out      // n_out)
{N = size;n_in = in;n_out = out;// initialize W, b// W[n_out][n_in], b[n_out]W = new double*[n_out];for(int i=0; i<n_out; i++) W[i] = new double[n_in];b = new double[n_out];for(int i=0; i<n_out; i++) {for(int j=0; j<n_in; j++) {W[i][j] = 0;}b[i] = 0;}
}LogisticRegression::~LogisticRegression()
{for(int i=0; i<n_out; i++) delete[] W[i];delete[] W;delete[] b;
}void LogisticRegression::train (int *x,    // the input from input nodes in training setint *y,    // the output from output nodes in training setdouble lr    // the learning rate)
{// the probability of P(y|x)double *p_y_given_x = new double[n_out];// the tmp variable which is not necessary being an arraydouble *dy = new double[n_out];// step 1: calculate the output of softmax given inputfor(int i=0; i<n_out; i++) {// initializep_y_given_x[i] = 0;for(int j=0; j<n_in; j++) {// the weight of networksp_y_given_x[i] += W[i][j] * x[j];}// the biasp_y_given_x[i] += b[i];}// the softmax valuesoftmax(p_y_given_x);// step 2: update the weight of networks// w_new = w_old + learningRate * differential (导数)//      = w_old + learningRate * x (1{y_i=y} - p_yi_given_x) //      = w_old + learningRate * x * (y - p_y_given_x)for(int i=0; i<n_out; i++) {dy[i] = y[i] - p_y_given_x[i];for(int j=0; j<n_in; j++) {W[i][j] += lr * dy[i] * x[j] / N;}b[i] += lr * dy[i] / N;}delete[] p_y_given_x;delete[] dy;
}void LogisticRegression::softmax (double *x)
{double max = 0.0;double sum = 0.0;// step1: get the max in the X vectorfor(int i=0; i<n_out; i++) if(max < x[i]) max = x[i];// step 2: normalization and softmax// normalize -- 'x[i]-max', it's not necessary in traditional LR.// I wonder why it appears here? for(int i=0; i<n_out; i++) {x[i] = exp(x[i] - max);sum += x[i];} for(int i=0; i<n_out; i++) x[i] /= sum;
}void LogisticRegression::predict(int *x,   // the input from input nodes in testing setdouble *y   // the calculated softmax probability)
{// get the softmax output value given the current networksfor(int i=0; i<n_out; i++) {y[i] = 0;for(int j=0; j<n_in; j++) {y[i] += W[i][j] * x[j];}y[i] += b[i];}softmax(y);
}

转载于:https://www.cnblogs.com/dyllove98/p/3194108.html

【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp相关推荐

  1. 笔记 | 吴恩达Coursera Deep Learning学习笔记

    向AI转型的程序员都关注了这个号☝☝☝ 作者:Lisa Song 微软总部云智能高级数据科学家,现居西雅图.具有多年机器学习和深度学习的应用经验,熟悉各种业务场景下机器学习和人工智能产品的需求分析.架 ...

  2. 网上某位牛人的deep learning学习笔记汇总

    目录(?)[-] 作者tornadomeet 出处httpwwwcnblogscomtornadomeet 欢迎转载或分享但请务必声明文章出处 Deep learning一基础知识_1 Deep le ...

  3. CV视觉论文Deep learning学习笔记(一)

    论文介绍和监督学习(introduction of paper and supervision of learning) 1. 论文介绍和作者介绍 作者:论文作者是2018年图灵奖得主yoshua B ...

  4. 【Deep Learning学习笔记】Deep learning for nlp without magic_Bengio_ppt_acl2012

    看完180多页的ppt,真心不容易.记得流水账如下: Five reason to explore Deep Learning: 1. learning representation; 2. the ...

  5. Deep Learning论文笔记之(一)K-means特征学习

    Deep Learning论文笔记之(一)K-means特征学习 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但老感 ...

  6. Deep Learning论文笔记之(五)CNN卷积神经网络代码理解

    Deep Learning论文笔记之(五)CNN卷积神经网络代码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但 ...

  7. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  8. Deep Learning论文笔记之(八)Deep Learning最新综述

    Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...

  9. Deep Learning论文笔记之(七)深度网络高层特征可视化

    Deep Learning论文笔记之(七)深度网络高层特征可视化 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文,但老感 ...

最新文章

  1. 公司要上监控,选型调研下 Zabbix 和 Prometheus
  2. c语言通讯录写入文件,学C三个月了,学了文件,用C语言写了个通讯录程序
  3. linux网页无法连接到服务器,linux – 无法连接到SMTP服务器
  4. Android “再按一次退出“
  5. twisted系列教程六–继续重构twisted poetry client
  6. c++ char*初始化_[零食时间]C/C++ 字符串全家桶(字符串表示/定义、字符串输入输出、易错点等)上半桶...
  7. heartbeat如何监控程序_一文看懂MyCAT 命令行监控命令,监控调优必备
  8. js 如何在浏览器中获取当前位置的经纬度
  9. 安卓zip解压软件_安卓zip文件压缩RAR解压app下载-安卓zip文件压缩RAR解压安卓版 v3.0.4...
  10. python单词查询_Python实现单词查询文件查找
  11. 报错 mysql 1194
  12. 基于图像识别的火灾探测技术
  13. Android SqlDelight详解和Demo例子
  14. 支持向量机识别数字集(数据采集+模型训练+预测输出)
  15. ORACLE进阶(十)start with connect by 实现递归查询
  16. 自然语言处理NLP之信息检索
  17. html5 video 隐藏全屏按钮,如何隐藏HTML5视频标签的全屏按钮?
  18. emoji mysql存储
  19. 2018年京东春招笔试题
  20. 使用SBench 6为任意波形发生器创建,捕获和传输波形

热门文章

  1. 一个简单易用的导出Excel类
  2. 使用Apriori进行关联分析(一)
  3. tornado 简易教程
  4. 3.1 采购管理规划
  5. DataGridView控件中显示图片及其注意事项 【z】
  6. freebsd点到点的ipsec ***
  7. Struts+DAO框架搭建完成!(源码)
  8. Cisco路由器故障诊断技术(3)
  9. 网络传播动力学_通过简单的规则传播动力
  10. 数据治理 主数据 元数据_我们对数据治理的误解