前一段时间,一直在忙框架方面的工作,偶尔也会帮业务同学去优化优化使用TensorFlow的代码,也加上之前看了dmlc/relay,nnvm的代码,觉得蛮有意思,也想分别看下TensorFlow的Graph IR、PaddlePaddle的Graph IR,上周五,看代码看的正津津有味的时候,看到某个数据竞赛群里面讨论东西,不记得具体内容,大概说的是框架的代码实现, 有几位算法大佬说看底层源码比较麻烦,因为比较早从框架,这块代码通常都还能看,问题都不大,和群里小伙伴吹水了半天之后,感觉是可以写篇如何看TensorFlow或者其他框架底层源码的劝退文了。

利其器

首先,一定是要找个好工作来看源码,很多人推荐vs code、sublime,我试过vs code+bazel的,好像也不错,但是后面做c++适应了clion之后,除了资源要求比较多,还是蛮不错的,使用c++一般推荐使用cmake来看编译项目,但是TensorFlow是bazel的,无法直接支持,最开始,这边是自己写简单的cmake,能够实现简单的代码跳转,但是涉及到比如protobuf之类的编译过后产生的文件无法跳转,比较麻烦,不够纯粹,很早之前知道clion有bazel的组件,但是不知道为啥一直搞不通,上周找时间再试了试,发现竟然通了,使用之后,这才是看tf源码的真正方式:

首先,选择合适版本的bazel,千万不能太高,也不能太低,这里我拉的是TF2.0的代码,使用bazel 0.24.0刚刚好,切记千万别太高也比太低, 千万别太高也比太低,千万别太高也比太低


其次,clion上选择bazel的插件

第三步,./configure,然后按你的意图选择合适的编译配置


第四步,导入bazel项目:File=>Import Bazel Project

经过上面几步之后,接下来就要经过比较长时间的等待,clion会导入bazel项目,然后编译整个项目,这个耗时视你机器和网络而定(顺便提一句,最好保证比较畅通的访问github的网络,另外由于上面targets:all,会编译TensorFlow所有的项目,如果你知道是什么意思,可以自己修改,如果不知道的话我先不提了,默认就好,期间会有很多Error出现,放心,问题不大,因为会默认编译所有的模块)
经过上面之后,我们就可以愉快的看代码啦,连protobuf生成的文件都很开心的跳转啦

极简版c++入门

TensorFlow大部分人都知道,底层是c++写的,然后外面包了一层python的api,既然底层是c++写的,那么用c++也是可以用来训练模型的,大部分人应该都用过c++或者java去载入frozen的模型,然后做serving应用在业务系统上,应该很少人去使用c++来训练模型,既然我们这里要读代码,我们先尝试看看用c++写模型,文件路径如下图:


主要函数就那么几个:CreateGraphDef, ConcurrentSteps, ConcurrentSessions:

CreateGraphDef 构造计算图

GraphDef CreateGraphDef() {// TODO(jeff,opensource): This should really be a more interesting// computation.  Maybe turn this into an mnist model instead?Scope root = Scope::NewRootScope();using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)// A = [3 2; -1 0].  Using Const<float> means the result will be a// float tensor even though the initializer has integers.auto a = Const<float>(root, {{3, 2}, {-1, 0}});// x = [1.0; 1.0]auto x = Const(root.WithOpName("x"), {{1.f}, {1.f}});// y = A * xauto y = MatMul(root.WithOpName("y"), a, x);// y2 = y.^2auto y2 = Square(root, y);// y2_sum = sum(y2).  Note that you can pass constants directly as// inputs.  Sum() will automatically create a Const node to hold the// 0 value.auto y2_sum = Sum(root, y2, 0);// y_norm = sqrt(y2_sum)auto y_norm = Sqrt(root, y2_sum);// y_normalized = y ./ y_normDiv(root.WithOpName("y_normalized"), y, y_norm);GraphDef def;TF_CHECK_OK(root.ToGraphDef(&def));return def;
}

