tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause。之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层。此开源库很活跃,几乎每天都有新的提交,因此下面详细介绍下tiny-dnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/tiny-dnn/tiny-dnn 下载源码:

$ git clone https://github.com/tiny-dnn/tiny-dnn.git 版本号为6281c1b,更新日期2016.12.03

2.      源文件中已经包含了vs2013工程,vc/vc12/tiny-dnn.sln,默认是win32的,这里新建一个x64的控制台工程tiny-dnn;

3.      仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-dnn.cpp文件;

4.      仿照examples/mnist中test.cpp和train.cpp文件中的代码添加测试代码;

#include "funset.hpp"
#include <string>
#include <algorithm>
#include "tiny_dnn/tiny_dnn.h"static void construct_net(tiny_dnn::network<tiny_dnn::sequential>& nn)
{// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X falsestatic const bool tbl[] = {O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O};
#undef O
#undef X// by default will use backend_t::tiny_dnn unless you compiled// with -DUSE_AVX=ON and your device supports AVX intrinsicstiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();// construct nets: C: convolution; S: sub-sampling; F: fully connectednn << tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(32, 32, 5, 1, 6,  // C1, 1@32x32-in, 6@28x28-outtiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::average_pooling_layer<tiny_dnn::activation::tan_h>(28, 28, 6, 2)   // S2, 6@28x28-in, 6@14x14-out<< tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(14, 14, 5, 6, 16, // C3, 6@14x14-in, 16@10x10-outconnection_table(tbl, 6, 16),tiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::average_pooling_layer<tiny_dnn::activation::tan_h>(10, 10, 16, 2)  // S4, 16@10x10-in, 16@5x5-out<< tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(5, 5, 5, 16, 120, // C5, 16@5x5-in, 120@1x1-outtiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::fully_connected_layer<tiny_dnn::activation::tan_h>(120, 10,        // F6, 120-in, 10-outtrue, backend_type);
}static void train_lenet(const std::string& data_dir_path)
{// specify loss-function and learning strategytiny_dnn::network<tiny_dnn::sequential> nn;tiny_dnn::adagrad optimizer;construct_net(nn);std::cout << "load models..." << std::endl;// load MNIST datasetstd::vector<tiny_dnn::label_t> train_labels, test_labels;std::vector<tiny_dnn::vec_t> train_images, test_images;tiny_dnn::parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte", &train_labels);tiny_dnn::parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);tiny_dnn::parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte", &test_labels);tiny_dnn::parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);std::cout << "start training" << std::endl;tiny_dnn::progress_display disp(static_cast<unsigned long>(train_images.size()));tiny_dnn::timer t;int minibatch_size = 10;int num_epochs = 30;optimizer.alpha *= static_cast<tiny_dnn::float_t>(std::sqrt(minibatch_size));// create callbackauto on_enumerate_epoch = [&](){std::cout << t.elapsed() << "s elapsed." << std::endl;tiny_dnn::result res = nn.test(test_images, test_labels);std::cout << res.num_success << "/" << res.num_total << std::endl;disp.restart(static_cast<unsigned long>(train_images.size()));t.restart();};auto on_enumerate_minibatch = [&](){disp += minibatch_size;};// trainingnn.train<tiny_dnn::mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs, on_enumerate_minibatch, on_enumerate_epoch);std::cout << "end training." << std::endl;// test and show resultsnn.test(test_images, test_labels).print_detail(std::cout);// save network model & trained weightsnn.save(data_dir_path + "/LeNet-model");
}// rescale output to 0-100
template <typename Activation>
static double rescale(double x)
{Activation a;return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
}static void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, tiny_dnn::vec_t& data)
{tiny_dnn::image<> img(imagefilename, tiny_dnn::image_type::grayscale);tiny_dnn::image<> resized = resize_image(img, w, h);// mnist dataset is "white on black", so negate requiredstd::transform(resized.begin(), resized.end(), std::back_inserter(data),[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}int test_dnn_mnist_train()
{std::string data_dir_path = "E:/GitCode/NN_Test/data";train_lenet(data_dir_path);return 0;
}int test_dnn_mnist_predict()
{std::string model { "E:/GitCode/NN_Test/data/LeNet-model" };std::string image_path { "E:/GitCode/NN_Test/data/images/"};int target[10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };tiny_dnn::network<tiny_dnn::sequential> nn;nn.load(model);for (int i = 0; i < 10; i++) {std::string str = std::to_string(i);str += ".png";str = image_path + str;// convert imagefile to vec_ttiny_dnn::vec_t data;convert_image(str, -1.0, 1.0, 32, 32, data);// recognizeauto res = nn.predict(data);std::vector<std::pair<double, int> > scores;// sort & print top-3for (int j = 0; j < 10; j++)scores.emplace_back(rescale<tiny_dnn::tan_h>(res[j]), j);std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>());for (int j = 0; j < 3; j++)fprintf(stdout, "%d: %f;  ", scores[j].second, scores[j].first);fprintf(stderr, "\n");// save outputs of each layerfor (size_t j = 0; j < nn.depth(); j++) {auto out_img = nn[j]->output_to_image();auto filename = image_path + std::to_string(i) + "_layer_" + std::to_string(j) + ".png";out_img.save(filename);}// save filter shape of first convolutional layerauto weight = nn.at<tiny_dnn::convolutional_layer<tiny_dnn::tan_h>>(0).weight_to_image();auto filename = image_path + std::to_string(i) + "_weights.png";weight.save(filename);fprintf(stdout, "the actual digit is: %d, correct digit is: %d \n\n", scores[0].second, target[i]);}return 0;
}

5.      运行程序,train时,运行结果如下图所示,准确率达到99%以上:

6.  对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:

7. 通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中0,8,9被误识别为2,2,1.

GitHub:https://github.com/fengbingchun/NN_Test

深度学习开源库tiny-dnn的使用(MNIST)相关推荐

  1. 10倍!微软开源深度学习优化库DeepSpeed,可训练1000亿参数模型

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 编辑:Sophia 计算机视觉联盟  报道  | 公众号 CVLianMeng 转载于 :微软 AI博士笔记系列推荐 ...

  2. 微软开源深度学习优化库 DeepSpeed,可训练 1000 亿参数的模型

    人工智能的最新趋势是,更大的自然语言模型可以提供更好的准确性,但是由于成本.时间和代码集成的障碍,较大的模型难以训练.微软日前开源了一个深度学习优化库 DeepSpeed,通过提高规模.速度.可用性并 ...

  3. 基于PyTorch、易上手,细粒度图像识别深度学习工具库Hawkeye开源

    转载自丨机器之心 鉴于当前领域内尚缺乏该方面的深度学习开源工具库,南京理工大学魏秀参教授团队用时近一年时间,开发.打磨.完成了 Hawkeye--细粒度图像识别深度学习开源工具库,供相关领域研究人员和 ...

  4. 【完结】12大深度学习开源框架(caffe,tf,pytorch,mxnet等)快速入门项目

    这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpeng2008/yousa ...

  5. 【完结】给新手的12大深度学习开源框架快速入门项目

    文/编辑 | 言有三 这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpen ...

  6. AI学习笔记(九)从零开始训练神经网络、深度学习开源框架

    AI学习笔记之从零开始训练神经网络.深度学习开源框架 从零开始训练神经网络 构建网络的基本框架 启动训练网络并测试数据 深度学习开源框架 深度学习框架 组件--张量 组件--基于张量的各种操作 组件- ...

  7. 12大深度学习开源框架(caffe,tensorflow,pytorch,mxnet等)汇总详解

    这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpeng2008/yousa ...

  8. 【github干货】主流深度学习开源框架从入门到熟练

    文章首发于微信公众号<有三AI> [github干货]主流深度学习开源框架从入门到熟练 今天送上有三AI学院第一个github项目 01项目背景 目前深度学习框架呈百家争鸣之态势,光是为人 ...

  9. 深度学习-14:知名的深度学习开源架构和项目

    深度学习-14:知名的深度学习开源架构和项目 深度学习原理与实践(开源图书)-总目录 人工智能artificial intelligence,AI是科技研究中最热门的方向之一.像IBM.谷歌.微软.F ...

最新文章

  1. php趣味小程序,php常用小程序
  2. 华为交换机堆叠SVF助手(推荐)
  3. 剑指offer之斐波那契问题(C++/Java双重实现)
  4. php获取当前世界,php获取网站alexa世界流量排名代码
  5. phpexcel.php linux,phpexcel在linux系统报错如何解决
  6. 简约自适应APP下载页源码
  7. 计算机专业英语常用词汇
  8. WireMock.NET如何帮助进行.NET Core应用程序的集成测试
  9. 微软3月补丁星期二修复71个漏洞,其中3个是0day
  10. 电脑课堂:U盘“无法停止通用卷设备时”的解决方法
  11. Handler机制的理解与使用
  12. php mql web开发,自己动手开发多线程异步 MQL5 WebRequest
  13. 戴尔linux恢复镜像,戴尔 SupportAssist OS Recovery 系统恢复教程
  14. 苹果手机一直显示搜索服务器,苹果手机safari浏览器搜索页面没有了
  15. 研究生英语期末复习(Unit3)
  16. java word转图片(word转pdf再转图片)
  17. 漂亮的JQUERY SLIDESHOW 磨砂玻璃背景
  18. 10个最新手机美食APP界面设计欣赏
  19. 10大全球设计师SNS社区网站
  20. 【AnySDK】目前对外开放的渠道列表

热门文章

  1. Opencv中的FaceRecognizer类
  2. LaneATT调试笔记
  3. tensorflow object_detect 操作步骤
  4. 在CentOS 6.9 x86_64的nginx 1.12.2上开启标准模块ngx_http_auth_request_module实录
  5. Python模块MySQLdb操作mysql出现2019错误:Can't initialize character set utf-8
  6. 在Unity中制作4种不同的游戏
  7. 【Linux基础】文件处理实例
  8. 车辆匹配和平均车速计算
  9. [dp] Jzoj P5804 简单的序列
  10. redis实现对账(集合比较)功能