pytorchOCR之PSEnet

论文链接
官方代码

论文解读这里就不做了,网上很多。这里只对项目代码解读。

标签制作

  • 借用论文里的图,如图所示,需要生成若干个(自己设定,论文中为6)黑白图,文字部分为白即为1,背景部分为黑即为0. 白色最大的为文字分割图,最小的文中叫做kernel图,通过这样可以分开临近的文本。
  • 在ptocr/dataloader/DetLoad/MakeSegMap.py里的
def shrink(self,bboxes, rate, max_shr=20):rate = rate * rateshrinked_bboxes = []for bbox in bboxes:area = plg.Polygon(bbox).area()peri = self.perimeter(bbox)pco = pyclipper.PyclipperOffset()pco.AddPath(bbox, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)offset = min((int)(area * (1 - rate) / (peri + 0.001) + 0.5), max_shr)shrinked_bbox = pco.Execute(-offset)if len(shrinked_bbox) == 0:shrinked_bboxes.append(bbox)continueshrinked_bbox = np.array(shrinked_bbox)[0]shrinked_bbox = np.array(shrinked_bbox)if shrinked_bbox.shape[0] <= 2:shrinked_bboxes.append(bbox)continueshrinked_bboxes.append(shrinked_bbox)return np.array(shrinked_bboxes)

通过这个函数将标注框缩小,得到每个缩小的框。最后用opencv生成分割图。

模型解读

该检测方法是基于分割,论文使用FPN作为分割网络,其中backbone为resnet50,参看
ptocr/model/backbone/det_resnet.py部分代码如下

 def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x2 = self.layer1(x)x3 = self.layer2(x2)x4 = self.layer3(x3)x5 = self.layer4(x4)return x2, x3, x4, x5

经过该backbone返回四个map(x2,x3,x4,x5),分别为原图的1/4,1/8,1/16,1/32.此四个map 进入ptocr/model/head/det_FPNHead.py,如下:
该部分是fpn不同深度的map融合部分

