tiny-cnn是一个基于CNN的开源库,它的License是BSD 3-Clause。作者也一直在维护更新,对进一步掌握CNN很有帮助,因此下面介绍下tiny-cnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/nyanp/tiny-cnn下载源码:

$ git clone https://github.com/nyanp/tiny-cnn.git  版本号为77d80a8,更新日期2016.01.22

2.      源文件中已经包含了vs2013工程,vc/tiny-cnn.sln,默认是win32的,examples/main.cpp需要OpenCV的支持,这里新建一个x64的控制台工程tiny-cnn;

3.      仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-cnn.cpp文件;

4.      将examples/mnist中test.cpp和train.cpp文件中的代码复制到test_tiny-cnn.cpp文件中;

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <tiny_cnn/tiny_cnn.h>
#include <opencv2/opencv.hpp>using namespace tiny_cnn;
using namespace tiny_cnn::activation;// rescale output to 0-100
template <typename Activation>
double rescale(double x)
{Activation a;return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
}void construct_net(network<mse, adagrad>& nn);
void train_lenet(std::string data_dir_path);
// convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img);
void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, vec_t& data);
void recognize(const std::string& dictionary, const std::string& filename, int target);int main()
{//trainstd::string data_path = "D:/Download/MNIST";train_lenet(data_path);//teststd::string model_path = "D:/Download/MNIST/LeNet-weights";std::string image_path = "D:/Download/MNIST/";int target[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };for (int i = 0; i < 10; i++) {char ch[15];sprintf(ch, "%d", i);std::string str;str = std::string(ch);str += ".png";str = image_path + str;recognize(model_path, str, target[i]);}std::cout << "ok!" << std::endl;return 0;
}void train_lenet(std::string data_dir_path) {// specify loss-function and learning strategynetwork<mse, adagrad> nn;construct_net(nn);std::cout << "load models..." << std::endl;// load MNIST datasetstd::vector<label_t> train_labels, test_labels;std::vector<vec_t> train_images, test_images;parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte",&train_labels);parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte",&train_images, -1.0, 1.0, 2, 2);parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte",&test_labels);parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",&test_images, -1.0, 1.0, 2, 2);std::cout << "start training" << std::endl;progress_display disp(train_images.size());timer t;int minibatch_size = 10;int num_epochs = 30;nn.optimizer().alpha *= std::sqrt(minibatch_size);// create callbackauto on_enumerate_epoch = [&](){std::cout << t.elapsed() << "s elapsed." << std::endl;tiny_cnn::result res = nn.test(test_images, test_labels);std::cout << res.num_success << "/" << res.num_total << std::endl;disp.restart(train_images.size());t.restart();};auto on_enumerate_minibatch = [&](){disp += minibatch_size;};// trainingnn.train(train_images, train_labels, minibatch_size, num_epochs,on_enumerate_minibatch, on_enumerate_epoch);std::cout << "end training." << std::endl;// test and show resultsnn.test(test_images, test_labels).print_detail(std::cout);// save networksstd::ofstream ofs("D:/Download/MNIST/LeNet-weights");ofs << nn;
}void construct_net(network<mse, adagrad>& nn) {// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X falsestatic const bool tbl[] = {O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O};
#undef O
#undef X// construct netsnn << convolutional_layer<tan_h>(32, 32, 5, 1, 6)  // C1, 1@32x32-in, 6@28x28-out<< average_pooling_layer<tan_h>(28, 28, 6, 2)   // S2, 6@28x28-in, 6@14x14-out<< convolutional_layer<tan_h>(14, 14, 5, 6, 16,connection_table(tbl, 6, 16))              // C3, 6@14x14-in, 16@10x10-in<< average_pooling_layer<tan_h>(10, 10, 16, 2)  // S4, 16@10x10-in, 16@5x5-out<< convolutional_layer<tan_h>(5, 5, 5, 16, 120) // C5, 16@5x5-in, 120@1x1-out<< fully_connected_layer<tan_h>(120, 10);       // F6, 120-in, 10-out
}void recognize(const std::string& dictionary, const std::string& filename, int target) {network<mse, adagrad> nn;construct_net(nn);// load netsstd::ifstream ifs(dictionary.c_str());ifs >> nn;// convert imagefile to vec_tvec_t data;convert_image(filename, -1.0, 1.0, 32, 32, data);// recognizeauto res = nn.predict(data);std::vector<std::pair<double, int> > scores;// sort & print top-3for (int i = 0; i < 10; i++)scores.emplace_back(rescale<tan_h>(res[i]), i);std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>());for (int i = 0; i < 3; i++)std::cout << scores[i].second << "," << scores[i].first << std::endl;std::cout << "the actual digit is: " << scores[0].second << ", correct digit is: "<<target<<std::endl;// visualize outputs of each layer//for (size_t i = 0; i < nn.depth(); i++) {// auto out_img = nn[i]->output_to_image();//  cv::imshow("layer:" + std::to_string(i), image2mat(out_img));//}visualize filter shape of first convolutional layer//auto weight = nn.at<convolutional_layer<tan_h>>(0).weight_to_image();//cv::imshow("weights:", image2mat(weight));//cv::waitKey(0);
}// convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img) {cv::Mat ori(img.height(), img.width(), CV_8U, &img.at(0, 0));cv::Mat resized;cv::resize(ori, resized, cv::Size(), 3, 3, cv::INTER_AREA);return resized;
}void convert_image(const std::string& imagefilename,double minv,double maxv,int w,int h,vec_t& data) {auto img = cv::imread(imagefilename, cv::IMREAD_GRAYSCALE);if (img.data == nullptr) return; // cannot open, or it's not an imagecv::Mat_<uint8_t> resized;cv::resize(img, resized, cv::Size(w, h));// mnist dataset is "white on black", so negate requiredstd::transform(resized.begin(), resized.end(), std::back_inserter(data),[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}

5.      编译时会提示几个错误,解决方法是:

(1)、error C4996,解决方法:将宏_SCL_SECURE_NO_WARNINGS添加到属性的预处理器定义中;

(2)、调用for_函数时,error C2668,对重载函数的调用不明教,解决方法:将for_中的第三个参数强制转化为size_t类型;

6.      运行程序,train时,运行结果如下图所示:

7.      对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:

通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中6和9被误识为5和1:

GitHub:https://github.com/fengbingchun/NN

tiny-cnn开源库的使用(MNIST)相关推荐

  1. 深度学习开源库tiny-dnn的使用(MNIST)

    tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause.之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层.此开源库 ...

  2. 本周Github项目精选:高效CNN推理库、多款AlphaGo实现

    来源:PaperWeekly 本文共1700字,建议阅读6分钟. 本文为你精选近期Github上的高效CNN推理库.多款AlphaGo实现项目等,一起Star和Fork吧- 01 ELF #Faceb ...

  3. 【开源】Caffe、TensorFlow、MXnet三个开源库对比

    from:http://www.wtoutiao.com/p/1cbxddO.html 最近Google开源了他们内部使用的深度学习框架TensorFlow[1],结合之前开源的MXNet[2]和Ca ...

  4. C++开源库大全(转)

    http://blog.csdn.net/chen19870707/article/details/40427645 程序员要站在巨人的肩膀上,C++拥有丰富的开源库,这里包括:标准库.Web应用框架 ...

  5. 人脸识别开源库face_recognition的简单介绍

    人脸识别开源库face_recognition的简单介绍 原文出处: https://blog.xugaoxiang.com/ai/face-recognition-cnn.html 软硬件环境 ub ...

  6. 开源库 | 监控视频中的目标检测与跟踪

    介绍一份来自卡内基梅隆大学开源的主要用于监控视频中目标检测与跟踪的开源库:Object_Detection_Tracking . 其赢得了 2019 Activities in Extended Vi ...

  7. 赶在 2018 年前推荐 30 个最火爆的开源库

    点击上方"CSDN",选择"置顶公众号" 关键时刻,第一时间送达! 作者简介:杨守乐,CSDN 知名博主,关注 Android.Java 领域,现在主要专注于音 ...

  8. Android开发:开源库集合

    开源库大全 目录 抽屉菜单 ListView WebView SwitchButton 按钮 点赞按钮 进度条 TabLayout 图标 下拉刷新 ViewPager 图表(Chart) 菜单(Men ...

  9. Github安卓流行布局开源库

    抽屉菜单 MaterialDrawer ★7337 - 安卓抽屉效果实现方案 Side-Menu.Android ★3865 - 创意边侧菜单 FlowingDrawer ★1744 - 向右滑动流动 ...

最新文章

  1. 从 Flutter 的视频渲染到 App 落地经验
  2. crc16modbus查表法_查表法计算CRC16校验值
  3. BZOJ4723[POI2017]Flappy Bird——模拟
  4. 零基础如何学好Python?Python有哪些必须学的知识?
  5. 快来围观一下JavaScript的Proxy
  6. txt形式进行传输WebShell图文演示!
  7. 【报告分享】ibm构建认知型企业:实现ai赋能的企业转型.pdf(附下载链接)
  8. windows里面的批处理命令不停地处理同一条命令
  9. MATLAB在声学理论基础中的应用,MATLAB在声学理论基础中的应用
  10. 获得客户端真实IP的方法
  11. 循环小题题库存档(期末复习)
  12. C#中的动态类型(Dynamic)
  13. 亚马逊云科技平台上的无服务器 WebSocket
  14. 线代——余子式和代数余子式
  15. Apache POI读合并单元格
  16. 家装企业如何开展网络营销?
  17. 【项目实战】批量导出excel,并打包zip文件【连载中】
  18. 什么是5G聚合路由器?
  19. 架构师,你需要了解的git知识都在这里了
  20. 魔乐科技安卓开发教程----李兴华----08APPWidget

热门文章

  1. 数字图像处理——第四章 频率域滤波
  2. CornerNet代码解析——损失函数
  3. Unity中创建本地多人游戏完整案例视频教程 Learn To Create A Local Multiplayer Game In Unity
  4. LTE - PRACH 时频资源介绍
  5. blktrace 工具集使用 及其实现原理
  6. ceph bluestore 源码分析:ceph-osd内存查看方式及控制源码分析
  7. Linux文件系统:概览(思维导图)
  8. readelf 读取动态链接表命令
  9. 排序算法之直接插入排序
  10. Python高级函数--map/reduce