参考:https://github.com/WenmuZhou/DBNet.pytorch

下面是 zhoujun 的 DBNet.pytorch /models/model.py

可见通过三个 from   import 分别引入了 backbone、neck、head。

一下列举的是pytorch的实现,可以对照paddleocr中的实现参考学习,另外opencv4.5之后中也增加了对dbnet的支持。

参考:https://github.com/MhLiao/DB/tree/4ac194d0357fd102ac871e37986cb8027ecf094e

DBnet论文中最关键的一个近似阶跃函数在DBHead中

    def step_function(self, x, y):return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))

调用在

    def forward(self, x):shrink_maps = self.binarize(x)threshold_maps = self.thresh(x)if self.training:binary_maps = self.step_function(shrink_maps, threshold_maps)y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)else:y = torch.cat((shrink_maps, threshold_maps), dim=1)return y

对比一下

论文中其他关键的地方主要是后处理了。

1、主干model.py

在Model类中

  1. 通过__init__初始化backbone、neck、head。
  2. 通过forward将三个部分拼接到一起。

通过forward函数可以看到,输入尺寸和输出尺寸是相等的,而这个相等的尺寸是通过双线性(bilinear)插值实现的。

# -*- coding: utf-8 -*-
# @Time    : 2019/8/23 21:57
# @Author  : zhoujun
from addict import Dict
from torch import nn
import torch.nn.functional as Ffrom models.backbone import build_backbone
from models.neck import build_neck
from models.head import build_headclass Model(nn.Module):def __init__(self, model_config: dict):"""PANnet:param model_config: 模型配置"""super().__init__()model_config = Dict(model_config)backbone_type = model_config.backbone.pop('type')neck_type = model_config.neck.pop('type')head_type = model_config.head.pop('type')self.backbone = build_backbone(backbone_type, **model_config.backbone)self.neck = build_neck(neck_type, in_channels=self.backbone.out_channels, **model_config.neck)self.head = build_head(head_type, in_channels=self.neck.out_channels, **model_config.head)self.name = f'{backbone_type}_{neck_type}_{head_type}'def forward(self, x):_, _, H, W = x.size()backbone_out = self.backbone(x)neck_out = self.neck(backbone_out)y = self.head(neck_out)y = F.interpolate(y, size=(H, W), mode='bilinear', align_corners=True)return yif __name__ == '__main__':import torchdevice = torch.device('cpu')x = torch.zeros(2, 3, 640, 640).to(device)model_config = {'backbone': {'type': 'resnest50', 'pretrained': True, "in_channels": 3},'neck': {'type': 'FPN', 'inner_channels': 256},  # 分割头,FPN or FPEM_FFM'head': {'type': 'DBHead', 'out_channels': 2, 'k': 50},}model = Model(model_config=model_config).to(device)import timetic = time.time()y = model(x)print(time.time() - tic)print(y.shape)print(model.name)print(model)# torch.save(model.state_dict(), 'PAN.pth')

2、backbone 以resnet为例

下面这段代码写的很官方,应该是在官方代码的基础上修改的,增加了变形卷积,不过我们常用的还是普通卷积。

不同版本的pytorch,官方的resnet实现可能略有不同。

import torch.nn as nn
import math
import torch.utils.model_zoo as model_zooBatchNorm2d = nn.BatchNorm2d__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'deformable_resnet18', 'deformable_resnet50','resnet152']model_urls = {'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}def constant_init(module, constant, bias=0):nn.init.constant_(module.weight, constant)if hasattr(module, 'bias'):nn.init.constant_(module.bias, bias)def conv3x3(in_planes, out_planes, stride=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,padding=1, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):super(BasicBlock, self).__init__()self.with_dcn = dcn is not Noneself.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = BatchNorm2d(planes)self.relu = nn.ReLU(inplace=True)self.with_modulated_dcn = Falseif not self.with_dcn:self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)else:from torchvision.ops import DeformConv2ddeformable_groups = dcn.get('deformable_groups', 1)offset_channels = 18self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, kernel_size=3, padding=1)self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, bias=False)self.bn2 = BatchNorm2d(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)# out = self.conv2(out)if not self.with_dcn:out = self.conv2(out)else:offset = self.conv2_offset(out)out = self.conv2(out, offset)out = self.bn2(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, inplanes, planes, stride=1, downsample=None, dcn=None):super(Bottleneck, self).__init__()self.with_dcn = dcn is not Noneself.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)self.bn1 = BatchNorm2d(planes)self.with_modulated_dcn = Falseif not self.with_dcn:self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)else:deformable_groups = dcn.get('deformable_groups', 1)from torchvision.ops import DeformConv2doffset_channels = 18self.conv2_offset = nn.Conv2d(planes, deformable_groups * offset_channels, stride=stride, kernel_size=3, padding=1)self.conv2 = DeformConv2d(planes, planes, kernel_size=3, padding=1, stride=stride, bias=False)self.bn2 = BatchNorm2d(planes)self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)self.bn3 = BatchNorm2d(planes * 4)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.dcn = dcnself.with_dcn = dcn is not Nonedef forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)# out = self.conv2(out)if not self.with_dcn:out = self.conv2(out)else:offset = self.conv2_offset(out)out = self.conv2(out, offset)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:residual = self.downsample(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, in_channels=3, dcn=None):self.dcn = dcnself.inplanes = 64super(ResNet, self).__init__()self.out_channels = []self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3,bias=False)self.bn1 = BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dcn=dcn)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dcn=dcn)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dcn=dcn)for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))elif isinstance(m, BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()if self.dcn is not None:for m in self.modules():if isinstance(m, Bottleneck) or isinstance(m, BasicBlock):if hasattr(m, 'conv2_offset'):constant_init(m.conv2_offset, 0)def _make_layer(self, block, planes, blocks, stride=1, dcn=None):downsample = Noneif stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * block.expansion,kernel_size=1, stride=stride, bias=False),BatchNorm2d(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, dcn=dcn))self.inplanes = planes * block.expansionfor i in range(1, blocks):layers.append(block(self.inplanes, planes, dcn=dcn))self.out_channels.append(planes * block.expansion)return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x2 = self.layer1(x)x3 = self.layer2(x2)x4 = self.layer3(x3)x5 = self.layer4(x4)return x2, x3, x4, x5def resnet18(pretrained=True, **kwargs):"""Constructs a ResNet-18 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'print('load from imagenet')model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)return modeldef deformable_resnet18(pretrained=True, **kwargs):"""Constructs a ResNet-18 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(BasicBlock, [2, 2, 2, 2], dcn=dict(deformable_groups=1), **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'print('load from imagenet')model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)return modeldef resnet34(pretrained=True, **kwargs):"""Constructs a ResNet-34 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)return modeldef resnet50(pretrained=True, **kwargs):"""Constructs a ResNet-50 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)return modeldef deformable_resnet50(pretrained=True, **kwargs):"""Constructs a ResNet-50 model with deformable conv.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 4, 6, 3], dcn=dict(deformable_groups=1), **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)return modeldef resnet101(pretrained=True, **kwargs):"""Constructs a ResNet-101 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)return modeldef resnet152(pretrained=True, **kwargs):"""Constructs a ResNet-152 model.Args:pretrained (bool): If True, returns a model pre-trained on ImageNet"""model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)if pretrained:assert kwargs['in_channels'] == 3, 'in_channels must be 3 whem pretrained is True'model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)return modelif __name__ == '__main__':import torchx = torch.zeros(2, 3, 640, 640)net = deformable_resnet50(pretrained=False)y = net(x)for u in y:print(u.shape)print(net.out_channels)

3、neck 以 FPN为例

# -*- coding: utf-8 -*-
# @Time    : 2019/9/13 10:29
# @Author  : zhoujun
import torch
import torch.nn.functional as F
from torch import nnfrom models.basic import ConvBnReluclass FPN(nn.Module):def __init__(self, in_channels, inner_channels=256, **kwargs):""":param in_channels: 基础网络输出的维度:param kwargs:"""super().__init__()inplace = Trueself.conv_out = inner_channelsinner_channels = inner_channels // 4# reduce layersself.reduce_conv_c2 = ConvBnRelu(in_channels[0], inner_channels, kernel_size=1, inplace=inplace)self.reduce_conv_c3 = ConvBnRelu(in_channels[1], inner_channels, kernel_size=1, inplace=inplace)self.reduce_conv_c4 = ConvBnRelu(in_channels[2], inner_channels, kernel_size=1, inplace=inplace)self.reduce_conv_c5 = ConvBnRelu(in_channels[3], inner_channels, kernel_size=1, inplace=inplace)# Smooth layersself.smooth_p4 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)self.smooth_p3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)self.smooth_p2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, padding=1, inplace=inplace)self.conv = nn.Sequential(nn.Conv2d(self.conv_out, self.conv_out, kernel_size=3, padding=1, stride=1),nn.BatchNorm2d(self.conv_out),nn.ReLU(inplace=inplace))self.out_channels = self.conv_outdef forward(self, x):c2, c3, c4, c5 = x# Top-downp5 = self.reduce_conv_c5(c5)p4 = self._upsample_add(p5, self.reduce_conv_c4(c4))p4 = self.smooth_p4(p4)p3 = self._upsample_add(p4, self.reduce_conv_c3(c3))p3 = self.smooth_p3(p3)p2 = self._upsample_add(p3, self.reduce_conv_c2(c2))p2 = self.smooth_p2(p2)x = self._upsample_cat(p2, p3, p4, p5)x = self.conv(x)return xdef _upsample_add(self, x, y):return F.interpolate(x, size=y.size()[2:]) + ydef _upsample_cat(self, p2, p3, p4, p5):h, w = p2.size()[2:]p3 = F.interpolate(p3, size=(h, w))p4 = F.interpolate(p4, size=(h, w))p5 = F.interpolate(p5, size=(h, w))return torch.cat([p2, p3, p4, p5], dim=1)

4、DBHead

# -*- coding: utf-8 -*-
# @Time    : 2019/12/4 14:54
# @Author  : zhoujun
import torch
from torch import nnclass DBHead(nn.Module):def __init__(self, in_channels, out_channels, k=50):super().__init__()self.k = kself.binarize = nn.Sequential(nn.Conv2d(in_channels, in_channels // 4, 3, padding=1),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, in_channels // 4, 2, 2),nn.BatchNorm2d(in_channels // 4),nn.ReLU(inplace=True),nn.ConvTranspose2d(in_channels // 4, 1, 2, 2),nn.Sigmoid())self.binarize.apply(self.weights_init)self.thresh = self._init_thresh(in_channels)self.thresh.apply(self.weights_init)def forward(self, x):shrink_maps = self.binarize(x)threshold_maps = self.thresh(x)if self.training:binary_maps = self.step_function(shrink_maps, threshold_maps)y = torch.cat((shrink_maps, threshold_maps, binary_maps), dim=1)else:y = torch.cat((shrink_maps, threshold_maps), dim=1)return ydef weights_init(self, m):classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.kaiming_normal_(m.weight.data)elif classname.find('BatchNorm') != -1:m.weight.data.fill_(1.)m.bias.data.fill_(1e-4)def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):in_channels = inner_channelsif serial:in_channels += 1self.thresh = nn.Sequential(nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, inner_channels // 4, smooth=smooth, bias=bias),nn.BatchNorm2d(inner_channels // 4),nn.ReLU(inplace=True),self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),nn.Sigmoid())return self.threshdef _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):if smooth:inter_out_channels = out_channelsif out_channels == 1:inter_out_channels = in_channelsmodule_list = [nn.Upsample(scale_factor=2, mode='nearest'),nn.Conv2d(in_channels, inter_out_channels, 3, 1, 1, bias=bias)]if out_channels == 1:module_list.append(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=1, bias=True))return nn.Sequential(module_list)else:return nn.ConvTranspose2d(in_channels, out_channels, 2, 2)def step_function(self, x, y):return torch.reciprocal(1 + torch.exp(-self.k * (x - y)))

DBNet笔记(五)DBNet网络是怎么搭建起来的相关推荐

  1. linux发指令给网络设备,Linux学习笔记五(网络命令)

    1. 给用户发信息命令:write 指令名称:write 指令所在路径:/usr/bin/write 执行权限:所有用户 语法:write 功能描述:给用户发信息,以Ctrl+D保存结束 范例: # ...

  2. 区块链学习笔记五 BTC网络

  3. Android学习笔记第五篇--网络连接与云服务(一)

    Android学习笔记第五篇–网络连接与云服务 第一章.无线连接设备 ​ 除了能够在云端通讯,Android的无线API也允许在同一局域网内的设备通讯,**甚至没有连接网络,而是物理具体相近,也可以相 ...

  4. muduo网络库学习笔记(五) 链接器Connector与监听器Acceptor

    目录 muduo网络库学习笔记(五) 链接器Connector与监听器Acceptor Connector 系统函数connect 处理非阻塞connect的步骤: Connetor时序图 Accep ...

  5. 好程序员大数据笔记之:Hadoop集群搭建

    好程序员大数据笔记之:Hadoop集群搭建在学习大数据的过程中,我们接触了很多关于Hadoop的理论和操作性的知识点,尤其在近期学习的Hadoop集群的搭建问题上,小细节,小难点拼频频出现,所以,今天 ...

  6. yolov3从头实现(四)-- darknet53网络tf.keras搭建

    darknet53网络tf.keras搭建 一.定义darknet块类 1 .darknet块网络结构 2.darknet块实现 # 定义darknet块类 class _ResidualBlock( ...

  7. 【长篇博文】Docker学习笔记与深度学习环境的搭建和部署(二)

    长篇博文记录学习流程不容易,请关注.转发.点赞.评论,谢谢! 上一篇文章:Docker学习笔记与深度学习环境的搭建和部署(一) 文章末尾附加nvidia455.23.cuda11.1.cudnn8.0 ...

  8. Java学习笔记(五):一张图总结完JVM8基础概念

    Java学习笔记(五):一张图总结完JVM8基础概念 引文 最近在学习JVM的相关内容,好不容易把基础概念全部都学了一遍,却发现知识网络是零零散散的.迫不得已,只好再来一次总的归纳总结.为了更好的理解 ...

  9. python微信公众号翻译功能_自学Python笔记:给微信公众号搭建“成绩查询”功能...

    原标题:自学Python笔记:给微信公众号搭建"成绩查询"功能 期末考试 临近年末,全国各地都在上演一场大戏<期末考试>,考完试无论什么样的结果总想尽快看到自己一个学期 ...

  10. 校园版网络教学平台搭建方案(学生端)

    随着在线教育的发展,很多学校也开始考虑接入内网完善网络教学平台搭建方案,为的就是能够避免由于类似疫情这种情况所造成不能到校上课的情况,接下来小编将会提供一套完整的针对于学生端的内网网络教学平台搭建方案 ...

最新文章

  1. windows下安装和配置Redis
  2. 2016年秋季个人阅读计划
  3. 线程基础知识_线程生命周期_从JVM内存结构看多线程下的共享资源
  4. django新建一个项目_如何使用Django创建项目
  5. python修改散点图中点的颜色_如何在seaborn散点图中更改点边颜色?
  6. python——socket网络编程
  7. 云龙51单片机视频教程+课件+程序代码+课后作业,零基础入门视频教程
  8. 《松本行弘的程序世界》读书笔记
  9. 西方哲学包括计算机科学吗,哲学类专业包括哪些专业
  10. 全新2021款 Jlink隔离器,ARM仿真器隔离,Jlink,Nu-link,ULINK的隔离,Cortex-M系列隔离仿真
  11. python处理搜狗新闻数据_140万条
  12. 《查拉图斯特拉如是说》读书笔记
  13. linux nfs不在同一个网络,NFS共享机制
  14. 使用tesserocr二值化识别知网登录验证码
  15. Canvas百战成神-圆(1)
  16. poi对excel进行读取
  17. Windows 10/11 官方下载工具 镜像制作U盘启动盘 快速安装
  18. 腾讯十大最受欢迎的开源项目!
  19. 情绪,是人的底层操作系统!掌控情绪,才能掌控人生
  20. 将图和双向transformers融合的分子性质预测新工具 简称AGBT - 论文阅读

热门文章

  1. 电商推荐系统(上):推荐系统架构、数据模型、离线统计与机器学习推荐、历史热门商品、最近热门商品、商品平均得分统计推荐、基于隐语义模型的协同过滤推荐、用户商品推荐列表、商品相似度矩阵、模型评估和参数选取
  2. 编写Bat脚本调用Vecotr工具软件HexView
  3. c语言中a的小写字母,a小写字母【a小写字母英语头条】- a小写字母知识点 - 中企动力...
  4. android x5 webview mixed content,关于 Webview 的混合模式(华为 P30 出现问题)
  5. android 微信长链接,微信7.0.17内测链接
  6. 写一个 JS 调用栈可视化工具 hound-trace
  7. mysql更新表数据时报错 You can't specify target table 'RES_CATALOG_CLASSIFY' for update in FROM clause...
  8. 今晚直播 |重磅!Anchor Free系列算法强势来袭!
  9. 滚动到后,如何使div停留在屏幕顶部?
  10. potato什么软件_张鹏都这么努力又出新项目美生在线,你还有什么资格不努力!?...