【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp
模型实现代码,关键是train函数和predict函数,都很容易。
#include <iostream>
#include <string>
#include <math.h>
#include "LogisticRegression.h"
using namespace std;LogisticRegression::LogisticRegression(int size, // Nint in, // n_inint out // n_out)
{N = size;n_in = in;n_out = out;// initialize W, b// W[n_out][n_in], b[n_out]W = new double*[n_out];for(int i=0; i<n_out; i++) W[i] = new double[n_in];b = new double[n_out];for(int i=0; i<n_out; i++) {for(int j=0; j<n_in; j++) {W[i][j] = 0;}b[i] = 0;}
}LogisticRegression::~LogisticRegression()
{for(int i=0; i<n_out; i++) delete[] W[i];delete[] W;delete[] b;
}void LogisticRegression::train (int *x, // the input from input nodes in training setint *y, // the output from output nodes in training setdouble lr // the learning rate)
{// the probability of P(y|x)double *p_y_given_x = new double[n_out];// the tmp variable which is not necessary being an arraydouble *dy = new double[n_out];// step 1: calculate the output of softmax given inputfor(int i=0; i<n_out; i++) {// initializep_y_given_x[i] = 0;for(int j=0; j<n_in; j++) {// the weight of networksp_y_given_x[i] += W[i][j] * x[j];}// the biasp_y_given_x[i] += b[i];}// the softmax valuesoftmax(p_y_given_x);// step 2: update the weight of networks// w_new = w_old + learningRate * differential (导数)// = w_old + learningRate * x (1{y_i=y} - p_yi_given_x) // = w_old + learningRate * x * (y - p_y_given_x)for(int i=0; i<n_out; i++) {dy[i] = y[i] - p_y_given_x[i];for(int j=0; j<n_in; j++) {W[i][j] += lr * dy[i] * x[j] / N;}b[i] += lr * dy[i] / N;}delete[] p_y_given_x;delete[] dy;
}void LogisticRegression::softmax (double *x)
{double max = 0.0;double sum = 0.0;// step1: get the max in the X vectorfor(int i=0; i<n_out; i++) if(max < x[i]) max = x[i];// step 2: normalization and softmax// normalize -- 'x[i]-max', it's not necessary in traditional LR.// I wonder why it appears here? for(int i=0; i<n_out; i++) {x[i] = exp(x[i] - max);sum += x[i];} for(int i=0; i<n_out; i++) x[i] /= sum;
}void LogisticRegression::predict(int *x, // the input from input nodes in testing setdouble *y // the calculated softmax probability)
{// get the softmax output value given the current networksfor(int i=0; i<n_out; i++) {y[i] = 0;for(int j=0; j<n_in; j++) {y[i] += W[i][j] * x[j];}y[i] += b[i];}softmax(y);
}
转载于:https://www.cnblogs.com/dyllove98/p/3194108.html
【deep learning学习笔记】注释yusugomori的LR代码 --- LogisticRegression.cpp相关推荐
- 笔记 | 吴恩达Coursera Deep Learning学习笔记
向AI转型的程序员都关注了这个号☝☝☝ 作者:Lisa Song 微软总部云智能高级数据科学家,现居西雅图.具有多年机器学习和深度学习的应用经验,熟悉各种业务场景下机器学习和人工智能产品的需求分析.架 ...
- 网上某位牛人的deep learning学习笔记汇总
目录(?)[-] 作者tornadomeet 出处httpwwwcnblogscomtornadomeet 欢迎转载或分享但请务必声明文章出处 Deep learning一基础知识_1 Deep le ...
- CV视觉论文Deep learning学习笔记(一)
论文介绍和监督学习(introduction of paper and supervision of learning) 1. 论文介绍和作者介绍 作者:论文作者是2018年图灵奖得主yoshua B ...
- 【Deep Learning学习笔记】Deep learning for nlp without magic_Bengio_ppt_acl2012
看完180多页的ppt,真心不容易.记得流水账如下: Five reason to explore Deep Learning: 1. learning representation; 2. the ...
- Deep Learning论文笔记之(一)K-means特征学习
Deep Learning论文笔记之(一)K-means特征学习 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感 ...
- Deep Learning论文笔记之(五)CNN卷积神经网络代码理解
Deep Learning论文笔记之(五)CNN卷积神经网络代码理解 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但 ...
- Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现
Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文, ...
- Deep Learning论文笔记之(八)Deep Learning最新综述
Deep Learning论文笔记之(八)Deep Learning最新综述 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感觉看完 ...
- Deep Learning论文笔记之(七)深度网络高层特征可视化
Deep Learning论文笔记之(七)深度网络高层特征可视化 zouxy09@qq.com http://blog.csdn.net/zouxy09 自己平时看了一些论文,但老感 ...
最新文章
- 公司要上监控,选型调研下 Zabbix 和 Prometheus
- c语言通讯录写入文件,学C三个月了,学了文件,用C语言写了个通讯录程序
- linux网页无法连接到服务器,linux – 无法连接到SMTP服务器
- Android “再按一次退出“
- twisted系列教程六–继续重构twisted poetry client
- c++ char*初始化_[零食时间]C/C++ 字符串全家桶(字符串表示/定义、字符串输入输出、易错点等)上半桶...
- heartbeat如何监控程序_一文看懂MyCAT 命令行监控命令,监控调优必备
- js 如何在浏览器中获取当前位置的经纬度
- 安卓zip解压软件_安卓zip文件压缩RAR解压app下载-安卓zip文件压缩RAR解压安卓版 v3.0.4...
- python单词查询_Python实现单词查询文件查找
- 报错 mysql 1194
- 基于图像识别的火灾探测技术
- Android SqlDelight详解和Demo例子
- 支持向量机识别数字集(数据采集+模型训练+预测输出)
- ORACLE进阶(十)start with connect by 实现递归查询
- 自然语言处理NLP之信息检索
- html5 video 隐藏全屏按钮,如何隐藏HTML5视频标签的全屏按钮?
- emoji mysql存储
- 2018年京东春招笔试题
- 使用SBench 6为任意波形发生器创建,捕获和传输波形