Layer 功能:

是所有的网络层的基类,其中,定义了一些通用的接口,比如前馈,反馈,reshape,setup等。

#ifndef CAFFE_LAYER_H_
#define CAFFE_LAYER_H_#include <algorithm>
#include <string>
#include <vector>#include "caffe/blob.hpp"
#include "caffe/common.hpp"
#include "caffe/layer_factory.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/device_alternate.hpp"namespace caffe {// 功能:所有的网络层的基类,定义的所有的网络层的通用接口。// 前馈接口,必须实现// 反馈接口,需要的时候实现,计算梯度。
template <typename Dtype>
class Layer {public:/*** 每个网络层需要自己定义它的setup而不需要构造函数*/explicit Layer(const LayerParameter& param): layer_param_(param) {//通过网络层参数来构造网络层phase_ = param.phase();if (layer_param_.blobs_size() > 0) {blobs_.resize(layer_param_.blobs_size());for (int i = 0; i < layer_param_.blobs_size(); ++i) {blobs_[i].reset(new Blob<Dtype>());blobs_[i]->FromProto(layer_param_.blobs(i));}}}// 析构函数virtual ~Layer() {}/*** 实现一些通用的设置功能** @param bottom 网络层的输入的shape* @param top 网络层的输出,需要被reshape* 调用 LayerSetUp 来对每一个网络层进行特殊化的处理,* 调用reshape top* 设置 数值权重* 这个方法可以不被重载。*/void SetUp(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {CheckBlobCounts(bottom, top);LayerSetUp(bottom, top);Reshape(bottom, top);SetLossWeights(top);}/*** @brief 设置一些层相关的设置,定义的层需要实现这个方法以及Reshape方法*///设置网络层virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {}/*** @brief 调整top blob以适应bottom blob。*/virtual void Reshape(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) = 0;/*** @brief 给定 bottom blobs, 计算 top blobs 以及 loss.* 每一个网络层都需要定义cpu版本的前馈,可选gpu版本的前馈*///前馈inline Dtype Forward(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top);/*** @brief 给定 top blob 的梯度, 计算 bottom blob 梯度.* @param propagate_down 向量,长度为ibottom   的个数,每个索引值表示是是否将损失梯度值反馈到该bottom中*///反馈inline void Backward(const vector<Blob<Dtype>*>& top,const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom);/*** @brief 返回可学习的参数 blobs.*/vector<shared_ptr<Blob<Dtype> > >& blobs() {return blobs_;}/*** @brief 返回网络层参数*/const LayerParameter& layer_param() const { return layer_param_; }//序列化virtual void ToProto(LayerParameter* param, bool write_diff = false);/*** @brief 返回指定索引的标量损失值。*/inline Dtype loss(const int top_index) const {return (loss_.size() > top_index) ? loss_[top_index] : Dtype(0);}/*** @brief 设置网络层制定索引位置的loss*/inline void set_loss(const int top_index, const Dtype value) {if (loss_.size() <= top_index) {loss_.resize(top_index + 1, Dtype(0));}loss_[top_index] = value;}/*** @brief 返回网络层名字,字符串描述u*/virtual inline const char* type() const { return ""; }//Bottom的blob的确切数目virtual inline int ExactNumBottomBlobs() const { return -1; }//Bottom blob的最小数目virtual inline int MinBottomBlobs() const { return -1; }//Botttom的确切数目virtual inline int MaxBottomBlobs() const { return -1; }//Top Blob的确切数目virtual inline int ExactNumTopBlobs() const { return -1; }//最小的blob的数目virtual inline int MinTopBlobs() const { return -1; }// 最大的blob的数目virtual inline int MaxTopBlobs() const { return -1; }// 是否bottom 和top的数目相同virtual inline bool EqualNumBottomTopBlobs() const { return false; }// 是否自动Top blobvirtual inline bool AutoTopBlobs() const { return false; }//查询某一个bottom是否强制bpvirtual inline bool AllowForceBackward(const int bottom_index) const {return true;}//查询某一个blob是否bpinline bool param_propagate_down(const int param_id) {return (param_propagate_down_.size() > param_id) ?param_propagate_down_[param_id] : false;}//设置某一个blob是否bp。inline void set_param_propagate_down(const int param_id, const bool value) {if (param_propagate_down_.size() <= param_id) {param_propagate_down_.resize(param_id + 1, true);}param_propagate_down_[param_id] = value;}protected:// 网络层参数LayerParameter layer_param_;// 模式Phase phase_;//用blob来存储一系列向量vector<shared_ptr<Blob<Dtype> > > blobs_;//是否bp的向量vector<bool> param_propagate_down_;//存储top的lossvector<Dtype> loss_;//cpu版本的前馈实现virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) = 0;//gpu版本的前馈实现virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {// LOG(WARNING) << "Using CPU code as backup.";return Forward_cpu(bottom, top);}//cpu版本的前馈实现virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom) = 0;//gpu版本的反馈实现virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom) {// LOG(WARNING) << "Using CPU code as backup.";Backward_cpu(top, propagate_down, bottom);}// 核查bootom和top的大小是否与该layer层指定的一致。virtual void CheckBlobCounts(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {if (ExactNumBottomBlobs() >= 0) {CHECK_EQ(ExactNumBottomBlobs(), bottom.size())<< type() << " Layer takes " << ExactNumBottomBlobs()<< " bottom blob(s) as input.";}if (MinBottomBlobs() >= 0) {CHECK_LE(MinBottomBlobs(), bottom.size())<< type() << " Layer takes at least " << MinBottomBlobs()<< " bottom blob(s) as input.";}if (MaxBottomBlobs() >= 0) {CHECK_GE(MaxBottomBlobs(), bottom.size())<< type() << " Layer takes at most " << MaxBottomBlobs()<< " bottom blob(s) as input.";}if (ExactNumTopBlobs() >= 0) {CHECK_EQ(ExactNumTopBlobs(), top.size())<< type() << " Layer produces " << ExactNumTopBlobs()<< " top blob(s) as output.";}if (MinTopBlobs() >= 0) {CHECK_LE(MinTopBlobs(), top.size())<< type() << " Layer produces at least " << MinTopBlobs()<< " top blob(s) as output.";}if (MaxTopBlobs() >= 0) {CHECK_GE(MaxTopBlobs(), top.size())<< type() << " Layer produces at most " << MaxTopBlobs()<< " top blob(s) as output.";}if (EqualNumBottomTopBlobs()) {CHECK_EQ(bottom.size(), top.size())<< type() << " Layer produces one top blob as output for each "<< "bottom blob input.";}}// 用blob初始化损失权重。inline void SetLossWeights(const vector<Blob<Dtype>*>& top) {const int num_loss_weights = layer_param_.loss_weight_size();if (num_loss_weights) {CHECK_EQ(top.size(), num_loss_weights) << "loss_weight must be ""unspecified or specified once per top blob.";for (int top_id = 0; top_id < top.size(); ++top_id) {const Dtype loss_weight = layer_param_.loss_weight(top_id);if (loss_weight == Dtype(0)) { continue; }this->set_loss(top_id, loss_weight);const int count = top[top_id]->count();Dtype* loss_multiplier = top[top_id]->mutable_cpu_diff();caffe_set(count, loss_weight, loss_multiplier);}}}DISABLE_COPY_AND_ASSIGN(Layer);
};  // class Layer// 前馈,根据caffe的mode 调用相对应的cpu实现或者是gpu实现,并且计算损失函数值。
template <typename Dtype>
inline Dtype Layer<Dtype>::Forward(const vector<Blob<Dtype>*>& bottom,const vector<Blob<Dtype>*>& top) {Dtype loss = 0;Reshape(bottom, top);switch (Caffe::mode()) {case Caffe::CPU:Forward_cpu(bottom, top);for (int top_id = 0; top_id < top.size(); ++top_id) {if (!this->loss(top_id)) { continue; }const int count = top[top_id]->count();const Dtype* data = top[top_id]->cpu_data();const Dtype* loss_weights = top[top_id]->cpu_diff();loss += caffe_cpu_dot(count, data, loss_weights);}break;case Caffe::GPU:Forward_gpu(bottom, top);
#ifndef CPU_ONLYfor (int top_id = 0; top_id < top.size(); ++top_id) {if (!this->loss(top_id)) { continue; }const int count = top[top_id]->count();const Dtype* data = top[top_id]->gpu_data();const Dtype* loss_weights = top[top_id]->gpu_diff();Dtype blob_loss = 0;caffe_gpu_dot(count, data, loss_weights, &blob_loss);loss += blob_loss;}
#endifbreak;default:LOG(FATAL) << "Unknown caffe mode.";}return loss;
}//反向传播梯度,根据Caffe的mode是在GPU还是CPU,调用相对应版本的函数
//propagate_down 用于控制对应的层是否bp
template <typename Dtype>
inline void Layer<Dtype>::Backward(const vector<Blob<Dtype>*>& top,const vector<bool>& propagate_down,const vector<Blob<Dtype>*>& bottom) {switch (Caffe::mode()) {case Caffe::CPU:Backward_cpu(top, propagate_down, bottom);break;case Caffe::GPU:Backward_gpu(top, propagate_down, bottom);break;default:LOG(FATAL) << "Unknown caffe mode.";}
}// 序列化网络层参数到协议缓存,最终是调用blob写入协议缓存。
template <typename Dtype>
void Layer<Dtype>::ToProto(LayerParameter* param, bool write_diff) {param->Clear();param->CopyFrom(layer_param_);param->clear_blobs();for (int i = 0; i < blobs_.size(); ++i) {blobs_[i]->ToProto(param->add_blobs(), write_diff);}
}}  // namespace caffe#endif  // CAFFE_LAYER_H_
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94
  • 95
  • 96
  • 97
  • 98
  • 99
  • 100
  • 101
  • 102
  • 103
  • 104
  • 105
  • 106
  • 107
  • 108
  • 109
  • 110
  • 111
  • 112
  • 113
  • 114
  • 115
  • 116
  • 117
  • 118
  • 119
  • 120
  • 121
  • 122
  • 123
  • 124
  • 125
  • 126
  • 127
  • 128
  • 129
  • 130
  • 131
  • 132
  • 133
  • 134
  • 135
  • 136
  • 137
  • 138
  • 139
  • 140
  • 141
  • 142
  • 143
  • 144
  • 145
  • 146
  • 147
  • 148
  • 149
  • 150
  • 151
  • 152
  • 153
  • 154
  • 155
  • 156
  • 157
  • 158
  • 159
  • 160
  • 161
  • 162
  • 163
  • 164
  • 165
  • 166
  • 167
  • 168
  • 169
  • 170
  • 171
  • 172
  • 173
  • 174
  • 175
  • 176
  • 177
  • 178
  • 179
  • 180
  • 181
  • 182
  • 183
  • 184
  • 185
  • 186
  • 187
  • 188
  • 189
  • 190
  • 191
  • 192
  • 193
  • 194
  • 195
  • 196
  • 197
  • 198
  • 199
  • 200
  • 201
  • 202
  • 203
  • 204
  • 205
  • 206
  • 207
  • 208
  • 209
  • 210
  • 211
  • 212
  • 213
  • 214
  • 215
  • 216
  • 217
  • 218
  • 219
  • 220
  • 221
  • 222
  • 223
  • 224
  • 225
  • 226
  • 227
  • 228
  • 229
  • 230
  • 231
  • 232
  • 233
  • 234
  • 235
  • 236
  • 237
  • 238
  • 239
  • 240
  • 241
  • 242
  • 243
  • 244
  • 245
  • 246
  • 247
  • 248
  • 249
  • 250
  • 251
  • 252
  • 253
  • 254
  • 255
  • 256
  • 257
  • 258
  • 259
  • 260
  • 261
  • 262
  • 263
  • 264
  • 265
  • 266
  • 267
  • 268
  • 269
  • 270
  • 271
  • 272
  • 273
  • 274
  • 275
  • 276
  • 277
  • 278
  • 279
  • 280
  • 281
  • 282
  • 283
  • 284
  • 285
  • 286
  • 287
  • 288
  • 289
  • 290
  • 291
  • 292
  • 293
  • 294
  • 295
  • 296
  • 297
  • 298
  • 299
  • 300
  • 301
  • 302
  • 303
  • 304
  • 305
  • 306
  • 307
  • 308
  • 309
  • 310
  • 311
  • 312
  • 313
  • 314

版权声明:本文为博主原创文章,未经博主允许不得转载。

【Caffe代码解析】Layer网络层相关推荐

