序言

前段时间PaddleOCRv3版本发布,更新了检测和识别模型,性能有很大提升,本着能嫖就嫖的原则,刚出来的第一天就开始嫖上了,虽然新模型的性能相较于之前有较大提升,但是乍一看模型结构复杂了很多,部署起来要麻烦了很多,现阶段paddle框架转其他部署框架只能通过转paddle2onnx再转其他框架实现,所以打算踩下坑,提供import paddle as torch版本的模型:将paddle框架的模型权重转到pytorch上,为部署方案提供多一些选择,转换到pytorch框架上后可以通过从pytorch再转其他部署方式,举个之前的例子:使用pnnx把pytorch模型转ncnn模型。

与之前的模型性能对比:

本项目代码实现基于:

  • https://github.com/WenmuZhou/PytorchOCR
  • https://github.com/PaddlePaddle/PaddleOCR

一、paddle2torch

先说下转换原理,因为paddlepaddle和pytorch都是动态框架,所以转换起来比较简单,对于要转换的paddle模型,我们只需要用torch重新构建相同的网络模型结构,然后将paddle的权重取出,一一对应赋值进每一层。看似过程比较简单,但是毕竟是不同的框架,有些op实现也是不同的,难免会踩很多坑。

在转换之前,我们先看一下PaddleOCRV3相对于上一个版本的模型更新了那些模块:

检测模块

  1. LK-PAN:大感受野的PAN结构
  2. DML:教师模型互学习策略
  3. RSE-FPN:残差注意力机制的FPN结构

识别模块

  • SVTR_LCNet:轻量级文本识别网络
  • GTC:Attention指导CTC训练策略
  • TextConAug:挖掘文字上下文信息的数据增广策略
  • TextRotNet:自监督的预训练模型
  • UDML:联合互学习策略
  • UIM:无标注数据挖掘方案

具体的可以看PPOCRV3官方的技术报告,在这里我们只需要关注我们转换的过程需要注意的那些模块即可

二、检测模型转换

首先是检测模块,检测模块有三部分更新,我们只需要关注RSE-FPN,因为前两个都是在训练过程中蒸馏学习对教师模型的优化。

RSE-FPN(Residual Squeeze-and-Excitation FPN)如下图所示,引入残差结构和通道注意力结构,将FPN中的卷积层更换为通道注意力结构的RSEConv层,进一步提升特征图的表征能力。考虑到PP-OCRv2的检测模型中FPN通道数非常小,仅为96,如果直接用SEblock代替FPN中卷积会导致某些通道的特征被抑制,精度会下降。RSEConv引入残差结构会缓解上述问题,提升文本检测效果。进一步将PP-OCRv2中CML的学生模型的FPN结构更新为RSE-FPN,学生模型的hmean可以进一步从84.3%提升到85.4%:

RSE-FPN pytorch代码实现:

