OpenCV3.3中给出了逻辑回归(logistic regression)的实现,即cv::ml::LogisticRegression类,类的声明在include/opencv2/ml.hpp文件中,实现在modules/ml/src/lr.cpp文件中,它既支持两分类,也支持多分类,其中:

(1)、cv::ml::LogisticRegression类继承自cv::ml::StateModel,而cv::ml::StateModel又继承自cv::Algorithm;

(2)、setLearningRate函数用来设置学习率,getLearningRate函数用来获取学习率值;

(3)、setIterations函数用来设置迭代次数,getIterations函数用来获取迭代次数值;

(4)、setRegularization函数用来设置采用哪种正则化方法,目前支持两种L1 norm和L2 norm,正则化方法主要用来防止过拟合,getRegularization函数用来获取采用哪种正则化方法;

(5)、setTrainMethod函数用来设置采用哪种训练方法,目前支持两种Batch和Mini-Batch, getTrainMethod函数用来获取采用哪种训练方法;

(6)、setMiniBatchSize函数用来设置在Mini-Batch梯度下降训练方法中每一个step采集的训练样本数,getMiniBatchSize函数用来获取每一个step采集的训练样本数;

(7)、setTermCriteria函数用来设置终止训练的条件,包括迭代次数和期望的精度,getTermCriteria用来获取终止训练的条件;

(8)、get_learnt_thetas函数用来获取训练参数;

(9)、create函数为static, new一个LogisticRegressionImpl用来创建一个LogisticRegression对象;

(10)、train函数(使用基类StatModel中的)进行训练;

(11)、predict函数用于预测;

(12)、save函数(使用基类Algorithm中的)保存已训练好的model,支持xml,yaml,json格式;

(13)、load函数用来load已训练好的model;

以下为两分类测试代码:训练数据集为从MNIST中train中随机选取的0、1各10个图像;测试数据集为从MNIST中test中随机选取的0、1各10个图像,如下图,其中第一排前10个0用于训练,后10个0用于测试;第二排前10个1用于训练,后10个1用于测试:

#include "opencv.hpp"
#include <string>
#include <vector>
#include <memory>
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include "common.hpp"// Logistic Regression ///
static void show_image(const cv::Mat& data, int columns, const std::string& name)
{cv::Mat big_image;for (int i = 0; i < data.rows; ++i) {big_image.push_back(data.row(i).reshape(0, columns));}cv::imshow(name, big_image);cv::waitKey(0);
}static float calculate_accuracy_percent(const cv::Mat& original, const cv::Mat& predicted)
{return 100 * (float)cv::countNonZero(original == predicted) / predicted.rows;
}int test_opencv_logistic_regression_train()
{const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };cv::Mat data, labels, result;for (int i = 1; i < 11; ++i) {const std::vector<std::string> label{ "0_", "1_" };for (const auto& value : label) {std::string name = std::to_string(i);name = image_path + value + name + ".jpg";cv::Mat image = cv::imread(name, 0);if (image.empty()) {fprintf(stderr, "read image fail: %s\n", name.c_str());return -1;}data.push_back(image.reshape(0, 1));}}data.convertTo(data, CV_32F);//show_image(data, 28, "train data");std::unique_ptr<float[]> tmp(new float[20]);for (int i = 0; i < 20; ++i) {if (i % 2 == 0) tmp[i] = 0.f;else tmp[i] = 1.f;}labels = cv::Mat(20, 1, CV_32FC1, tmp.get());cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::create();lr->setLearningRate(0.00001);lr->setIterations(100);lr->setRegularization(cv::ml::LogisticRegression::REG_DISABLE);lr->setTrainMethod(cv::ml::LogisticRegression::MINI_BATCH);lr->setMiniBatchSize(1);CHECK(lr->train(data, cv::ml::ROW_SAMPLE, labels));const std::string save_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" }; // .xml, .yaml, .jsonslr->save(save_file);return 0;
}int test_opencv_logistic_regression_predict()
{const std::string image_path{ "E:/GitCode/NN_Test/data/images/digit/handwriting_0_and_1/" };cv::Mat data, labels, result;for (int i = 11; i < 21; ++i) {const std::vector<std::string> label{ "0_", "1_" };for (const auto& value : label) {std::string name = std::to_string(i);name = image_path + value + name + ".jpg";cv::Mat image = cv::imread(name, 0);if (image.empty()) {fprintf(stderr, "read image fail: %s\n", name.c_str());return -1;}data.push_back(image.reshape(0, 1));}}data.convertTo(data, CV_32F);//show_image(data, 28, "test data");std::unique_ptr<int[]> tmp(new int[20]);for (int i = 0; i < 20; ++i) {if (i % 2 == 0) tmp[i] = 0;else tmp[i] = 1;}labels = cv::Mat(20, 1, CV_32SC1, tmp.get());const std::string model_file{ "E:/GitCode/NN_Test/data/logistic_regression_model.xml" };cv::Ptr<cv::ml::LogisticRegression> lr = cv::ml::LogisticRegression::load(model_file);lr->predict(data, result);fprintf(stdout, "predict result: \n");std::cout << "actual: " << labels.t() << std::endl;std::cout << "target: " << result.t() << std::endl;fprintf(stdout, "accuracy: %.2f%%\n", calculate_accuracy_percent(labels, result));return 0;
}

