论文:http://www.robots.ox.ac.uk/~szheng/papers/CRFasRNN.pdf,
CRF as RNN论文的代码在https://github.com/torrvision/crfasrnn可以找到。
有一个在线的demo可以演示http://www.robots.ox.ac.uk/~szheng/crfasrnndemo

这篇博文主要是记录自己对CRF as RNN中的 MultiStageMeanfieldLayer 的解读。涉及到的文件有multi_stage_meanfield的头文件与实现、meanfield的头文件与实现。

这个代码是基于老版本的caffe,大部分的层的头文件都在vision_layers.hpp中,
对应的位置是class MultiStageMeanfieldLayer 和 class MeanfieldIteration,比较简单,MultiStageMeanfieldLayer才是真正的层,而MeanfieldIteration是一个辅助类,直接看实现。

层运算的入口便是LayerSetUp,前面都是成员变量的初始化,接着是读取spatial.par和bilateral.par。 然后是计算spatial_kernel,直接调用了
compute_spatial_kernel()函数:

template <typename Dtype>
void MultiStageMeanfieldLayer<Dtype>::compute_spatial_kernel(float* const output_kernel) {for (int p = 0; p < num_pixels_; ++p) {output_kernel[2*p] = static_cast<float>(p % width_) / theta_gamma_;output_kernel[2*p + 1] = static_cast<float>(p / width_) / theta_gamma_;}
}

这个功能很简单,就是用一个2倍于像素点个数的矩阵,存储 (列/theta_gamma_,行/theta_gamma_)的kernel.
接下来就是将spatial_lattice_初始化。然后将后面计算需要的一元项先分配内存。由于需要使用多次的meanfield,所以接下来就为每个meanfield进行了一次初始化。就这样,层就可以启动了。

接下来就是Forward_cpu

/*** Performs filter-based mean field inference given the image and unaries.** bottom[0] - Unary terms* bottom[1] - Softmax input/Output from the previous iteration (a copy of the unary terms if this is the first stage).* bottom[2] - RGB images** top[0] - Output of the mean field inference (not normalized).*/
template <typename Dtype>
void MultiStageMeanfieldLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {split_layer_bottom_vec_[0] = bottom[0];split_layer_->Forward(split_layer_bottom_vec_, split_layer_top_vec_);// Initialize the bilateral lattices.bilateral_lattices_.resize(num_);for (int n = 0; n < num_; ++n) {compute_bilateral_kernel(bottom[2], n, bilateral_kernel_buffer_.get());bilateral_lattices_[n].reset(new ModifiedPermutohedral());bilateral_lattices_[n]->init(bilateral_kernel_buffer_.get(), 5, num_pixels_);// Calculate bilateral filter normalization factors.Dtype* norm_output_data = bilateral_norms_.mutable_cpu_data() + bilateral_norms_.offset(n);bilateral_lattices_[n]->compute(norm_output_data, norm_feed_.get(), 1);for (int i = 0; i < num_pixels_; ++i) {norm_output_data[i] = 1.f / (norm_output_data[i] + 1e-20f);}}for (int i = 0; i < num_iterations_; ++i) {meanfield_iterations_[i]->PrePass(this->blobs_, &bilateral_lattices_, &bilateral_norms_);meanfield_iterations_[i]->Forward_cpu();}
}

功能就是让前面的多次meanfield每一个跑一次。

下面是Backward_cpu()

/*** Backprop through filter-based mean field inference.*/
template<typename Dtype>
void MultiStageMeanfieldLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom) {for (int i = (num_iterations_ - 1); i >= 0; --i) {meanfield_iterations_[i]->Backward_cpu();}vector<bool> split_layer_propagate_down(1, true);split_layer_->Backward(split_layer_top_vec_, split_layer_propagate_down, split_layer_bottom_vec_);// Accumulate diffs from mean field iterations.for (int blob_id = 0; blob_id < this->blobs_.size(); ++blob_id) {Blob<Dtype>* cur_blob = this->blobs_[blob_id].get();if (this->param_propagate_down_[blob_id]) {caffe_set(cur_blob->count(), Dtype(0), cur_blob->mutable_cpu_diff());for (int i = 0; i < num_iterations_; ++i) {const Dtype* diffs_to_add = meanfield_iterations_[i]->blobs()[blob_id]->cpu_diff();caffe_axpy(cur_blob->count(), Dtype(1.), diffs_to_add, cur_blob->mutable_cpu_diff());}}}
}

开始就是让每个MeanfieldIteration进行一个Backward_cpu。然后有两个for循环,第一个就是循环所有的blob,第二个就是把每个blob的所有迭代时的diff相加,放到对应blob的diff中。