class RSELayer(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, shortcut=True):super(RSELayer, self).__init__()self.out_channels = out_channelsself.in_conv = nn.Conv2d(in_channels=in_channels,out_channels=self.out_channels,kernel_size=kernel_size,padding=int(kernel_size // 2),bias=False)self.se_block = SEBlock(self.out_channels,self.out_channels)self.shortcut = shortcutdef forward(self, ins):x = self.in_conv(ins)if self.shortcut:out = x + self.se_block(x)else:out = self.se_block(x)return outclass RSEFPN(nn.Module):def __init__(self, in_channels, out_channels=256, shortcut=True, **kwargs):super(RSEFPN, self).__init__()self.out_channels = out_channelsself.ins_conv = nn.ModuleList()self.inp_conv = nn.ModuleList()for i in range(len(in_channels)):self.ins_conv.append(RSELayer(in_channels[i],out_channels,kernel_size=1,shortcut=shortcut))self.inp_conv.append(RSELayer(out_channels,out_channels // 4,kernel_size=3,shortcut=shortcut))def _upsample_add(self, x, y):return F.interpolate(x, scale_factor=2) + ydef _upsample_cat(self, p2, p3, p4, p5):p3 = F.interpolate(p3, scale_factor=2)p4 = F.interpolate(p4, scale_factor=4)p5 = F.interpolate(p5, scale_factor=8)return torch.cat([p5, p4, p3, p2], dim=1)def forward(self, x):c2, c3, c4, c5 = xin5 = self.ins_conv[3](c5)in4 = self.ins_conv[2](c4)in3 = self.ins_conv[1](c3)in2 = self.ins_conv[0](c2)out4 = self._upsample_add(in5, in4)out3 = self._upsample_add(out4, in3)out2 = self._upsample_add(out3, in2)p5 = self.inp_conv[3](in5)p4 = self.inp_conv[2](out4)p3 = self.inp_conv[1](out3)p2 = self.inp_conv[0](out2)x = self._upsample_cat(p2, p3, p4, p5)return x

完整的网络分为三部分:Backbone(MobileNetV3)、Neck(RSEFPN)、Head(DBHead),借助于PytorchOCR项目,将这三部分分别实现,然后将网络搭建。

from torch import nn
from det.DetMobilenetV3 import MobileNetV3
from det.DB_fpn import DB_fpn,RSEFPN,LKPAN
from det.DetDbHead import DBHeadbackbone_dict = {'MobileNetV3': MobileNetV3}
neck_dict = {'DB_fpn': DB_fpn,'RSEFPN':RSEFPN,'LKPAN':LKPAN}
head_dict = {'DBHead': DBHead}class DetModel(nn.Module):def __init__(self, config):super().__init__()assert 'in_channels' in config, 'in_channels must in model config'backbone_type = config.backbone.pop('type')assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)neck_type = config.neck.pop('type')assert neck_type in neck_dict, f'neck.type must in {neck_dict}'self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)head_type = config.head.pop('type')assert head_type in head_dict, f'head.type must in {head_dict}'self.head = head_dict[head_type](self.neck.out_channels, **config.head)self.name = f'DetModel_{backbone_type}_{neck_type}_{head_type}'def load_3rd_state_dict(self, _3rd_name, _state):self.backbone.load_3rd_state_dict(_3rd_name, _state)self.neck.load_3rd_state_dict(_3rd_name, _state)self.head.load_3rd_state_dict(_3rd_name, _state)def forward(self, x):x = self.backbone(x)x = self.neck(x)x = self.head(x)return xif __name__=="__main__":db_config = AttrDict(in_channels=3,backbone=AttrDict(type='MobileNetV3', model_name='large',scale=0.5,pretrained=True),neck=AttrDict(type='RSEFPN', out_channels=96),head=AttrDict(type='DBHead'))model = DetModel(db_config)

然后使用paddleOCRV3的文字检测训练模型(注意只能用训练模型),将模型的权重和对应的键值取出,分别对应初始化到torch模型中,完整代码在文后链接。

def load_state(path,trModule_state):"""记载paddlepaddle的参数:param path::return:"""if os.path.exists(path + '.pdopt'):# XXX another hack to ignore the optimizer statetmp = tempfile.mkdtemp()dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))shutil.copy(path + '.pdparams', dst + '.pdparams')state = fluid.io.load_program_state(dst)shutil.rmtree(tmp)else:state = fluid.io.load_program_state(path)# for i, key in enumerate(state.keys()):#     print("{}  {} ".format(i, key))state_dict = {}for i, key in enumerate(state.keys()):if key =="StructuredToParameterName@@":continuestate_dict[trModule_state[i]] = torch.from_numpy(state[key])return state_dict

三、识别模型转换

识别模型的转换相对于检测模型要复杂很多,PP-OCRv3的识别模块是基于文本识别算法SVTR优化。SVTR不再采用RNN结构,通过引入Transformers结构更加有效地挖掘文本行图像的上下文信息,从而提升文本识别能力,上面的诸多识别优化中,我们只需要关注第一个优化:SVTR_LCNet,其他的都是训练过程中的训练技巧,在模型转换的过程中不需要用到。

SVTR_LCNet是针对文本识别任务,将基于Transformer的SVTR网络和轻量级CNN网络PP-LCNet 融合的一种轻量级文本识别网络,整体网络如下所示:

使用该网络,预测速度优于PP-OCRv2的识别模型20%,但是由于没有采用蒸馏策略,该识别模型效果略差。此外,进一步将输入图片规范化高度从32提升到48,预测速度稍微变慢,但是模型效果大幅提升,识别准确率达到73.98%(+2.08%),接近PP-OCRv2采用蒸馏策略的识别模型效果,消融实验过程:

同样的,根据paddle的识别网络结构构建torch网络模型,模型分为三部分:Backbone(LCNet)、Encoder(SVTR Transformers)、Head(MultiHead),其中Encoder部分使用了SVTR的Transformers结构编码:

