论文阅读: Spatial Dual-Modality Graph Reasoning for Key Information Extraction (关键信息提取算法)
目录
- 引言
- SDMG-R整体结构
- 双模态融合模块
- Backbone部分
- Head部分
- 融合模块
- 文本节点与边权重获得部分
- 图推理模块
- 分类模块
- 总结
引言
- 文档图像中的关键信息提取任务(Key Information Extraction, KIE)是实现办公场景自动化的一项重要任务。
- 如果从OCR场景来看,KIE任务可以作为对OCR提取结果内容的结构化抽取来使用。
- 这次介绍的论文是Spatial Dual-Modality Graph Reasoning for Key Information Extraction,该论文是在去年3月份挂在arxiv上的,但是具体作者机构暂时没有给出。从论文arxiv主页上来看,其官方Code的链接是MMOCR。大胆猜测,难道作者是商汤的?这就不得而知了。
- 值得一提的是,PaddleOCR最近也集成了该算法,看来关键信息抽取这一研究方向得到了工业界的一些重视,距离真正落地使用不会太远了。
- 由于该算法在MMOCR和PaddleOCR中均有实现,考虑到跑通示例程序的便利性,这里以PaddleOCR中实现为例,指出相应的实现源码,用作学习之用。
SDMG-R整体结构
- 从结构图来看,论文的思路比较清晰。整体结构可分为三个模块:双模态融合模块、图推理模块和分类模块三个。
双模态融合模块
该模块结合视觉特征和文本特征。其中视觉特征vi{v_{i}}vi来自U-Net和ROI-Pooling提取所得,文本特征ti{t_{i}}ti则是通过Bi-LSTM提取得到的。两个不同模态的特征通过Kronecker乘积操作得以融合。这样得以充分利用图像的一维和二维信息,这也是Dual Modality名称的由来。
Backbone部分
在PaddleOCR中backbone:U-Net,位于ppocr/modeling/backbones/kie_unet_sdmgr.py,主要包括U-Net和ROI Align两部分。对应的源码如下(省略部分代码),重点在于forward
部分:
class Kie_backbone(nn.Layer):def __init__(self, in_channels, **kwargs):super(Kie_backbone, self).__init__()self.out_channels = 16self.img_feat = UNet()self.maxpool = nn.MaxPool2D(kernel_size=7)def bbox2roi(self, bbox_list):passdef pre_process(self, img, relations, texts, gt_bboxes, tag, img_size):passdef forward(self, inputs):# img shape: [1, 3, 1024, 1024]img = inputs[0]relations, texts = inputs[1], inputs[2]gt_bboxes, tag, img_size = inputs[3], inputs[5], inputs[-1]# 前处理img, relations, texts, gt_bboxes = self.pre_process(img, relations, texts, gt_bboxes, tag, img_size)# 此时img shape: [1, 3, 512, 512]x = self.img_feat(img)# output x shape: [1, 16, 512, 512]boxes, rois_num = self.bbox2roi(gt_bboxes)feats = paddle.fluid.layers.roi_align(x,boxes,spatial_scale=1.0,pooled_height=7,pooled_width=7,rois_num=rois_num)# feats shape: [26, 16, 7, 7]feats = self.maxpool(feats).squeeze(-1).squeeze(-1)# output feats shape: [26, 16]return [relations, texts, feats]
Head部分
该部分是文本特征的提取,以及backbone部分提取所得图像特征和文本特征的合并部分。这部分主要集中在PaddleOCR代码中head部分下。
class SDMGRHead(nn.Layer):def __init__(self,in_channels,num_chars=92,visual_dim=16,fusion_dim=1024,node_input=32,node_embed=256,edge_input=5,edge_embed=256,num_gnn=2,num_classes=26,bidirectional=False):super().__init__()# 融合模块self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)self.node_embed = nn.Embedding(num_chars, node_input, 0)hidden = node_embed // 2 if bidirectional else node_embed# 单层LSTM模块self.rnn = nn.LSTM(input_size=node_input, hidden_size=hidden, num_layers=1)# 图推理模块self.edge_embed = nn.Linear(edge_input, edge_embed)self.gnn_layers = nn.LayerList([GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])# 分类模块self.node_cls = nn.Linear(node_embed, num_classes)self.edge_cls = nn.Linear(edge_embed, 2)def forward(self, input, targets):relations, texts, x = inputnode_nums, char_nums = [], []for text in texts:node_nums.append(text.shape[0])char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))max_num = max([char_num.max() for char_num in char_nums])all_nodes = paddle.concat([paddle.concat([text, paddle.zeros((text.shape[0], max_num - text.shape[1]))], -1)for text in texts])temp = paddle.clip(all_nodes, min=0).astype(int)embed_nodes = self.node_embed(temp)rnn_nodes, _ = self.rnn(embed_nodes)b, h, w = rnn_nodes.shapenodes = paddle.zeros([b, w])all_nums = paddle.concat(char_nums)valid = paddle.nonzero((all_nums > 0).astype(int))temp_all_nums = (paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)temp_all_nums = paddle.expand(temp_all_nums, [temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]])temp_all_nodes = paddle.gather(rnn_nodes, valid)N, C, A = temp_all_nodes.shapeone_hot = F.one_hot(temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])one_hot = paddle.multiply(temp_all_nodes, one_hot.astype("float32")).sum(axis=1,keepdim=True)t = one_hot.expand([N, 1, A]).squeeze(1)nodes = paddle.scatter(nodes, valid.squeeze(1), t)# 图像特征和文本特征融合if x is not None:nodes = self.fusion([x, nodes])all_edges = paddle.concat([rel.reshape([-1, rel.shape[-1]]) for rel in relations])embed_edges = self.edge_embed(all_edges.astype('float32'))embed_edges = F.normalize(embed_edges)# 将节点特征和边的权重信息整合到一起# 图推理模块for gnn_layer in self.gnn_layers:nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)# 分类模块node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)return node_cls, edge_cls
- 不过,从MMOCR中实现的代码来看,并没有采用Bi-LSTM,而是采用的是LSTM。这一点从Issue #491中,可以看到开发者对此的说法是:使用Bi-LSTM和单独使用LSTM并没有多大区别。
- 同时,我又去比对了PaddleOCR中此部分的实现,同样也是采用的LSTM。严重怀疑,PaddleOCR的该算法代码是转写的MMOCR的。
论文阅读: Spatial Dual-Modality Graph Reasoning for Key Information Extraction (关键信息提取算法)相关推荐
- 论文阅读课1-Attention Guided Graph Convolutional Networks for Relation Extraction(关系抽取,图卷积,ACL2019,n元)
文章目录 abstract 1.introduction 1.1 dense connection+GCN 1.2 效果突出 1.3 contribution 2.Attention Guided G ...
- 【论文阅读笔记】Myers的O(ND)时间复杂度的高效的diff算法
前言 之前咱们三个同学做了个Simple-SCM,我负责那个Merge模块,也就是对两个不同分支的代码进行合并.当时为了简便起见,遇到文件冲突的时候,就直接按照文件的更改日期来存储,直接把更改日期较新 ...
- 论文阅读:Entangled Watermarks as a Defense against Model Extraction
论文阅读:Entangled Watermarks as a Defense against Model Extraction 这里给大家分享一篇有关模型水印的论文.这篇文章2021年发布在the P ...
- 论文阅读:DuEE:A Large-Scale Dataset for Chinese Event Extraction in Real-World Scenarios(附数据集地址)
论文阅读:DuEE:A Large-Scale Dataset for Chinese Event Extraction in Real-World Scenarios 基于现实场景的大规模中文事件抽 ...
- 【论文阅读】基于视图的图卷积神经网络3D物体形状识别算法
原文地址:点击访问 本期,为大家推送CVPR 2020一篇关于图神经网络与3D相关的文章.自我感觉挺有趣的,有兴趣的同学推荐一读. 论文题目:View-GCN: View-based Graph Co ...
- 论文阅读笔记:Towards Fine-Grained Reasoning for Fake News Detection
Towards Fine-Grained Reasoning for Fake News Detection author={Jin, Yiqiao and Wang, Xiting and Yang ...
- 论文阅读笔记《Neural Graph Matching Network: Learning Lawler’s Quadratic Assignment Problem With Extension》
核心思想 该文提出一种图匹配神经网络用于解决Lawler's形式的二次分配问题,并将其推广到超图匹配和多图匹配领域.在之前的文章中,我们介绍过图匹配问题通常被定义为一种二次分配问题(QAP),通常 ...
- 【论文阅读ACL2020】Leveraging Graph to Improve Abstractive Multi-Document Summarization
题目:Leveraging Graph to Improve Abstractive Multi-Document Summarization (基于图表示的生成式多文档摘要方法 ) 会议:ACL20 ...
- 论文阅读:A Novel Graph based Trajectory Predictor with Pseudo Oracle
A Novel Graph based Trajectory Predictor with Pseudo Oracle 摘要 1 引言 2 相关工作 3 PROPOSED METHOD IV. EXP ...
最新文章
- 一打在2019年亮相的迷人科技项目:飞行汽车、子弹头列车、登月、……
- day1学python Hello Python
- 第3章 Python 数字图像处理(DIP) - 灰度变换与空间滤波13 - 平滑低通滤波器 -盒式滤波器核
- mwc校准油门_编写下载服务器。 第五部分:油门下载速度
- mysql 导出数据库中的某张数据表_mysql 导出数据库中的某张数据表
- 【20171005】Luogu P1164 小A点菜
- 信息学奥赛一本通 1141:删除单词后缀 | OpenJudge NOI 1.7 20
- Hibernate 学习-1
- jmeter_linux下运行
- 机器学习基础知识之概率论基础详解
- Github中那些迷之缩写?LGTM?
- 把已有项目转换成Visual Studio的解决方案
- 前端高效开发必备——常用js框架和第三方插件
- 逆水寒能不能网页预约服务器,不是说《逆水寒》凉了吗,为什么新服预约不到1小时就满了?...
- VB中的界面设计原则和编程技巧
- 微信转账php开发心得
- 从零开始开发Android相机app(三)简单介绍图像滤镜功能
- Android 7.0 Vold工作流程
- 罗斯柴尔德家族:“大道无形”世界首富
- 【论文阅读】DPLVO: Direct Point-Line Monocular Visual Odometry
热门文章
- 论文阅读课1-Attention Guided Graph Convolutional Networks for Relation Extraction(关系抽取,图卷积,ACL2019,n元)