  1. 梳理caffe代码layer(五)

    Layer(层)是Caffe中最庞大最繁杂的模块.由于Caffe强调模块化设计,因此只允许每个layer完成一类特定的计算,例如convolution操作.pooling.非线性变换.内积运算,以及数 ...

  2. EfficientNet(ICML 2019)原理与代码解析

    paper:EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks code:mmclassification ...

  3. DeepLearning | 图注意力网络Graph Attention Network(GAT)论文、模型、代码解析

    本篇博客是对论文 Velikovi, Petar, Cucurull, Guillem, Casanova, Arantxa,et al. Graph Attention Networks, 2018 ...

  4. Caffe代码导读(5):对数据集进行Testing

    转载自: Caffe代码导读(5):对数据集进行Testing - 卜居 - 博客频道 - CSDN.NET http://blog.csdn.net/kkk584520/article/detail ...

  5. Caffe代码导读(0):路线图

    转载自: Caffe代码导读(0):路线图 - 卜居 - 博客频道 - CSDN.NET http://blog.csdn.net/kkk584520/article/details/41681085 ...

  6. Python编写caffe代码

    有时候,我们需要将网络使用caffe代码实现,人工手写容易出问题.可以使用Python完成网络编写. 卷积层: def generate_conv_layer_no_bias(name, bottom ...

  7. [GCN] 代码解析 of GitHub:Semi-supervised classification with graph convolutional networks

    本文解析的代码是论文Semi-Supervised Classification with Graph Convolutional Networks作者提供的实现代码. 原GitHub:Graph C ...

  8. Caffe编写Python layer

    Caffe编写Python layer 在使用caffe做训练的时候,通常的做法是把数据转为lmdb格式,然后在train.prototxt中指定,最后在开始训练,转为lmdb格式的优点是读取数据更高 ...

  9. Rescue-Prime hash STARK 代码解析

    1. 引言 前序博客有: STARK入门知识 STARKs and STARK VM: Proofs of Computational Integrity STARK中的FRI代码解析 Rescue- ...

  10. Polygon zkEVM的pil-stark Fibonacci状态机代码解析

    1. 引言 前序博客有: Polygon zkEVM的pil-stark Fibonacci状态机初体验 STARKs and STARK VM: Proofs of Computational In ...

最新文章

  1. oracle 主机名改ip,[oracle 10.2]主机名或者IP地址改变造成的dbconsole服务无法启动解决...
  2. jQuery 入门教程(1): 概述
  3. PAT甲级 -- 1148 Werewolf - Simple Version (20 分)
  4. linux cd -目录,linux cd
  5. 论文首页下划线怎么对齐_毕业论文标准格式要求是什么样的?
  6. Redis 5.0.8+常见面试题(单线程还是多线程、先更新缓存还是数据库、雪崩穿透击穿解决办法...)
  7. 金融项目app业务及测试策略
  8. mysql 建表语句
  9. 微软智能云Azure在华新增数据中心区域正式启用
  10. 2018年中山大学计算机考研初试经验贴
  11. 不知道rar压缩包密码可以解密么,rar压缩包有密码怎么解开?
  12. xss.haozi.me靶场详解
  13. 几种常用的软件测试工具
  14. Pr 入门系列之十三:添加字幕
  15. PPP拨号和NDIS拨号的区别:
  16. 筛选鉴定与已经基因启动子相互作用的DNA结合蛋白-DNA Pull Down实验原理,技术流程
  17. 通达信标记符号_通达信指标中赋值符号“:”、“=”、“:=”区别?
  18. 营销学习思维导图模板
  19. GPU内存分明没人占用但是分配不了内存的解决办法
  20. integral函数Opencv源码理解-leetcode动态规划的使用场景

热门文章

  1. 把grid第一列设置为行号
  2. 通过OAuth方式与docker hub v2 API交互
  3. 微信小程序MQTT客户端的坑
  4. CentOS6.8下实现配置配额
  5. 本人新书推荐《linux运维之道》
  6. django视图(views)
  7. BUAA OO 2019 第一单元作业总结
  8. Spring AOP动态代理原理与实现方式
  9. 浅谈C++ Lambda 表达式(简称LB)
  10. Paypal开源nodejs框架研究(二)KrakenJs之Enrouten