PS:

  • 有关caffe数学计算的,可以在math_functions中找到,也可以看看http://www.cnblogs.com/jianyingzhou/p/4444728.html。
  • 关于cblas计算的内容,可以参考http://www.math.utah.edu/software/lapack/lapack-blas.html
  • Blob的解读http://www.tuicool.com/articles/6rUVNf2
  • CRF可以参考这篇http://blog.csdn.net/thesby/article/details/50969788

————————————————-我是分割线2016.06.23———————————————————————————–
我把这个版本的caffe已经merge到了最新的官方版caffe,因为它的原始版本实在太老了。下载地址在此.

CRF as RNN 代码解读相关推荐

  1. 【深度学习】深入浅出CRF as RNN(以RNN形式做CRF后处理)

    [深度学习]深入浅出CRF as RNN(以RNN形式做CRF后处理) 文章目录 1 概述 2 目标 3 思路 4 简述 5 论文原文5.1 Introduction5.2 相关工作5.3 关键步骤 ...

  2. 线性条件随机场代码解读

      NER中CRF是必不可少的环节,特地看了一遍CRF相关理论以及allennlp中CRF的代码,特在这里笔记记录下来! 1.线性CRF简介 1.1一般形式   关于线性条件随机场的详细介绍,请参考李 ...

  3. 深度学习(三十三)CRF as RNN语义分割-未完待续

    CRF as RNN语义分割 原文地址:http://blog.csdn.net/hjimce/article/details/50888915 作者:hjimce 一.相关理论 本篇博文主要讲解文献 ...

  4. PredRNN++:网络结构和代码解读

    已经有很多帖子对PredRNN++的理论和改进效果进行了解读,不再赘述.直接分析结构和代码. Causal LSTM 单元 三层级联结构: 第一层(蓝色框)类似传统的LSTM结构用于更新时间状态C(t ...

  5. 元学习之《Matching Networks for One Shot Learning》代码解读

    元学习系列文章 optimization based meta-learning <Model-Agnostic Meta-Learning for Fast Adaptation of Dee ...

  6. Pytorch LSTM 代码解读及自定义双向 LSTM 算子

    Pytorch LSTM 代码解读及自定义双向 LSTM 算子 1. 理论 关于 LSTM 的理论部分可以参考 Paper Long Short-Term Memory Based Recurrent ...

  7. 200行代码解读TDEngine背后的定时器

    作者 | beyondma来源 | CSDN博客 导读:最近几周,本文作者几篇有关陶建辉老师最新的创业项目-TdEngine代码解读文章出人意料地引起了巨大的反响,原以为C语言已经是昨日黄花,不过从读 ...

  8. 装逼一步到位!GauGAN代码解读来了

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:游璐颖,福州大学,Datawhale成员 AI神笔马良 如何装逼一 ...

  9. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

最新文章

  1. Amazon Redshift 架构
  2. python split 倒数第一个_请教一个在python中该如何去掉split之后的第一个单词?
  3. primer3批量设计引物
  4. .NET Core实战项目之CMS 第七章 设计篇-用户权限极简设计全过程
  5. Python:绘图保存时出现空白图像的解决和如何保存图片
  6. 奇虎360大战腾讯QQ 高潮迭起用户受伤
  7. Golang——数据类型使用细节详解
  8. [Node.js]操作mysql
  9. 阿里企业邮箱产品优势、功能、版本介绍
  10. amr文件怎么转换成mp3?
  11. 「武汉理工大学 软件工程复习」第三章 | 软件需求
  12. 表格比手机屏幕宽时不压缩,可左右滚动,格子内容不换行
  13. U盘读不出来的解决办法
  14. MongoMongo简介
  15. golang学习笔记之string转换
  16. MATLAB r2014a 下载+安装+激活
  17. 如何看待996现象,996工作模式是种什么样的体验?
  18. css3宽度变大动画_H5 直播的疯狂点赞动画是如何实现的?
  19. tsqlconnection连接datasnap出现connection closed gracefully错误的解决办法
  20. 中国广电即将放号,或代表着中国移动反攻,联通先慌了

热门文章

  1. C/C++小程序学习:n*n魔方矩阵实现每行、每列、每一对角线上的元素之和相等
  2. 工控攻击,黑客组织GhostSec 称入侵以色列55 家Berghof PLC
  3. navicat中导出数据表结构为word格式
  4. mini2440 linux驱动程序,基于linux的mini2440 led驱动及应用程序
  5. win10任务栏固定图标删不掉
  6. 这 7 门 编程语言最适合新手学习
  7. 创建ROS消息(msg)和服务(srv)
  8. rabbitmq细节说明与效率(三)
  9. Spring Cloud--Sleuth+Zipkin 链路跟踪/订单的流量削峰
  10. 实时性是指计算机多媒体系统中声音及活动,《计算机应用基础》电子教案