文章索引

代码:PRIS-CV/PMG-Progressive-Multi-Granularity-Training
论文:Fine-Grained Visual Classification via Progressive Multi-Granularity Training of Jigsaw Patches

说明

Fine-Grained Visual Classification viaProgressive Multi-Granularity Training ofJigsaw Patches代码解读,本文只针对PMG中重点代码记录分析,算是对文章更深层的理解,具体模型可以参考原论文。

正文


首先,作者先把图片进行切割,切割的尺寸大小是渐渐增大的,对应渐进式训练,在论文中阐述的是用不同的step训练,代码如下:

#train.py
# 数据处理代码省略....
# 从训练开始for epoch in range(start_epoch, nb_epoch):print('\nEpoch: %d' % epoch)net.train()# 这里设置了5个loss,分别对应1-3个step的loss和一个concat的losss,最后train_losss表示总losstrain_loss = 0train_loss1 = 0train_loss2 = 0train_loss3 = 0train_loss4 = 0correct = 0total = 0idx = 0for batch_idx, (inputs, targets) in enumerate(trainloader):idx = batch_idxif inputs.shape[0] < batch_size:continueif use_cuda:inputs, targets = inputs.to(device), targets.to(device)inputs, targets = Variable(inputs), Variable(targets)# update learning ratefor nlr in range(len(optimizer.param_groups)):optimizer.param_groups[nlr]['lr'] = cosine_anneal_schedule(epoch, nb_epoch, lr[nlr])# Step 1optimizer.zero_grad()inputs1 = jigsaw_generator(inputs, 8) #这里其实就是论文中说的拼图生成器,输入一张原图,根据第二个参数进行不同程度的切割再拼图output_1, _, _, _ = netp(inputs1) # 调用网络,具体为什么是四个输出,这四个输出有什么作用后面解释loss1 = CELoss(output_1, targets) * 1loss1.backward()optimizer.step()# Step 2optimizer.zero_grad()inputs2 = jigsaw_generator(inputs, 4)_, output_2, _, _ = netp(inputs2)loss2 = CELoss(output_2, targets) * 1loss2.backward()optimizer.step()# Step 3optimizer.zero_grad()inputs3 = jigsaw_generator(inputs, 2)_, _, output_3, _ = netp(inputs3)loss3 = CELoss(output_3, targets) * 1loss3.backward()optimizer.step()# Step 4optimizer.zero_grad()_, _, _, output_concat = netp(inputs)concat_loss = CELoss(output_concat, targets) * 2concat_loss.backward()optimizer.step()#  training log_, predicted = torch.max(output_concat.data, 1)total += targets.size(0)correct += predicted.eq(targets.data).cpu().sum()train_loss += (loss1.item() + loss2.item() + loss3.item() + concat_loss.item())train_loss1 += loss1.item()train_loss2 += loss2.item()train_loss3 += loss3.item()train_loss4 += concat_loss.item()# 下面打印输出省略

很明显,在train.py中,作者通过jigswa_generator()生成不同尺度的拼图(inputs1,inputs2,inputs3,inputs),再分别把这他们输入到网络netp,得到不同尺度的输出,再分别做了一个loss,最后把所有loss加起来做一个loss,这就是这段代码的具体含义,至于这几个inputs输入到网络中发生了什么,就要看看PMG网络模型了:

首先加载模型,可以从下面的代码开出,使用的基础模型是resnet50。这里resnet50会输出不同层的tensor,分别是第一到四层的tensor,分别表示了不同层的特征。再通过调用PMG将net模型传入。

def load_model(model_name, pretrain=True, require_grad=True):print('==> Building model..')if model_name == 'resnet50_pmg':net = resnet50(pretrained=pretrain)for param in net.parameters():param.requires_grad = require_gradnet = PMG(net, 512, 200)return net

首先PMG定义了很多block,直接看forwoard调用过程,先调用model也就是传进来的resnet模型,得到5个不同层的tensor输出,可以发现作者只去了后三层的输出,即xf3, xf4, xf5,丢弃xf1,xf2,具体为什么丢弃文章中没有做解释,猜测是前面层噪声大,加与不加作用不大。然后把这三个tensor进行卷积,得到xl1,xl2,xl3,最后对这三个tensor分别卷积分类。x_concat 就是把这几个tensor结合起来。