class EncoderWithSVTR(nn.Module):def __init__(self,in_channels,dims=64,  # XSdepth=2,hidden_dims=120,use_guide=False,num_heads=8,qkv_bias=True,mlp_ratio=2.0,drop_rate=0.1,attn_drop_rate=0.1,drop_path=0.,qk_scale=None):super(EncoderWithSVTR, self).__init__()self.depth = depthself.use_guide = use_guideself.conv1 = ConvBNLayer(in_channels, in_channels // 8, padding=1)self.conv2 = ConvBNLayer(in_channels // 8, hidden_dims, kernel_size=1)self.svtr_block = nn.ModuleList([Block(dim=hidden_dims,num_heads=num_heads,mixer='Global',HW=None,mlp_ratio=mlp_ratio,qkv_bias=qkv_bias,qk_scale=qk_scale,drop=drop_rate,act_layer="Swish",attn_drop=attn_drop_rate,drop_path=drop_path,norm_layer='nn.LayerNorm',epsilon=1e-05,prenorm=False) for i in range(depth)])self.norm = nn.LayerNorm(hidden_dims, eps=1e-6)self.conv3 = ConvBNLayer(hidden_dims, in_channels, kernel_size=1)# last conv-nxn, the input is concat of input tensor and conv3 output tensorself.conv4 = ConvBNLayer(2 * in_channels, in_channels // 8, padding=1)self.conv1x1 = ConvBNLayer(in_channels // 8, dims, kernel_size=1)self.out_channels = dimsself.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):trunc_normal_(m.weight)if isinstance(m, nn.Linear) and m.bias is not None:zeros_(m.bias)elif isinstance(m, nn.LayerNorm):zeros_(m.bias)ones_(m.weight)def forward(self, x):# for use guideif self.use_guide:z = x.clone()z.stop_gradient = Trueelse:z = x# for short cuth = z# reduce dimz = self.conv1(z)z = self.conv2(z)# SVTR global blockB, C, H, W = z.shapez = z.flatten(2).permute([0, 2, 1])for blk in self.svtr_block:z = blk(z)z = self.norm(z)# last stagez = z.reshape([-1, H, W, C]).permute([0, 3, 1, 2])z = self.conv3(z)z = torch.cat((h, z), dim=1)z = self.conv1x1(self.conv4(z))return z

Head部分是一个多头,但是在推理的时候实际上也只用了CTCHead,把训练时候的SARHead去掉了,所以这部分不需要在网络构建时加进去。

class MultiHead(nn.Module):def __init__(self, in_channels, **kwargs):super().__init__()self.out_c = kwargs.get('n_class')self.head_list = kwargs.get('head_list')self.gtc_head = 'sar'# assert len(self.head_list) >= 2for idx, head_name in enumerate(self.head_list):# name = list(head_name)[0]name = head_nameif name == 'SARHead':# sar headsar_args = self.head_list[name]self.sar_head = eval(name)(in_channels=in_channels, out_channels=self.out_c, **sar_args)if name == 'CTC':# ctc neckself.encoder_reshape = Im2Seq(in_channels)neck_args = self.head_list[name]['Neck']encoder_type = neck_args.pop('name')self.encoder = encoder_typeself.ctc_encoder = SequenceEncoder(in_channels=in_channels,encoder_type=encoder_type, **neck_args)# ctc headhead_args = self.head_list[name]self.ctc_head = eval(name)(in_channels=self.ctc_encoder.out_channels,n_class=self.out_c, **head_args)else:raise NotImplementedError('{} is not supported in MultiHead yet'.format(name))def forward(self, x, targets=None):ctc_encoder = self.ctc_encoder(x)ctc_out = self.ctc_head(ctc_encoder, targets)head_out = dict()head_out['ctc'] = ctc_outhead_out['ctc_neck'] = ctc_encoderreturn ctc_out                          # infer   不经过SAR直接返回# # eval mode# print(not self.training)# if not self.training:                 # training#     return ctc_out# if self.gtc_head == 'sar':#     sar_out = self.sar_head(x, targets[1:])#     head_out['sar'] = sar_out#     return head_out# else:#     return head_out

完整的网络构建:

from torch import nnfrom rec.RNN import SequenceEncoder, Im2Seq,Im2Im
from rec.RecSVTR import SVTRNet
from rec.RecMv1_enhance import MobileNetV1Enhancefrom rec.RecCTCHead import CTC,MultiHeadbackbone_dict = {"SVTR":SVTRNet,"MobileNetV1Enhance":MobileNetV1Enhance}
neck_dict = {'PPaddleRNN': SequenceEncoder, 'Im2Seq': Im2Seq,'None':Im2Im}
head_dict = {'CTC': CTC,'Multi':MultiHead}class RecModel(nn.Module):def __init__(self, config):super().__init__()assert 'in_channels' in config, 'in_channels must in model config'backbone_type = config.backbone.pop('type')assert backbone_type in backbone_dict, f'backbone.type must in {backbone_dict}'self.backbone = backbone_dict[backbone_type](config.in_channels, **config.backbone)neck_type = config.neck.pop('type')assert neck_type in neck_dict, f'neck.type must in {neck_dict}'self.neck = neck_dict[neck_type](self.backbone.out_channels, **config.neck)head_type = config.head.pop('type')assert head_type in head_dict, f'head.type must in {head_dict}'self.head = head_dict[head_type](self.neck.out_channels, **config.head)self.name = f'RecModel_{backbone_type}_{neck_type}_{head_type}'def load_3rd_state_dict(self, _3rd_name, _state):self.backbone.load_3rd_state_dict(_3rd_name, _state)self.neck.load_3rd_state_dict(_3rd_name, _state)self.head.load_3rd_state_dict(_3rd_name, _state)def forward(self, x):x = self.backbone(x)x = self.neck(x)x = self.head(x)return xif __name__=="__main__":rec_config = AttrDict(in_channels=3,backbone=AttrDict(type='MobileNetV1Enhance', scale=0.5,last_conv_stride=[1,2],last_pool_type='avg'),neck=AttrDict(type='None'),head=AttrDict(type='Multi',head_list=AttrDict(CTC=AttrDict(Neck=AttrDict(name="svtr",dims=64,depth=2,hidden_dims=120,use_guide=True)),# SARHead=AttrDict(enc_dim=512,max_text_length=70)),n_class=6625))model = RecModel(rec_config)

同样的,加载paddleocrv3的识别训练模型,将权重对应键值取出,初始化到torch模型中,但是这里需要注意的是,paddle中的全链接层和torch中全链接层的权重形状问题,paddle的全链接层赋值到torch的全链接层的时候,权重需要做一个转置transpose():

def load_state(path,trModule_state):"""记载paddlepaddle的参数:param path::return:"""if os.path.exists(path + '.pdopt'):# XXX another hack to ignore the optimizer statetmp = tempfile.mkdtemp()dst = os.path.join(tmp, os.path.basename(os.path.normpath(path)))shutil.copy(path + '.pdparams', dst + '.pdparams')state = fluid.io.load_program_state(dst)shutil.rmtree(tmp)else:state = fluid.io.load_program_state(path)# for i, key in enumerate(state.keys()):#     print("{}  {} ".format(i, key))keys = ["head.ctc_encoder.encoder.svtr_block.0.mixer.qkv.weight","head.ctc_encoder.encoder.svtr_block.0.mixer.proj.weight","head.ctc_encoder.encoder.svtr_block.0.mlp.fc1.weight","head.ctc_encoder.encoder.svtr_block.0.mlp.fc2.weight","head.ctc_encoder.encoder.svtr_block.1.mixer.qkv.weight","head.ctc_encoder.encoder.svtr_block.1.mixer.proj.weight","head.ctc_encoder.encoder.svtr_block.1.mlp.fc1.weight","head.ctc_encoder.encoder.svtr_block.1.mlp.fc2.weight","head.ctc_head.fc.weight",]state_dict = {}for i, key in enumerate(state.keys()):if key =="StructuredToParameterName@@":continueif i > 238:j = i-239if j <= 195:if trModule_state[j] in keys:state_dict[trModule_state[j]] = torch.from_numpy(state[key]).transpose(0,1)else:state_dict[trModule_state[j]] = torch.from_numpy(state[key])return state_dict

paddleocr的训练模型链接PaddleOCR:

完整代码已经扔到github上,欢迎白嫖学习。

paddle2torch_PPOCRv3

PPOCRv3模型转pytorch相关推荐

  1. 从LeNet-5 CNN模型入门PyTorch

    从LeNet-5 CNN模型入门PyTorch 1. PyTorch 准备 1.1 PyTorch特点 1.2 PyTorch安装测试 2. 完整代码 2.1 LeNet模型 2.2 训练 2.2 测 ...

  2. 推荐系统 | 基础推荐模型 | 逻辑回归模型 | LS-PLM | PyTorch实现

    基础推荐模型--传送门: 推荐系统 | 基础推荐模型 | 协同过滤 | UserCF与ItemCF的Python实现及优化 推荐系统 | 基础推荐模型 | 矩阵分解模型 | 隐语义模型 | PyTor ...

  3. python pytorch语音识别_PyTorch通过ASR实现语音到文本端的模型以及pytorch语音识别(speech) - pytorch中文网...

    ASR,英文的全称是Automated Speech Recognition,即自动语音识别技术,它是一种将人的语音转换为文本的技术.今天我们主要了解pytorch实现语音到文本的端到端模型. spe ...

  4. 可高效训练超大规模图模型,PyTorch BigGraph是如何做到的?

    选自medium 作者:Jesus Rodriguez 机器之心编译 编辑:Panda Facebook 提出了一种可高效训练包含数十亿节点和数万亿边的图模型的框架 BigGraph 并开源了其 Py ...

  5. 如何将tensorflow模型转PYTORCH模型

    将tensorflow版本的.ckpt模型转成pytorch的.bin模型 - 最咸的鱼 - 博客园

  6. 【NLP】Github标星7.7k+:常见NLP模型的PyTorch代码实现

    推荐github上的一个NLP代码教程:nlp-tutorial,教程中包含常见的NLP模型代码实现(基于Pytorch1.0+),而且教程中的大多数NLP模型都使用少于100行代码. 教程说明 这是 ...

  7. 数行代码训练视频模型,PyTorch视频理解利器出炉

    本文转自机器之心. Facebook人工智能实验室在 PySlowFast 之后时隔两年,携 PyTorchVideo 重入战场. 视频作为当今最被广为使用的媒体形式,已逐渐占超过文字和图片,据了人们 ...

  8. PyTorch 1.0 中文官方教程:使用ONNX将模型从PyTorch传输到Caffe2和移动端

    译者:冯宝宝 在本教程中,我们将介绍如何使用ONNX将PyTorch中定义的模型转换为ONNX格式,然后将其加载到Caffe2中.一旦进入Caffe2,我们就可以运行模型来仔细检查它是否正确导出,然后 ...

  9. object怎么转list_PaddleOCR识别模型转Pytorch全流程记录

    这篇文章主要负责记录自己在转PaddleOCR 模型过程中遇到的问题,以供大家参考. 重要的话说在最前面,以免大家不往下看: 本篇文章是把 "整个" ppocr 模型 转成了 py ...

最新文章

  1. mybatis的面试一对一,一对多,多对多的mapper.xml配置
  2. pwn学习总结(二) —— 基础知识(持续更新)
  3. (转载)cmd-命令大全及详解
  4. 一些当前 Node.js 中最流行 ES6 特性的 benchmark (V8 / Chakra)
  5. Cryptocurrency Blockchain Internship Programme
  6. php 面试靠快速排序,搞定PHP面试 - 常见排序算法及PHP实现
  7. JavaScript中的正则表达式详解
  8. 结合源码探讨Android系统的启动流程
  9. 为计算机构建安全方案,计算机科学系安全管理标准化建设实施方案
  10. 计算机办公软件应用操作,基于计算机Word办公软件的使用及操作流程
  11. 关于pdms中设备参数模板的更新PML代码
  12. text edit model FELIX的理解与python实现
  13. 二层交换机 三层交换机 四层交换机的区别
  14. 内容创业赛道分野,2018紧,2019更紧
  15. 从win10回退到win7的苦逼经历
  16. next_day函数用法
  17. 粗同步 符号同步 matlab,OFDM系统在衰落信道中帧同步算法研究(毕业论文)
  18. DGV:人类基因组结构变异数据库
  19. Qt打开Word、Excel和PPT总结
  20. Android JetPack组件之DataBinding的使用详解

热门文章

  1. Kafka Broker
  2. nginx $1,2,3的含义
  3. jsp+servlet实现商城购物车功能
  4. Win10备份错误代码0x800700e1怎么解决?
  5. form表单提交数据的两种方式——submit直接提交、AJAX提交
  6. PMP一模考试错题集+解析 之 人员
  7. 能发送消息,但是浏览器上不了网?360安全卫士功能推荐
  8. SQL Server Management Studio
  9. 学习笔记—增量式PID详细实现(C语言)
  10. php xmp,xmp可以一直开着吗