定义graph 节点 root, 然后定义常数变量a (shape为2*2), x (shape为2* 1),然后 y = A * x, y2 = y.2, y2_sum = sum(y2), y_norm = sqrt(y2_sum), y_normlized = y ./ y_norm。代码很简洁, 看起来一目了然,
然后是ConcurrentSteps

void ConcurrentSteps(const Options* opts, int session_index) {// Creates a session.SessionOptions options;std::unique_ptr<Session> session(NewSession(options));GraphDef def = CreateGraphDef();if (options.target.empty()) {graph::SetDefaultDevice(opts->use_gpu ? "/device:GPU:0" : "/cpu:0", &def);}TF_CHECK_OK(session->Create(def));// Spawn M threads for M concurrent steps.const int M = opts->num_concurrent_steps;std::unique_ptr<thread::ThreadPool> step_threads(new thread::ThreadPool(Env::Default(), "trainer", M));for (int step = 0; step < M; ++step) {step_threads->Schedule([&session, opts, session_index, step]() {// Randomly initialize the input.Tensor x(DT_FLOAT, TensorShape({2, 1}));auto x_flat = x.flat<float>();x_flat.setRandom();std::cout << "x_flat: " << x_flat << std::endl;Eigen::Tensor<float, 0, Eigen::RowMajor> inv_norm =x_flat.square().sum().sqrt().inverse();x_flat = x_flat * inv_norm();// Iterations.std::vector<Tensor> outputs;for (int iter = 0; iter < opts->num_iterations; ++iter) {outputs.clear();TF_CHECK_OK(session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));CHECK_EQ(size_t{2}, outputs.size());const Tensor& y = outputs[0];const Tensor& y_norm = outputs[1];// Print out lambda, x, and y.std::printf("%06d/%06d %sn", session_index, step,DebugString(x, y).c_str());// Copies y_normalized to x.x = y_norm;}});}// Delete the threadpool, thus waiting for all threads to complete.step_threads.reset(nullptr);TF_CHECK_OK(session->Close());
}

新建一个session,然后设置10个线程来计算,来执行:

std::vector<Tensor> outputs;for (int iter = 0; iter < opts->num_iterations; ++iter) {outputs.clear();TF_CHECK_OK(session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));CHECK_EQ(size_t{2}, outputs.size());const Tensor& y = outputs[0];const Tensor& y_norm = outputs[1];// Print out lambda, x, and y.std::printf("%06d/%06d %sn", session_index, step,DebugString(x, y).c_str());// Copies y_normalized to x.x = y_norm;}

每次计算之后,x=y_norm,这里的逻辑其实就是为了计算矩阵A的最大eigenvalue, 重复执行x = y/y_norm; y= A*x;
编译:

bazel build //tensorflow/cc:tutorials_example_trainer

执行结果,前面不用太care是我打印的一些调试输出:

简单的分析

上面简单的c++入门实例之后,可以抽象出TensorFlow的逻辑:

  1. 构造graphdef,使用TensorFlow本身的Graph API,利用算子去构造一个逻辑计算的graph,可以试上述简单地计算eigenvalue,也可以是复杂的卷积网络,这里是涉及到Graph IR的东西,想要了解的话,我建议先看下nnvm和relay,才会有初步的概念;
  2. 用于构造graphdef的各种操作,比如上述将达到的Square、MatMul,这些操作可以是自己写的一些数学操作也可以是TensorFlow本身封装一些数学计算操作,可以是MKL的封装,也可以是cudnn的封装,当然也可以是非数学库,如TFRecord的读取;
  3. Session的构造,新建一个session,然后用于graph外与graph内部的数据交互:session->Run({{"x", x}}, {"y:0", "y_normalized:0"}, {}, &outputs));这里不停地把更新的x王graph里喂来计算y与y_normalized,然后将x更新为y_normalized;