self.toplayer = ConvBnRelu(in_channels[-1], inner_channels, kernel_size=1, stride=1,padding=0,bias=bias)  # Reduce channels
# Smooth layers
self.smooth1 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
self.smooth2 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
self.smooth3 = ConvBnRelu(inner_channels, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)
# Lateral layers
self.latlayer1 = ConvBnRelu(in_channels[-2], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
self.latlayer2 = ConvBnRelu(in_channels[-3], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
self.latlayer3 = ConvBnRelu(in_channels[-4], inner_channels, kernel_size=1, stride=1, padding=0,bias=bias)
# Out map
self.conv_out = ConvBnRelu(inner_channels * 4, inner_channels, kernel_size=3, stride=1, padding=1,bias=bias)

在config的yaml中需要设置in_channels和inner_channels,其中in_channels分别对应着不同尺度输出map(x2,x3,x4,x5)的channel数目,如果你想改变backbone,这里也要根据实际情况做相应改变,inner_channels可以随意设置,但是一般根据backbone来调整。

def forward(self, x):c2, c3, c4, c5 = x##p5 = self.toplayer(c5)c4 = self.latlayer1(c4)p4 = upsample_add(p5, c4)p4 = self.smooth1(p4)c3 = self.latlayer2(c3)p3 = upsample_add(p4, c3)p3 = self.smooth2(p3)c2 = self.latlayer3(c2)p2 = upsample_add(p3, c2)p2 = self.smooth3(p2)##p3 = upsample(p3, p2)p4 = upsample(p4, p2)p5 = upsample(p5, p2)fuse = torch.cat((p2, p3, p4, p5), 1)fuse = self.conv_out(fuse)return fuse

这里操作就是将深层map向上做插值和上一层的map做融合,最后将不同尺度的map进行concat,论文中对此有描述。至此FPN部分完成。于是进入ptocr/model/segout/det_PSE_segout.py

class SegDetector(nn.Module):def __init__(self,inner_channels=256,classes=7):super(SegDetector,self).__init__()self.binarize = nn.Conv2d(inner_channels,classes,1,1,0)def forward(self, x,img):x = self.binarize(x)x = upsample(x,img)if self.training:pre_batch = dict(pre_text=x[:,0])pre_batch['pre_kernel'] = x[:,1:]return pre_batchreturn x

这里就是输出分割图,并把分割图插值成原图大小,这里输出7个分割图,其中第0个为最大对应着图片中文字的分割图,依次不断减小,kernel就是最小的一个分割图即第6个kernel图作用就是用来区分密集文本。

loss 部分

这里用到了分割常用的dice loss,在ptocr/model/loss/basical_loss.py如下:

class DiceLoss(nn.Module):def __init__(self,eps=1e-6):super(DiceLoss,self).__init__()self.eps = epsdef forward(self,pre_score,gt_score,train_mask):pre_score = pre_score.contiguous().view(pre_score.size()[0], -1)gt_score = gt_score.contiguous().view(gt_score.size()[0], -1)train_mask = train_mask.contiguous().view(train_mask.size()[0], -1)pre_score = pre_score * train_maskgt_score = gt_score * train_maska = torch.sum(pre_score * gt_score, 1)b = torch.sum(pre_score * pre_score, 1) + self.epsc = torch.sum(gt_score * gt_score, 1) + self.epsd = (2 * a) / (b + c)dice_loss = torch.mean(d)return 1 - dice_loss

这里共需要三个输入,一个网络输出的7个图,一个标签制作好的7个图,以及这七个图的train_mask,这里train_mask的作用就是使得部分像素不参与loss计算(即这部分的loss为0)。
这里用到了ohem如下:

def ohem_single(score, gt_text, training_mask):pos_num = (int)(np.sum(gt_text > 0.5)) - (int)(np.sum((gt_text > 0.5) & (training_mask <= 0.5)))if pos_num == 0:# selected_mask = gt_text.copy() * 0 # may be not goodselected_mask = training_maskselected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')return selected_maskneg_num = (int)(np.sum(gt_text <= 0.5))neg_num = (int)(min(pos_num * 3, neg_num))if neg_num == 0:selected_mask = training_maskselected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')return selected_maskneg_score = score[gt_text <= 0.5]neg_score_sorted = np.sort(-neg_score)threshold = -neg_score_sorted[neg_num - 1]selected_mask = ((score >= threshold) | (gt_text > 0.5)) & (training_mask > 0.5)selected_mask = selected_mask.reshape(1, selected_mask.shape[0], selected_mask.shape[1]).astype('float32')return selected_mask

这里就是选取负样本中loss排序大的,选择正负样本为1:3,假如正样本有3个,负样本像素就要选9个。选择loss最大的九个。

说明:文中图均来自论文

pytorchOCR之PSEnet相关推荐

  1. pytorchOCR之目录层级结构说明

    pytorchOCR之目录层级结构说明 目录层级结构如下 │ finetune_prune_model.sh │ infer.sh │ make.sh │ README.md │ requiremen ...

  2. 连通域最小外接矩形算法原理_基于分割的文本检测算法之PSENet/PAN/DBNet

    1. 文本检测难点 文本内包含文本,艺术字体,任意方向 ,曲线文字 ,多语言,其他环境因素等是文本检测中的难点 2. 分割 问题1: 语义分割模型是对pixel进行分类,所以理论上讲,可以检测不规则的 ...

  3. 用tensorflow还原PSENet网络

    # PSENet_tensorflow PSENet的tensorflow复现源代码地址:https://github.com/liuheng92/tensorflow_PSENet 参考CSDN博客 ...

  4. fpga实战训练精粹pdf_tensorflow版PSENet 文本检测模型训练和测试

    向AI转型的程序员都关注了这个号??? 机器学习AI算法工程   公众号:datayx psenet核心是为了解决基于分割的算法不能区分相邻文本的问题,以及对任意形状文本的检测问题. psenet依然 ...

  5. CVPR 2019 | 文本检测算法PSENet解读与开源实现

    点击我爱计算机视觉标星,更快获取CVML新技术 作者:刘恒 编辑:CV君 PSENet文本检测算法来自论文<Shape Robust Text Detection with Progressiv ...

  6. 高精度PSEnet文本检测在windows/linux运行教程

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx PSEnet核心是为了解决不能区分相邻文本的问题,以及对任意形状文本的检测问题.PSEnet依 ...

  7. PSENet PANNet DBNet 三个文本检测算法异同

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 这三个文本检测算法都是segment base算法,通过由下而上的方式,先对text进行seg ...

  8. tensorflow版PSENet 文本检测模型训练和测试

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx psenet核心是为了解决基于分割的算法不能区分相邻文本的问题,以及对任意形状文本的检测问题. ...

  9. PSENet原理介绍

        通常OCR中,文字检测都是由目标检测继承而来,目标检测大多都是基于先验框的(anchor base),近期出现的no-anchor模式本质上也是基于先验框的.anchor-base模式在目标检 ...

最新文章

  1. 告别2019:属于深度学习的十年,那些我们必须知道的经典
  2. 存储过程—导出table数据为inser sqlt语句
  3. dm9000 driver 1
  4. springboot pom文件添加mysql组件_SpringBoot+Mybatis 通过databaseIdProvider支持多数据库
  5. 英文名字的取法 分享
  6. python和台达plc通讯_台达PLC通信协议ModbusASCIIDVP
  7. 金融统计分析与挖掘实战3.1-3.2
  8. java小数位-DecimalFormat(转)
  9. Mybatis中的collection和association一关系
  10. UE4源码下载慢的解决方案--代理法
  11. 计算机包括台式机和笔记本,笔记本电脑与台式机怎样连接
  12. 如何从Apple电子钱包中删除旧登机牌
  13. springboot+vue+nodejs多用户网上图书商城系统-含卖家功能java
  14. 博客大赛,我的一场生意一场梦
  15. hdu5544 Ba Gua Zhen(高斯消元)
  16. 中国红霉素市场深度分析与投资前景调研报告2022-2028年
  17. STM32F103C8T6进行DAC播放
  18. 【已解决】ansible 命令报错 Error -5 while decompressing data: incomplete or truncated stream
  19. 石头机器人红灯快闪_5.1南宁上演“科幻大片”!各闹市街头惊现“机器人快闪”...
  20. 334个地级市名单_334个地级市名单_334个地级市的“基层”演出,李志说“如果我死了,得留下点什么”......

热门文章

  1. [EXP]CVE-2019-9621 Zimbra小于8.8.11 远程代码执行漏洞 XXE GetShell Exploit
  2. 计算机我们一起学猫叫谱子,学猫叫简谱-小潘潘-我们一起学猫叫,一起喵喵喵喵喵...
  3. Loading class 'com.mysql.jdbc.Driver', This is deprecated. The new driver class is'com.cj.jdbc.Driv'
  4. 视频教程- 设计讲师吴刚 2019-4-27 12:23:55 【吴刚大讲堂】Photoshop(PS)CC2-Photoshop
  5. 李跃喊了两年的“三新”临盆 中移动做企业社交?没戏!
  6. Java 面向对象基础和进阶
  7. PayPal外贸生意经--外贸零售之节日经济
  8. 爱斯维尔模板错误总结——坑爹货!
  9. linux 恢复修改文件内容,Linux备份及恢复及Linux文件权限详解
  10. 改变心理学的40项研究