class PMG(nn.Module):def __init__(self, model, feature_size, classes_num):super(PMG, self).__init__()self.features = modelself.max1 = nn.MaxPool2d(kernel_size=56, stride=56)self.max2 = nn.MaxPool2d(kernel_size=28, stride=28)self.max3 = nn.MaxPool2d(kernel_size=14, stride=14)self.num_ftrs = 2048 * 1 * 1self.elu = nn.ELU(inplace=True)self.classifier_concat = nn.Sequential(nn.BatchNorm1d(1024 * 3),nn.Linear(1024 * 3, feature_size),nn.BatchNorm1d(feature_size),nn.ELU(inplace=True),nn.Linear(feature_size, classes_num),)self.conv_block1 = nn.Sequential(BasicConv(self.num_ftrs//4, feature_size, kernel_size=1, stride=1, padding=0, relu=True),BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True))self.classifier1 = nn.Sequential(nn.BatchNorm1d(self.num_ftrs//2),nn.Linear(self.num_ftrs//2, feature_size),nn.BatchNorm1d(feature_size),nn.ELU(inplace=True),nn.Linear(feature_size, classes_num),)self.conv_block2 = nn.Sequential(BasicConv(self.num_ftrs//2, feature_size, kernel_size=1, stride=1, padding=0, relu=True),BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True))self.classifier2 = nn.Sequential(nn.BatchNorm1d(self.num_ftrs//2),nn.Linear(self.num_ftrs//2, feature_size),nn.BatchNorm1d(feature_size),nn.ELU(inplace=True),nn.Linear(feature_size, classes_num),)self.conv_block3 = nn.Sequential(BasicConv(self.num_ftrs, feature_size, kernel_size=1, stride=1, padding=0, relu=True),BasicConv(feature_size, self.num_ftrs//2, kernel_size=3, stride=1, padding=1, relu=True))self.classifier3 = nn.Sequential(nn.BatchNorm1d(self.num_ftrs//2),nn.Linear(self.num_ftrs//2, feature_size),nn.BatchNorm1d(feature_size),nn.ELU(inplace=True),nn.Linear(feature_size, classes_num),)def forward(self, x):#   x = torch.Size([8, 3, 448, 448])# xf1 = torch.Size([8, 64, 112, 112])# xf2 = torch.Size([8, 256, 112, 112])# xf3 = torch.Size([8, 512, 56, 56])# xf4 = torch.Size([8, 1024, 28, 28])# xf5 = torch.Size([8, 2048, 14, 14])xf1, xf2, xf3, xf4, xf5 = self.features(x)xl1 = self.conv_block1(xf3)xl2 = self.conv_block2(xf4)xl3 = self.conv_block3(xf5)xl1 = self.max1(xl1)xl1 = xl1.view(xl1.size(0), -1)xc1 = self.classifier1(xl1)xl2 = self.max2(xl2)xl2 = xl2.view(xl2.size(0), -1)xc2 = self.classifier2(xl2)xl3 = self.max3(xl3)xl3 = xl3.view(xl3.size(0), -1)xc3 = self.classifier3(xl3)x_concat = torch.cat((xl1, xl2, xl3), -1)x_concat = self.classifier_concat(x_concat)return xc1, xc2, xc3, x_concat

所以这里的四个返回值就对应了resnet中不同层的再次卷积得到的结果。

总结:

到这里,整个训练就完成了,最后总结下:所谓的渐进式训练,其实就是把不同层的tensor拿出来,得到分类结果,然后对应train.py中不同的inputs,这几个tensor最后分别看做一个指标,进行训练。其中inputs加了不同的拼图。
和传统分类改进的点:传统分类方法是只把最后一层的输出作为指标进行分类,PMG把中间几层也加入到指标当中,并且加入了不同尺度的拼图,可以让网络更关注细节特征,用一句话说,增加了网络的容错率吧。

Fine-Grained Visual Classification via Progressive Multi-Granularity Training of Jigsaw Patches相关推荐

  1. A Novel Plug-in Module for Fine-Grained Visual Classification学习

    A Novel Plug-in Module for Fine-Grained Visual Classification Po-Yung Chou, Cheng-Hung Lin Member , ...

  2. Context-aware Attentional Pooling (CAP) for Fine-grained Visual Classification

    Context-aware Attentional Pooling (CAP) for Fine-grained Visual Classification 用于细粒度视觉分类的上下文感知注意力池化 ...

  3. WS-DAN:Weakly Supervised Data Augmentation Netowrk for Fine-Grained Visual Classification

    See Better Before Looking Closer: Weakly Supervised Data Augmentation Netowrk for Fine-Grained Visua ...

  4. 论文笔记:See Better Before Looking Closer: WS-DAN for Fine-Grained Visual Classification

    文章目录 0 摘要 1 引言 2 相关工作 3 方法 3.1 弱监督注意力学习 3.2 注意力导向的数据增强 3.3 目标定位和细化 4 实验 4.1 数据集 4.2 实现细节 4.3 精度分布 4. ...

  5. CVPR2018论文笔记: Robust Physical-World Attacks on Deep Learning Visual Classification

    论文百篇计划第二篇,cvpr2018的一篇文章,引用量1800.作者来自密歇根大学安娜堡分校. 最近的研究表明目前DNN容易收到对抗样本的攻击,理解物理世界中的对抗样本对发展弹性学习算法非常重要.我们 ...

  6. 图像处理-State of the Art

    https://github.com/BlinkDL/BlinkDL.github.io 目前常见图像任务的 State-of-the-Art 方法,从 Super-resolution 到 Capt ...

  7. 【AutoAugment】《AutoAugment:Learning Augmentation Policies from Data》

    arXiv-2018 文章目录 1 Background and Motivation 2 Related Work 3 Advantages / Contributions 4 Method 5 E ...

  8. AutoAugment: Learning Augmentation Policies from Data(一种自动数据增强技术)

    谷歌大脑提出自动数据增强方法AutoAugment:可迁移至不同数据集 近日,来自谷歌大脑的研究者在 arXiv 上发表论文,提出一种自动搜索合适数据增强策略的方法 AutoAugment,该方法创建 ...

  9. Fine-grained Classification 论文调研

    目录 细粒度分类综述 论文一 Learning to Navigate for Fine-grained Classification (ECCV2018 from PKU) 1. Abstract ...

最新文章

  1. 关于双目立体视觉的三大基本算法及发展现状的总结
  2. java数组语法_Java 基本语法----数组
  3. 12306 网站的非技术分析
  4. python3精要(40)-数组与矩阵
  5. c语言位运算负数的实例_0基础学习C语言第三章:位运算
  6. 小程序modal控件(显示为弹框) 可有输入框
  7. java函数式 new_Java函数式编程-4.lambda表达式一些高级用法
  8. Python的基本数据类型(1)
  9. SE_01 需求分析
  10. 计算机考试PPT2003好考吗,2014年职称计算机考试PowerPoint2003基本操作试题
  11. JAVA系列-设计模式-中介者模式
  12. Windows下LATEX排版论文攻略—CTeX、JabRef使用心得, 包括 IEEEtran.bst
  13. 用matlab画脑图,思维导图怎么画,画出一副好看的流程图方法是什么
  14. 红米7 自编译不完美 twrp 可root手机
  15. UnsupportedOperationException:setProperty must be overridden by all subclasses of SOAPMessage解决方法有效
  16. s+清辅音,读作对应的浊辅音
  17. Centos7 安装Graylog 5.0收集网络设备运行日志+卸载GrayLog服务
  18. ServU 教程11.1.0.7使用教程
  19. 网件 R6400 TTL 救砖详细 教程
  20. GCSE英语语言考试-角色定位

热门文章

  1. pandas:世界各国GDP数据集数据清洗案例
  2. FFmpeg基本使用
  3. Java后端对接微信支付(微信小程序、APP、PC端扫码)非常全,包含查单、退款
  4. 城市各种服务设施半径
  5. UNIX经典命令详解
  6. 自媒体人必备神器,200w+自媒体人都在用
  7. win10一按右键就闪屏_六种方法教你如何解决win10笔记本屏幕闪烁问题?
  8. 以太坊源码分析-同步之Syncing接口
  9. 宏碁暗影骑士设置u盘启动教程
  10. BC v1.2充电规范