GraphDef这一套,太过复杂,不适合演示如何看TF源码,建议大家先有一定的基础知识之后,再看,这里我们摘出一些算法同学感兴趣的,比如Square这个怎么在TF当中实现以及绑定到对应操作

  1. 代码中直接跳转到Square类,如下图;

2.很明显看到Square类的定义,其构造函数,接收一个scope还有一个input, 然后我们找下具体实现,如下图:

3.同目录下, http://math_ops.cc,看实现逻辑,我们是构造一个名为Square的op,然后往scope里更新,既然如此,肯定是预先有保存名为Square的op,接下来我们看下图:

4.这里讲functor::square注册到"Square"下,且为UnaryOp,这个我不知道怎么解释,相信用过eigen的人都知道,不知道的话去google下,很容易理解,且支持各种数据类型;

5.那么看起来,square的实现就在functor::square,我们再进去看看,集成base模板类,且看起来第二个模板参数为其实现的op,再跳转看看:

 6.最后,我们到达了最终的实现逻辑:operator()和packetOp,也看到了最终的实现,是不是没有想象的那么难。

更重要一点

看完了上面那些,基本上会知道怎么去看TensorFlow的一些基础的代码,如果你了解graph ir这套,可以更深入去理解下,这个过程中,如果对TensorFlow各个文件逻辑感兴趣,不妨去写写测试用例,TensorFlow很多源码文件都有对应的test用例,我们可以通过Build文件来查看,比如我想跑下http://client_session_test.cc这里的测试用例


我们看一下Build文件中

这里表明了对应的编译规则,然后我们只需要

bazel build //tensorflow/cc:client_client_session_test

然后运行相应的测试程序即可

更更重要的一点

上面把如何看TensorFlow代码的小经验教给各位,但是其实这个只是真正的开始,无论TensorFlow、MXNet、PaddlePaddle异或是TVM这些,单纯去看代码,很难理解深刻其中原理,需要去找相关行业的paper,以及找到行业的精英去请教,去学习。目前网上ml system的资料还是蛮多的,有点『乱花迷人眼』的感觉,也没有太多的课程来分享这块的工作,十分期望这些框架的官方分享这些框架的干货,之后我也会在学习中总结一些资料,有机会的话分享给大家。最后,这些东西确实是很复杂,作者在这块也是还是懵懵懂懂,希望能花时间把这些内在的东西搞清楚,真的还蛮有意思的。

也欢迎大家关注我的同名微信公众号 小石头的码疯窝(xiaoshitou_ml_tech),或者通过公众号加我的个人微信进行讨论

