深度学习开源库tiny-dnn的使用(MNIST)
tiny-dnn是一个基于DNN的深度学习开源库,它的License是BSD 3-Clause。之前名字是tiny-cnn是基于CNN的,tiny-dnn与tiny-cnn相关又增加了些新层。此开源库很活跃,几乎每天都有新的提交,因此下面详细介绍下tiny-dnn在windows7 64bit vs2013的编译及使用。
1. 从https://github.com/tiny-dnn/tiny-dnn 下载源码:
$ git clone https://github.com/tiny-dnn/tiny-dnn.git 版本号为6281c1b,更新日期2016.12.03
2. 源文件中已经包含了vs2013工程,vc/vc12/tiny-dnn.sln,默认是win32的,这里新建一个x64的控制台工程tiny-dnn;
3. 仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-dnn.cpp文件;
4. 仿照examples/mnist中test.cpp和train.cpp文件中的代码添加测试代码;
#include "funset.hpp"
#include <string>
#include <algorithm>
#include "tiny_dnn/tiny_dnn.h"static void construct_net(tiny_dnn::network<tiny_dnn::sequential>& nn)
{// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X falsestatic const bool tbl[] = {O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O};
#undef O
#undef X// by default will use backend_t::tiny_dnn unless you compiled// with -DUSE_AVX=ON and your device supports AVX intrinsicstiny_dnn::core::backend_t backend_type = tiny_dnn::core::default_engine();// construct nets: C: convolution; S: sub-sampling; F: fully connectednn << tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(32, 32, 5, 1, 6, // C1, 1@32x32-in, 6@28x28-outtiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::average_pooling_layer<tiny_dnn::activation::tan_h>(28, 28, 6, 2) // S2, 6@28x28-in, 6@14x14-out<< tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(14, 14, 5, 6, 16, // C3, 6@14x14-in, 16@10x10-outconnection_table(tbl, 6, 16),tiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::average_pooling_layer<tiny_dnn::activation::tan_h>(10, 10, 16, 2) // S4, 16@10x10-in, 16@5x5-out<< tiny_dnn::convolutional_layer<tiny_dnn::activation::tan_h>(5, 5, 5, 16, 120, // C5, 16@5x5-in, 120@1x1-outtiny_dnn::padding::valid, true, 1, 1, backend_type)<< tiny_dnn::fully_connected_layer<tiny_dnn::activation::tan_h>(120, 10, // F6, 120-in, 10-outtrue, backend_type);
}static void train_lenet(const std::string& data_dir_path)
{// specify loss-function and learning strategytiny_dnn::network<tiny_dnn::sequential> nn;tiny_dnn::adagrad optimizer;construct_net(nn);std::cout << "load models..." << std::endl;// load MNIST datasetstd::vector<tiny_dnn::label_t> train_labels, test_labels;std::vector<tiny_dnn::vec_t> train_images, test_images;tiny_dnn::parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte", &train_labels);tiny_dnn::parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte", &train_images, -1.0, 1.0, 2, 2);tiny_dnn::parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte", &test_labels);tiny_dnn::parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte", &test_images, -1.0, 1.0, 2, 2);std::cout << "start training" << std::endl;tiny_dnn::progress_display disp(static_cast<unsigned long>(train_images.size()));tiny_dnn::timer t;int minibatch_size = 10;int num_epochs = 30;optimizer.alpha *= static_cast<tiny_dnn::float_t>(std::sqrt(minibatch_size));// create callbackauto on_enumerate_epoch = [&](){std::cout << t.elapsed() << "s elapsed." << std::endl;tiny_dnn::result res = nn.test(test_images, test_labels);std::cout << res.num_success << "/" << res.num_total << std::endl;disp.restart(static_cast<unsigned long>(train_images.size()));t.restart();};auto on_enumerate_minibatch = [&](){disp += minibatch_size;};// trainingnn.train<tiny_dnn::mse>(optimizer, train_images, train_labels, minibatch_size, num_epochs, on_enumerate_minibatch, on_enumerate_epoch);std::cout << "end training." << std::endl;// test and show resultsnn.test(test_images, test_labels).print_detail(std::cout);// save network model & trained weightsnn.save(data_dir_path + "/LeNet-model");
}// rescale output to 0-100
template <typename Activation>
static double rescale(double x)
{Activation a;return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
}static void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, tiny_dnn::vec_t& data)
{tiny_dnn::image<> img(imagefilename, tiny_dnn::image_type::grayscale);tiny_dnn::image<> resized = resize_image(img, w, h);// mnist dataset is "white on black", so negate requiredstd::transform(resized.begin(), resized.end(), std::back_inserter(data),[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}int test_dnn_mnist_train()
{std::string data_dir_path = "E:/GitCode/NN_Test/data";train_lenet(data_dir_path);return 0;
}int test_dnn_mnist_predict()
{std::string model { "E:/GitCode/NN_Test/data/LeNet-model" };std::string image_path { "E:/GitCode/NN_Test/data/images/"};int target[10] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };tiny_dnn::network<tiny_dnn::sequential> nn;nn.load(model);for (int i = 0; i < 10; i++) {std::string str = std::to_string(i);str += ".png";str = image_path + str;// convert imagefile to vec_ttiny_dnn::vec_t data;convert_image(str, -1.0, 1.0, 32, 32, data);// recognizeauto res = nn.predict(data);std::vector<std::pair<double, int> > scores;// sort & print top-3for (int j = 0; j < 10; j++)scores.emplace_back(rescale<tiny_dnn::tan_h>(res[j]), j);std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>());for (int j = 0; j < 3; j++)fprintf(stdout, "%d: %f; ", scores[j].second, scores[j].first);fprintf(stderr, "\n");// save outputs of each layerfor (size_t j = 0; j < nn.depth(); j++) {auto out_img = nn[j]->output_to_image();auto filename = image_path + std::to_string(i) + "_layer_" + std::to_string(j) + ".png";out_img.save(filename);}// save filter shape of first convolutional layerauto weight = nn.at<tiny_dnn::convolutional_layer<tiny_dnn::tan_h>>(0).weight_to_image();auto filename = image_path + std::to_string(i) + "_weights.png";weight.save(filename);fprintf(stdout, "the actual digit is: %d, correct digit is: %d \n\n", scores[0].second, target[i]);}return 0;
}
5. 运行程序,train时,运行结果如下图所示,准确率达到99%以上:
6. 对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:
7. 通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中0,8,9被误识别为2,2,1.
GitHub:https://github.com/fengbingchun/NN_Test
深度学习开源库tiny-dnn的使用(MNIST)相关推荐
- 10倍!微软开源深度学习优化库DeepSpeed,可训练1000亿参数模型
点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 编辑:Sophia 计算机视觉联盟 报道 | 公众号 CVLianMeng 转载于 :微软 AI博士笔记系列推荐 ...
- 微软开源深度学习优化库 DeepSpeed,可训练 1000 亿参数的模型
人工智能的最新趋势是,更大的自然语言模型可以提供更好的准确性,但是由于成本.时间和代码集成的障碍,较大的模型难以训练.微软日前开源了一个深度学习优化库 DeepSpeed,通过提高规模.速度.可用性并 ...
- 基于PyTorch、易上手,细粒度图像识别深度学习工具库Hawkeye开源
转载自丨机器之心 鉴于当前领域内尚缺乏该方面的深度学习开源工具库,南京理工大学魏秀参教授团队用时近一年时间,开发.打磨.完成了 Hawkeye--细粒度图像识别深度学习开源工具库,供相关领域研究人员和 ...
- 【完结】12大深度学习开源框架(caffe,tf,pytorch,mxnet等)快速入门项目
这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpeng2008/yousa ...
- 【完结】给新手的12大深度学习开源框架快速入门项目
文/编辑 | 言有三 这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpen ...
- AI学习笔记(九)从零开始训练神经网络、深度学习开源框架
AI学习笔记之从零开始训练神经网络.深度学习开源框架 从零开始训练神经网络 构建网络的基本框架 启动训练网络并测试数据 深度学习开源框架 深度学习框架 组件--张量 组件--基于张量的各种操作 组件- ...
- 12大深度学习开源框架(caffe,tensorflow,pytorch,mxnet等)汇总详解
这是一篇总结文,给大家来捋清楚12大深度学习开源框架的快速入门,这是有三AI的GitHub项目,欢迎大家star/fork. https://github.com/longpeng2008/yousa ...
- 【github干货】主流深度学习开源框架从入门到熟练
文章首发于微信公众号<有三AI> [github干货]主流深度学习开源框架从入门到熟练 今天送上有三AI学院第一个github项目 01项目背景 目前深度学习框架呈百家争鸣之态势,光是为人 ...
- 深度学习-14:知名的深度学习开源架构和项目
深度学习-14:知名的深度学习开源架构和项目 深度学习原理与实践(开源图书)-总目录 人工智能artificial intelligence,AI是科技研究中最热门的方向之一.像IBM.谷歌.微软.F ...
最新文章
- php趣味小程序,php常用小程序
- 华为交换机堆叠SVF助手(推荐)
- 剑指offer之斐波那契问题(C++/Java双重实现)
- php获取当前世界,php获取网站alexa世界流量排名代码
- phpexcel.php linux,phpexcel在linux系统报错如何解决
- 简约自适应APP下载页源码
- 计算机专业英语常用词汇
- WireMock.NET如何帮助进行.NET Core应用程序的集成测试
- 微软3月补丁星期二修复71个漏洞,其中3个是0day
- 电脑课堂:U盘“无法停止通用卷设备时”的解决方法
- Handler机制的理解与使用
- php mql web开发,自己动手开发多线程异步 MQL5 WebRequest
- 戴尔linux恢复镜像,戴尔 SupportAssist OS Recovery 系统恢复教程
- 苹果手机一直显示搜索服务器,苹果手机safari浏览器搜索页面没有了
- 研究生英语期末复习(Unit3)
- java word转图片(word转pdf再转图片)
- 漂亮的JQUERY SLIDESHOW 磨砂玻璃背景
- 10个最新手机美食APP界面设计欣赏
- 10大全球设计师SNS社区网站
- 【AnySDK】目前对外开放的渠道列表
热门文章
- Opencv中的FaceRecognizer类
- LaneATT调试笔记
- tensorflow object_detect 操作步骤
- 在CentOS 6.9 x86_64的nginx 1.12.2上开启标准模块ngx_http_auth_request_module实录
- Python模块MySQLdb操作mysql出现2019错误:Can't initialize character set utf-8
- 在Unity中制作4种不同的游戏
- 【Linux基础】文件处理实例
- 车辆匹配和平均车速计算
- [dp] Jzoj P5804 简单的序列
- redis实现对账(集合比较)功能