测试代码中,test_opencv_logistic_regression_train函数用于训练,训练结果会产生一个叫logistic_regression_model.xml的model文件;test_opencv_logistic_regression_predict函数用于预测,预测结果如下,由结果可知,预测全部正确:

GitHub: https://github.com/fengbingchun/NN_Test

OpenCV3.3中逻辑回归(Logistic Regression)使用举例相关推荐

  1. 逻辑回归(Logistic Regression, LR)又称为逻辑回归分析,是分类和预测算法中的一种。通过历史数据的表现对未来结果发生的概率进行预测。例如,我们可以将购买的概率设置为因变量,将用户的

    逻辑回归(Logistic Regression, LR)又称为逻辑回归分析,是分类和预测算法中的一种.通过历史数据的表现对未来结果发生的概率进行预测.例如,我们可以将购买的概率设置为因变量,将用户的 ...

  2. 逻辑回归(Logistic Regression)简介及C++实现

    逻辑回归(Logistic Regression):该模型用于分类而非回归,可以使用logistic sigmoid函数( 可参考:http://blog.csdn.net/fengbingchun/ ...

  3. Coursera公开课笔记: 斯坦福大学机器学习第六课“逻辑回归(Logistic Regression)”

    Coursera公开课笔记: 斯坦福大学机器学习第六课"逻辑回归(Logistic Regression)" 斯坦福大学机器学习第六课"逻辑回归"学习笔记,本次 ...

  4. 斯坦福大学机器学习第四课“逻辑回归(Logistic Regression)”

    斯坦福大学机器学习第四课"逻辑回归(Logistic Regression)" 本次课程主要包括7部分: 1) Classification(分类) 2) Hypothesis R ...

  5. 逻辑回归(logistic regression)的本质——极大似然估计

    文章目录 1 前言 2 什么是逻辑回归 3 逻辑回归的代价函数 4 利用梯度下降法求参数 5 结束语 6 参考文献 1 前言 逻辑回归是分类当中极为常用的手段,因此,掌握其内在原理是非常必要的.我会争 ...

  6. CS229学习笔记(3)逻辑回归(Logistic Regression)

    1.分类问题 你要预测的变量yyy是离散的值,我们将学习一种叫做逻辑回归 (Logistic Regression) 的算法,这是目前最流行使用最广泛的一种学习算法. 从二元的分类问题开始讨论. 我们 ...

  7. 机器学习笔记04:逻辑回归(Logistic regression)、分类(Classification)

    之前我们已经大概学习了用线性回归(Linear Regression)来解决一些预测问题,详见: 1.<机器学习笔记01:线性回归(Linear Regression)和梯度下降(Gradien ...

  8. 札记_ML——《统计学习方法》逻辑回归logistic regression)

    统计学习方法:五. 逻辑回归logistic regression 逻辑回归logistic regression Logistic的起源 1).概念logistic回归又称logistic回归分析, ...

  9. 逻辑回归(Logistic Regression

    6.1 分类问题 参考文档: 6 - 1 - Classification (8 min).mkv 在这个以及接下来的几个视频中,开始介绍分类问题. 在分类问题中,你要预测的变量 y y y 是离散的 ...

最新文章

  1. 使用Prometheus和Grafana实现SLO
  2. Docker操作系统理解
  3. php 单词替换,如何在PHP中替换字符串中的单词?
  4. mplus 软件_Mplus 8.3 Combo Version 多元统计分析软件(Win)
  5. 数据结构练习题之树和图(附答案与解析)
  6. C语言:编写一个程序,从键盘读入一个矩形的两个边的值(整数),求矩形面积
  7. (维基百科LaTeX公式显示异常)解决方法
  8. P2900 [USACO08MAR]土地征用Land Acquisition
  9. Spring整合JavaMail
  10. java提前多久显示,Java当前日期/时间比原始时间提前1小时显示
  11. 5.13 综合案例2.0-火焰检测系统(2.2版本接口有更新)
  12. ​【原型设计】8种原型设计工具介绍​
  13. 搜狗输入法输出特殊符号快捷键
  14. linux修复win10启动失败,win10自动修复失败开不了机解决方法
  15. Python项目-Day26-数据加密-hash加盐加密-token-jwt
  16. 初学python者自学anaconda的正确姿势是什么?
  17. 高新技术企业申报材料汇编
  18. cuda安装正常,nvcc -V却没有任何显示
  19. Install Qualcomm Development Environment
  20. 用技巧] Http请求偶尔超时+总结各种超时死掉的可能和相应的解决办法

热门文章

  1. python基础知识整理 第三节 :函数
  2. OpenCV(十二)漫水填充算法
  3. 【机器视觉案例】(5) AI视觉,远程手势控制虚拟计算器,附python完整代码
  4. C语言截取指定长度子字符串方法
  5. Udacity机器人软件工程师课程笔记(三十二) - 卡尔曼滤波器 - 一维卡尔曼滤波器 - 多维卡尔曼滤波器 - 拓展卡尔曼滤波器(EKF)
  6. 在CentOS上安装TCP协议性能评测工具tcpdive
  7. 关于std::string 在 并发场景下 __grow_by_and_replace free was not allocated 的异常问题
  8. 【Android】基于A星寻路算法的简单迷宫应用
  9. mybatis简化实现思路
  10. oracle与mysql创建表时的区别