tensorflow源码编译教程_极简入门TensorFlow C++源码相关推荐

  1. tensorflow源码编译教程_源码编译安装tensorflow 1.8

    参考官方指南 基本要求 官网测试过的源代码配置如下: image 也就是说,按照这个版本安装的话不应该再产生版本的问题了. 我的配置 ubuntu 16 python 2.7 nccl 2.3 gcc ...

  2. java源码影视源码搭建教程_新版千月影视app源码+搭建教程

    使用notepad++批量替换URL[http://]为你的域名(被替换的域名访问有成品不能发布 需要修改的到前台confing里面查询),替换名称[鲸鹰影视]为你的应用名称: 服务端: 1.将替换好 ...

  3. 安卓系统源码编译系列(一)——下载安卓系统源码教程

    最近需要编译安卓系统,咨询了一个编译过安卓系统的朋友,说是下载源码就得下载两天,于是做好了长期抗战的准备,开始了下载安卓源码的旅程.在刚开始下载时,可以参照的内容只有官方教程,于是跟着官方教程一步一步 ...

  4. tx2+opencv源码编译教程(tx2+opencv4.4.0+opencv_contrib-4.4.0)

    tx2+opencv源码编译教程(tx2+opencv4.4.0+opencv_contrib-4.4.0) 一.卸载TX2上已安装的opencv 打开终端,输入以下指令卸载已经安装的opencv: ...

  5. 超好看的网站极简导航网址网站源码模板

    介绍: 超好看的网站极简导航网址网站源码模板 网盘下载地址: http://kekewl.org/vMD3vuKtwrC 图片:

  6. 最好用的Redis Desktop Manager 0.9.3 版本下载 以及源码编译教程

    文章目录 一.前言 二.编译教程 2.1 [redis destop manager 的源码地址](https://github.com/uglide/RedisDesktopManager) 2.2 ...

  7. tensorflow平台极简方式_TensorFlow极简入门教程

    原标题:TensorFlow极简入门教程 随着 TensorFlow 在研究及产品中的应用日益广泛,很多开发者及研究者都希望能深入学习这一深度学习框架.本文介绍了TensorFlow 基础,包括静态计 ...

  8. 安卓系统源码编译系列(1)——下载安卓系统源码教程

    安卓系统源码编译系列(一)--下载安卓系统源码教程 最近需要编译安卓系统,咨询了一个编译过安卓系统的朋友,说是下载源码就得下载两天,于是做好了长期抗战的准备,开始了下载安卓源码的旅程.在刚开始下载时, ...

  9. excel像素画教程_像素画新手教程:极简像素画角色分析

    摘要:像素画新手教程:极简像素画角色分析 关键词:像素画,新手教程,极简像素画,角色分析 撰文&编辑:三二 教你画像素画首发 | 公众号 pixelart 本文共755个字,阅读大约需要2分钟 ...

最新文章

  1. Ubuntu安装tomcat
  2. JAVA-JSP内置对象之pageContext对象取得不同范围属性
  3. 【我解C语言面试题系列】013 以单词为单位的翻转字符串
  4. 好嗨哟~谷歌量子神经网络新进展揭秘
  5. sql server 2005 在 windows7 报 IIS Feature Requirement 错误。解决办法。
  6. arthas命令整理:基础命令、jvm相关、class相关命令
  7. 支付宝18年账单已出,你消费了多少钱?
  8. 端到端BPM(带有DMN标记)
  9. javax.net.ssl.SSLException: closing inbound before receiving peer‘s close_notif---SpringCloud工作笔记111
  10. RxJava--takeWhile,takeUntil,(附带filter)的特性总结
  11. mvn package时,报错A required class is missing: com/thoughtworks/xstream/io/HierarchicalStreamDriver...
  12. java实训报告_Java实验报告三
  13. 一些常用的IOS开发网站
  14. 华为荣耀路由器虚拟服务器,华为荣耀路由器登录入口设置指南
  15. Linux C 下的socket网络编程
  16. Chrome浏览器安装Axure插件教程
  17. ant design pro中权限组件Authorized的个人学习
  18. UE GamePlay学习笔记
  19. 终端文本编辑神器--Vim命令详解。如何配置使用Vim、Vim插件?
  20. C#时间显示格式(12小时制VS24小时制)

热门文章

  1. 鼠标划过表格行变色效果JS
  2. AutoBench的使用分析
  3. ps cs6 磨皮插件_磨皮就是几秒的事!2020顶级PS一件磨皮插件DR5、Portaiture分享
  4. java documentlistener_java在DocumentListener中更改文档
  5. ML之xgboost :xgboost.plot_importance()函数的解读
  6. WPS:Excel数据表格查询定位技巧之如何设置加重颜色的十字定位(定位数据更加一目了然)
  7. AI开发者大会之计算机视觉技术实践与应用:2020年7月3日《RPA+AI助力政企实现智能时代的人机协同》、《5G风口到来,边缘计算引领数据中心变革》、《数字化时代金融市场与AI算法如何结合?》
  8. AI:2020年6月22日北京智源大会演讲分享之09:40-10:10Mari 教授《基于显式上下文表征的语言处理》、10:10-10:40周明教授《多语言及多模态任务中的预训练模型》
  9. DL之SqueezeNet:SqueezeNet算法的简介(论文介绍)、架构详解、案例应用等配图集合之详细攻略
  10. Py之twisted:Python库之twisted简介、安装、使用方法等详细攻略