batch_sampler是生产随机样本patch的方法,是一种常用的数据增量(DataAugment)策略。具体说来,它从训练数据图像中随机选取一个满足限制条件的区域。这里有两点,一个是随机选取,另一个就是得满足限制条件。限制条件在后文中将做具体说明,下面先看一个AnnotatedDataLayer层中关于batch_sampler参数配置的一个样例。

annotated_data_param {batch_sampler {max_sample: 1max_trials: 1}batch_sampler {sampler {min_scale: 0.3max_scale: 1min_aspect_ratio: 0.5max_aspect_ratio: 2}sample_constraint {min_jaccard_overlap: 0.1 }max_sample: 1 #訪采样器所要生成的patch的数量(这里是1)max_trials: 50 #最多可以尝试max_trails次采样以产生满足限制条件的patch}batch_sampler {sampler {min_scale: 0.3max_scale: 1min_aspect_ratio: 0.5max_aspect_ratio: 2}sample_constraint {min_jaccard_overlap: 0.3}max_sample: 1max_trials: 50}
}

一般会配置不同参数的多个batch_sampler,每个batch_sampler都会对训练图像进行采样,而每个batch_sampler可产生的最大patch数量由max_sample参数决定。像上面的参数配置中max_sample都为1,也就是一个batch_sampler采样得到一个满足限制条件的patch就够了。因为有限制条件的存在,允许你最多尝试max_trials次。当然,要是运气不好的话有可能没有一个满足限制条件的patch。

batch_sampler入口函数为GenerateBatchSamples,它在annotated_data_layer.cpp中的load_batch函数中被调用。load_batch顾名思义就是加载一个batch的训练数据。代码实现上比较直白,一个for循环逐一加载,加满一个batch为止。GenerateBatchSamples就置于for循环中。

for (size_t entry=0; entry<batch_size; ++entry) {// get an anno_datumshared_ptr<AnnotatedDatum> anno_datum = reader->full_pop(qid, "Waiting for data");size_t item_id = anno_datum->record_id() % batch_size; //record_id??data reader 里面设的if (item_id == 0UL) {current_batch_id = anno_datum->record_id() / batch_size;}AnnotatedDatum distort_datum;AnnotatedDatum expand_datum;if (transform_param.has_distort_param()) {distort_datum.CopyFrom(*anno_datum); //来来来,先copy个数据//这个this->bdt(thread_id)是取出thread_id线程中的DataTransformer对象this->bdt(thread_id)->DistortImage(anno_datum->datum(), distort_datum.mutable_datum());if (transform_param.has_expand_param()) {/*expandimage是缩小图片,达到zoom out的效果:先做一个比原图大的画布,然后随机找一个放原图的位置将图片镶嵌进去.既然要做一个大的画布,那必然要做数据的填充,通常是使用均值填充*/this->bdt(thread_id)->ExpandImage(distort_datum, &expand_datum);} else {expand_datum = distort_datum;}} else {if (transform_param.has_expand_param()) {this->bdt(thread_id)->ExpandImage(*anno_datum, &expand_datum);} else {expand_datum = *anno_datum;}}AnnotatedDatum sampled_datum;if (batch_samplers_.size() > 0) {// Generate sampled bboxes from expand_datum.vector<NormalizedBBox> sampled_bboxes;//针对expand_datum的batch sampleGenerateBatchSamples(expand_datum, batch_samplers_, &sampled_bboxes); if (sampled_bboxes.size() > 0) {// Randomly pick a sampled bbox and crop the expand_datum.int rand_idx = caffe_rng_rand() % sampled_bboxes.size();//虽然你可能生成了很多sampled_bboxes,但还是只从中挑出一个this->bdt(thread_id)->CropImage(expand_datum, sampled_bboxes[rand_idx], &sampled_datum);} else {sampled_datum = expand_datum;}} else {sampled_datum = expand_datum;}/*...省略....*/
}

GenerateBatchSamples的第一个传入参数有点意思,代码中是expand_datum,它是怎么来的?请见下面的草稿图!

在训练过程中数据加载程序从数据库(如:lmdb)中加载训练数据保存于anno_datum中,在进行batch_sampler之前根据参数配置可能会进行distort和expand变换。关于distort,可以参见我另一篇文章(数据增量之DistortImage)。 expand image是缩小训练图片以达到zoom out的效果。它先建一个比原图大的画布,然后随机找一个位置将原图镶嵌进去。既然要做一个大的画布,那必然要做数据的填充,通常是使用均值填充的方式。

void GenerateBatchSamples(const AnnotatedDatum& anno_datum,const vector<BatchSampler>& batch_samplers,vector<NormalizedBBox>* sampled_bboxes) {sampled_bboxes->clear();vector<NormalizedBBox> object_bboxes;GroupObjectBBoxes(anno_datum, &object_bboxes); //获取gt box//对于每个采样器生成多个boxfor (int i = 0; i < batch_samplers.size(); ++i) {// Use original image as the source for sampling//见caffe.proto, optional bool use_original_image = 1 [default = true];if (batch_samplers[i].use_original_image()) { NormalizedBBox unit_bbox;unit_bbox.set_xmin(0);unit_bbox.set_ymin(0);unit_bbox.set_xmax(1);unit_bbox.set_ymax(1);GenerateSamples(unit_bbox, object_bboxes, batch_samplers[i],sampled_bboxes);}}
}

在GenerateBatchSamples函数中首先会调用GroupObjectBBoxes来获取当前图像的gt,并存入到object_bboxes中。

void GroupObjectBBoxes(const AnnotatedDatum& anno_datum,vector<NormalizedBBox>* object_bboxes) {object_bboxes->clear();for (int i = 0; i < anno_datum.annotation_group_size(); ++i) {//Each group contains annotation for a particular classconst AnnotationGroup& anno_group = anno_datum.annotation_group(i);for (int j = 0; j < anno_group.annotation_size(); ++j) {const Annotation& anno = anno_group.annotation(j);object_bboxes->push_back(anno.bbox());}}
}

理解这段逻辑可以参考文章:https://blog.csdn.net/hjxu2016/article/details/83900459,作者所画的示意图其实就已经一目了然,我就不展开说了。

之后依次取出设置的每个batch_sampler,并调用GenerateSamples生成sample patch。

void GenerateSamples(const NormalizedBBox& source_bbox,const vector<NormalizedBBox>& object_bboxes,const BatchSampler& batch_sampler,vector<NormalizedBBox>* sampled_bboxes) {int found = 0;//每个采样器最多尝试max_trials次for (int i = 0; i < batch_sampler.max_trials(); ++i) {//max_sample参数控制要生成的满足条件的patch的个数,通常是1,也就是一旦有一个满足条件的就退出if (batch_sampler.has_max_sample() &&found >= batch_sampler.max_sample()) {break; }// Generate sampled_bbox in the normalized space [0, 1].NormalizedBBox sampled_bbox;SampleBBox(batch_sampler.sampler(), &sampled_bbox); //随机生成一个box// Transform the sampled_bbox w.r.t. source_bbox.LocateBBox(source_bbox, sampled_bbox, &sampled_bbox); //转换为在单位box中的坐标// Determine if the sampled bbox is positive or negative by the constraint.//看看是否满足限制条件,所的有gt与生成的box计算IoU,是否满足条件?//只要有一个目标box满足条件返回就是真if (SatisfySampleConstraint(sampled_bbox, object_bboxes,batch_sampler.sample_constraint())) {++found;sampled_bboxes->push_back(sampled_bbox);}}
}

说了这么多,整个batch_sampler中最核心的就是这个SampleBBox函数,因为正是由訪函数来完成实际的sample,随机生成一个patch。上面示例配置中的诸多参数也都在訪函数中得以应用。

void SampleBBox(const Sampler& sampler, NormalizedBBox* sampled_bbox) {// Get random scale.CHECK_GE(sampler.max_scale(), sampler.min_scale()); //max_sclae当然要大于等于min_scaleCHECK_GT(sampler.min_scale(), 0.f);CHECK_LE(sampler.max_scale(), 1.f);float scale;//产生一个介于[min_scale,max_scale]的随机值caffe_rng_uniform(1, sampler.min_scale(), sampler.max_scale(), &scale); // Get random aspect ratio.CHECK_GE(sampler.max_aspect_ratio(), sampler.min_aspect_ratio());CHECK_GT(sampler.min_aspect_ratio(), 0.f);CHECK_LT(sampler.max_aspect_ratio(), FLT_MAX); //这个值是多少float aspect_ratio;//产生一个宽高比的随机值caffe_rng_uniform(1, sampler.min_aspect_ratio(), sampler.max_aspect_ratio(),&aspect_ratio);//这是干啥?aspect_ratio = std::max<float>(aspect_ratio, std::pow(scale, 2.f));aspect_ratio = std::min<float>(aspect_ratio, 1.f / std::pow(scale, 2.f));// 有点类似于ssd中anchor的计算// Figure out bbox dimension.float bbox_width = scale * sqrt(aspect_ratio);float bbox_height = scale / sqrt(aspect_ratio);// Figure out top left coordinates.// 确定好左上角的坐标,咦,是这么个道理float w_off = 0.f, h_off = 0.f;if (bbox_width < 1.f) {caffe_rng_uniform(1, 0.f, 1.f - bbox_width, &w_off);}if (bbox_height < 1.f) {caffe_rng_uniform(1, 0.f, 1.f - bbox_height, &h_off);}sampled_bbox->set_xmin(w_off);sampled_bbox->set_ymin(h_off);sampled_bbox->set_xmax(w_off + bbox_width);sampled_bbox->set_ymax(h_off + bbox_height);
}

参数min_scale和max_scale限定了采样的patch的尺寸是介于[min_scale,max_scale]之间的一个随机值。这里都是用的归一化了的值,也就是相对于原图的比例。而min_aspect_ratio,max_aspect_ratio限定了采样的patch的宽高比是介于[min_aspect_ratio,max_aspect_ratio]之间的一个随机值。代码中分别用变量scale和aspect_ratio表示随机得到的尺寸和宽高比,然后据此计算出采样patch的宽和高。现在相当于patch的形状和大小已经固定好了,接下来就是采样的位置。只要确定了左上角的起始位置即可。当然,这个起始位置也是按一定条件随机得到的。最终将使用(xmin,ymin,xmax,ymax)表示的sample patch存入sample_bbox中。

bool SatisfySampleConstraint(const NormalizedBBox& sampled_bbox,const vector<NormalizedBBox>& object_bboxes,const SampleConstraint& sample_constraint) {bool has_jaccard_overlap = sample_constraint.has_min_jaccard_overlap() ||sample_constraint.has_max_jaccard_overlap();bool has_sample_coverage = sample_constraint.has_min_sample_coverage() ||sample_constraint.has_max_sample_coverage(); //通常没有,是啥?bool has_object_coverage = sample_constraint.has_min_object_coverage() ||sample_constraint.has_max_object_coverage(); //通常没有,是啥?bool satisfy = !has_jaccard_overlap && !has_sample_coverage &&!has_object_coverage;if (satisfy) {// By default, the sampled_bbox is "positive" if no constraints are defined.return true;}// Check constraints.bool found = false;for (int i = 0; i < object_bboxes.size(); ++i) {const NormalizedBBox& object_bbox = object_bboxes[i];// Test jaccard overlap.if (has_jaccard_overlap) {//计算IoUconst float jaccard_overlap = JaccardOverlap(sampled_bbox, object_bbox);if (sample_constraint.has_min_jaccard_overlap() &&jaccard_overlap < sample_constraint.min_jaccard_overlap()) {//小于最小iou条件,跳过continue;}if (sample_constraint.has_max_jaccard_overlap() &&jaccard_overlap > sample_constraint.max_jaccard_overlap()) {//为毛会有最大iou限制条件这一说?为毛要去限制最大iou?continue;}//这里只要有一个目标box与sample box满足IoU条件found就为truefound = true;}// Test sample coverage.if (has_sample_coverage) {const float sample_coverage = BBoxCoverage(sampled_bbox, object_bbox);if (sample_constraint.has_min_sample_coverage() &&sample_coverage < sample_constraint.min_sample_coverage()) {continue;}if (sample_constraint.has_max_sample_coverage() &&sample_coverage > sample_constraint.max_sample_coverage()) {continue;}found = true;}// Test object coverage.if (has_object_coverage) {const float object_coverage = BBoxCoverage(object_bbox, sampled_bbox);if (sample_constraint.has_min_object_coverage() &&object_coverage < sample_constraint.min_object_coverage()) {continue;}if (sample_constraint.has_max_object_coverage() &&object_coverage > sample_constraint.max_object_coverage()) {continue;}found = true;}//有一个found就算是trueif (found) {return true;}}return found;
}

前文中所说的限制条件在函数SatisfySampleConstraint中进行判断。

【NVCaffe源码分析】数据增量之batch_sampler相关推荐

  1. 【VUE】源码分析 - 数据劫持的基本原理

    tips:本系列博客的代码部分(示例等除外),均出自vue源码内容,版本为2.6.14.但是为了增加易读性,会对不相关内容做选择性省略.如果大家想了解完整的源码,建议自行从官方下载.https://g ...

  2. 【NVCaffe源码分析】数据增量之DistortImage

    distort image作为NVCaffe一项常用的数据增量策略,其参数(distort_param)配置大体如下: distort_param {brightness_prob: 0.5brigh ...

  3. 风讯dotNETCMS源码分析—数据存取篇

    前几天突然对CMS感兴趣,就去下载了风讯dotNETCMS源码.当前版本是dotnetcms1.0 sp5免费版,风讯的官方主页上可以下载. 用Visual Studio 2008打开后,初步分析了它 ...

  4. android+小米文件管理器源码,[MediaStore]小米文件管理器android版源码分析——数据来源...

    打开小米的文件管理器,我们很快会看到如下图所示的界面: 其中,会把各种文件分类显示.并且显示出每种文件的个数. 这是怎么做到的呢?当然不是每次启动都查询sdcard和应用程序data目录文件啦,那样实 ...

  5. Nginx源码分析--数据对齐posix_memalign和memalign函数

    posix_memalign函数() /*  * 背景:  *      1)POSIX 1003.1d  *      2)POSIX 标明了通过malloc( ), calloc( ), 和 re ...

  6. galler3d的源码分析——数据来源

    我们这里主要讲本地sd卡的数据,pisaca看情况后续再作分析. 数据操作设计的类包括:CacheService,MediaFeed,LocalDataSource,DiskCache,MediaIt ...

  7. Beanstalk源码分析--数据结构设计

    概述 beanstalk是多年前使用过的一个分布式任务队列,通过C实现,十分高效.和Redis(默认)的事件驱动框架一样,都是通过异步的epoll来实现,所以,能够高效的处理大量请求. 但不知什么原因 ...

  8. ConcurrentHashMap的源码分析-数据迁移阶段的实现分析

    通过分配好迁移的区间之后,开始对数据进行迁移.在看这段代码之前,先来了解一下原理 synchronized (f) {//对数组该节点位置加锁,开始处理数组该位置的迁移工作 if (tabAt(tab ...

  9. Memcached源码分析 - 内存存储机制Slabs(5)

    Memcached源码分析 - 网络模型(1) Memcached源码分析 - 命令解析(2) Memcached源码分析 - 数据存储(3) Memcached源码分析 - 增删改查操作(4) Me ...

最新文章

  1. 2018-4-15摘录笔记,《网络表征学习前沿与实践》 崔鹏以及《网络表征学习中的基本问题初探》 王啸 崔鹏 朱文武
  2. 兼容ie9以下css3,hover和圆角(htc)
  3. java 微信多媒体文件_java微信开发之上传下载多媒体文件
  4. Linux显示5 9行的数据,Linux复习
  5. MVC架构中的Repository模式 个人理解
  6. C#中文件及文件夾的遍历
  7. Codeforces 754E:Dasha and cyclic table
  8. LeetCode 127. 单词接龙(图的BFS/双向BFS)
  9. es 创建索引_es的基本原理和操作文档
  10. 29.3. phpMyAdmin - MySQL web administration tool
  11. 【全套完结】数字电子技术基础——全套实验手册及仿真工艺实习【建议保存】
  12. 本地搭建Redis集群 windows(图文详解)
  13. swftool pdf2swf使用
  14. Qos限速、流量监管、流量整形原理和实验(华为设备)
  15. Windows 10 微信双开或多开【PC端】
  16. [buuctf][Zer0pts2020]easy strcmp
  17. 关于ubuntu自带英文版firefox浏览器,安装evernote剪藏插件总是登录到国际版及firefox插件无法下载
  18. 求根计算机在线,在线一元方程求解计算工具-一元函数自动求解在线计算器
  19. 基于asp.net房屋按揭贷款管理系统
  20. 云计算 python PXE+KS无人值守安装

热门文章

  1. C# 打印文档(word文档)
  2. Linux多功能下载机(Arias2)
  3. 写一个判断素数的函数(isprime),在主函数输入一个正整数,输出是否是素数的信息。提示:int main(){int x=23; if (isprime(x)) print
  4. linux 16.04 密码,诡异的 登录 Linux / Ubuntu 16.04 系统 时, 系统提示 登录密码错误 之谜 !...
  5. CDH6.3.1安装指南
  6. LC 电路串联谐振与并联谐振
  7. KindEditor实现多图片上传
  8. uniapp同意使用,不同意退出APP
  9. Cobalt Strike(三)DNS beacon 的使用与原理
  10. 在浏览器拉起应用的方式