目录

  • 引言
  • SDMG-R整体结构
  • 双模态融合模块
    • Backbone部分
    • Head部分
    • 融合模块
  • 文本节点与边权重获得部分
  • 图推理模块
  • 分类模块
  • 总结

引言

  • 文档图像中的关键信息提取任务(Key Information Extraction, KIE)是实现办公场景自动化的一项重要任务。
  • 如果从OCR场景来看,KIE任务可以作为对OCR提取结果内容的结构化抽取来使用。
  • 这次介绍的论文是Spatial Dual-Modality Graph Reasoning for Key Information Extraction,该论文是在去年3月份挂在arxiv上的,但是具体作者机构暂时没有给出。从论文arxiv主页上来看,其官方Code的链接是MMOCR。大胆猜测,难道作者是商汤的?这就不得而知了。
  • 值得一提的是,PaddleOCR最近也集成了该算法,看来关键信息抽取这一研究方向得到了工业界的一些重视,距离真正落地使用不会太远了。
  • 由于该算法在MMOCR和PaddleOCR中均有实现,考虑到跑通示例程序的便利性,这里以PaddleOCR中实现为例,指出相应的实现源码,用作学习之用。

SDMG-R整体结构

  • 从结构图来看,论文的思路比较清晰。整体结构可分为三个模块:双模态融合模块图推理模块分类模块三个。

双模态融合模块

该模块结合视觉特征和文本特征。其中视觉特征vi{v_{i}}vi来自U-NetROI-Pooling提取所得,文本特征ti{t_{i}}ti则是通过Bi-LSTM提取得到的。两个不同模态的特征通过Kronecker乘积操作得以融合。这样得以充分利用图像的一维和二维信息,这也是Dual Modality名称的由来。

Backbone部分

在PaddleOCR中backbone:U-Net,位于ppocr/modeling/backbones/kie_unet_sdmgr.py,主要包括U-NetROI Align两部分。对应的源码如下(省略部分代码),重点在于forward部分:

class Kie_backbone(nn.Layer):def __init__(self, in_channels, **kwargs):super(Kie_backbone, self).__init__()self.out_channels = 16self.img_feat = UNet()self.maxpool = nn.MaxPool2D(kernel_size=7)def bbox2roi(self, bbox_list):passdef pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):passdef forward(self, inputs):# img shape: [1, 3, 1024, 1024]img = inputs[0]relations, texts = inputs[1], inputs[2]gt_bboxes, tag, img_size = inputs[3], inputs[5], inputs[-1]# 前处理img, relations, texts, gt_bboxes = self.pre_process(img, relations, texts, gt_bboxes, tag, img_size)# 此时img shape: [1, 3, 512, 512]x = self.img_feat(img)# output x shape: [1, 16, 512, 512]boxes, rois_num = self.bbox2roi(gt_bboxes)feats = paddle.fluid.layers.roi_align(x,boxes,spatial_scale=1.0,pooled_height=7,pooled_width=7,rois_num=rois_num)# feats shape: [26, 16, 7, 7]feats = self.maxpool(feats).squeeze(-1).squeeze(-1)# output feats shape: [26, 16]return [relations, texts, feats]
Head部分

该部分是文本特征的提取,以及backbone部分提取所得图像特征和文本特征的合并部分。这部分主要集中在PaddleOCR代码中head部分下。

class SDMGRHead(nn.Layer):def __init__(self,in_channels,num_chars=92,visual_dim=16,fusion_dim=1024,node_input=32,node_embed=256,edge_input=5,edge_embed=256,num_gnn=2,num_classes=26,bidirectional=False):super().__init__()# 融合模块self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)self.node_embed = nn.Embedding(num_chars, node_input, 0)hidden = node_embed // 2 if bidirectional else node_embed# 单层LSTM模块self.rnn = nn.LSTM(input_size=node_input, hidden_size=hidden, num_layers=1)# 图推理模块self.edge_embed = nn.Linear(edge_input, edge_embed)self.gnn_layers = nn.LayerList([GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])# 分类模块self.node_cls = nn.Linear(node_embed, num_classes)self.edge_cls = nn.Linear(edge_embed, 2)def forward(self, input, targets):relations, texts, x = inputnode_nums, char_nums = [], []for text in texts:node_nums.append(text.shape[0])char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))max_num = max([char_num.max() for char_num in char_nums])all_nodes = paddle.concat([paddle.concat([text, paddle.zeros((text.shape[0], max_num - text.shape[1]))], -1)for text in texts])temp = paddle.clip(all_nodes, min=0).astype(int)embed_nodes = self.node_embed(temp)rnn_nodes, _ = self.rnn(embed_nodes)b, h, w = rnn_nodes.shapenodes = paddle.zeros([b, w])all_nums = paddle.concat(char_nums)valid = paddle.nonzero((all_nums > 0).astype(int))temp_all_nums = (paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)temp_all_nums = paddle.expand(temp_all_nums, [temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]])temp_all_nodes = paddle.gather(rnn_nodes, valid)N, C, A = temp_all_nodes.shapeone_hot = F.one_hot(temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])one_hot = paddle.multiply(temp_all_nodes, one_hot.astype("float32")).sum(axis=1,keepdim=True)t = one_hot.expand([N, 1, A]).squeeze(1)nodes = paddle.scatter(nodes, valid.squeeze(1), t)# 图像特征和文本特征融合if x is not None:nodes = self.fusion([x, nodes])all_edges = paddle.concat([rel.reshape([-1, rel.shape[-1]]) for rel in relations])embed_edges = self.edge_embed(all_edges.astype('float32'))embed_edges = F.normalize(embed_edges)# 将节点特征和边的权重信息整合到一起# 图推理模块for gnn_layer in self.gnn_layers:nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)# 分类模块node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)return node_cls, edge_cls