[PaddleSeg源码阅读] PaddleSeg 导出静态图 export.py 文件中的道道
周末去泰山玩♂耍,周六晚上10点半开始爬,周日上午10点26回到住的地方躺下,整整12个小时!!
我一个人爬完全程有些慢,不过起码我不是逃兵
下山到最后几段的时候,脸上摆出极其夸张和痛苦的表情,有几个阿姨和小姑娘看见我直笑,好吧,我看见他们笑我也笑了hhhh
来两张云海镇楼!!
在将PaddleSeg的模型导出为 onnx 或者 trt 之前,首先要将动态图模型导出为静态图模型,之前没怎么注意这个文件,后来出现问题了,才看这个文件,还是有些小细节的
- 预处理的 Transform 部分,有没有被导出到模型中? (按理说是不会的,实际也是不会的)
- 后处理的 argmax 和 softmax(这个模型也可以有,后处理也可以有),这是分割的,如果是目标检测,还涉及到NMS部分往哪里加
OKKKK,现在咱开始看源码,
就是这个文件,里边结构很简单:
- parse_args 函数,用来解析外部传来的参数
- main 函数,进行所有的操作
- SavedSegmentationNet 类,用来加上后处理的类
- PostPorcesser 类,用来将后处理
1. parse_args 函数
"--config", # config 文件,至少要有 export 那一项
export:transforms:- type: Resizetarget_size: [224, 224]- type: Normalize
'--save_dir', # 就是保存的路径
'--model_path', # 动态图模型的路径
'--without_argmax', # 是否不在网络末端添加argmax算子。由于PaddleSeg组网默认返回logits,为部署模型可以直接获取预测结果,我们默认在网络末端添加argmax算子
'--with_softmax', # 在网络末端添加softmax算子。由于PaddleSeg组网默认返回logits,如果想要部署模型获取概率值,可以置为True
"--input_shape", # 设置导出模型的输入shape,比如传入--input_shape 1 3 1024 1024。如果不设置input_shape,默认导出模型的输入shape是[-1, 3, -1, -1]
以上部分参考自:
https://github.com/PaddlePaddle/PaddleSeg/blob/release/2.5/docs/model_export_cn.md
2. SavedSegmentationNet 类
class SavedSegmentationNet(paddle.nn.Layer):def __init__(self, net, without_argmax=False, with_softmax=False):super().__init__()self.net = netself.post_processer = PostPorcesser(without_argmax, with_softmax)def forward(self, x):outs = self.net(x)outs = self.post_processer(outs)return outs
参数 net 就是 PaddleSeg 的动态图模型实例 paddle.nn.Layer
可以看到 self.post_processer
专门用来解决后处理的 softmax 和 argmax 的操作
3. PostPorcesser 类
class PostPorcesser(paddle.nn.Layer):def __init__(self, without_argmax, with_softmax):super().__init__()self.without_argmax = without_argmaxself.with_softmax = with_softmaxdef forward(self, outs):new_outs = []for out in outs:if self.with_softmax:out = paddle.nn.functional.softmax(out, axis=1)if not self.without_argmax:out = paddle.argmax(out, axis=1)new_outs.append(out)return new_outs
通过两个 flag self.without_argmax
和 self.with_softmax
来控制是否添加 argmax 和 softmax
NOTICE: outs
是个列表? 模型明明只返回一个 logit
啊
这里给出一个结论,PaddleSeg 的所有模型在 call
之后,从 forward
返回的都是列表,列表可能会包含多个元素,但第0个元素一定是 logit
这里有两点说明:
关于 logit 是什么,logit 其实就是模型的输出,可以理解为没有 通过softmax 之前的部分
logit 的shape为[bs, cls, w, h]
, bs是batch_size,cls是有几类,wh是宽高
logit 在softmax 之后,每个像素点就有了每一类的概率,即加和为1
而直接取logit (无论是否通过softmax) 最大值的那一类,就是模型预测该像素点的类别关于为何 PaddleSeg 的模型返回值是列表,可以查看这篇博客:
关于PaddleSeg模型返回的都是list这件小事
4. main函数
咱一行一行看吧,今儿周日,我有大把大把的时间hhh,有种初中英语老师领着做阅读理解的感觉hhh
os.environ['PADDLESEG_EXPORT_STAGE'] = 'True' # 添加了一个环境变量? 我其实不太懂这个是做什么的
cfg = Config(args.cfg) # Config 对象
net = cfg.model # 实例化一个model
不知道 net = cfg.model()
为啥是实例化一个model的,可以看一下:
https://blog.csdn.net/HaoZiHuang/article/details/125641772
的前半部分,简单说下就是该函数用了 @property
装饰器
if args.model_path:para_state_dict = paddle.load(args.model_path)net.set_dict(para_state_dict)logger.info('Loaded trained params of model successfully.')if args.input_shape is None:shape = [None, 3, None, None]
else:shape = args.input_shape
接下来就是读入模型参数,经典的先 load
读入参数字典,然后再set_dict
如果 args.input_shape
有指定,则用指定的,没有则不用,因为静态图要指定shape, 所有默认为[None, 3, None, None]
这里插入一句:
什么是动态图和静态图?
在深度学习模型构建上,飞桨框架支持动态图编程和静态图编程两种方式,其代码编写和执行方式均存在差异。
动态图编程: 采用 Python 的编程风格,解析式地执行每一行网络代码,并同时返回计算结果。在 模型开发 章节中,介绍的都是动态图编程方式。
静态图编程: 采用先编译后执行的方式。需先在代码中预定义完整的神经网络结构,飞桨框架会将神经网络描述为 Program 的数据结构,并对 Program 进行编译优化,再调用执行器获得计算结果。
动态图编程体验更佳、更易调试,但是因为采用 Python 实时执行的方式,开销较大,在性能方面与 C++ 有一定差距;静态图调试难度大,但是将前端 Python 编写的神经网络预定义为 Program描述,转到 C++ 端重新解析执行,脱离了 Python 依赖,往往执行性能更佳,并且预先拥有完整网络结构也更利于全局优化。
以上摘自Paddle官方文档:
https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/jit/index_cn.html
if not args.without_argmax or args.with_softmax:new_net = SavedSegmentationNet(net, args.without_argmax,args.with_softmax)
else:new_net = net
如果不需要 without_argmax
和 with_softmax
参数,则直接返回之前的 net
,就无需后处理了
注意一个是 without,一个是with,PaddleSeg 导出这里,默认会给咱加上 argmax
new_net.eval()
new_net = paddle.jit.to_static(new_net,input_spec=[paddle.static.InputSpec(shape=shape, dtype='float32')])save_path = os.path.join(args.save_dir, 'model')
paddle.jit.save(new_net, save_path)
终于到了导出环节,后两行是导出静态图模型后的保存环节,注意保存函数为paddle.jit.save
而不像动态图可以这样保存:
param = model.state_dict()
path = 'model.pdparams'
paddle.save(param, path)
导出函数为 paddle.jit.to_static
, 第一个参数为动态图模型,第二个参数input_spec
用来指定,动态图模型输入的shape, 第三个参数 build_strategy
, 相对高级,我也没用过
参考自:
https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/jit/to_static_cn.html#to-static
这一步之后就导出完毕了,这个 to_static
可以用来做装饰器,这是官方的demo:
import paddle
from paddle.jit import to_static@to_static
def func(x):if paddle.mean(x) < 0:x_v = x - 1else:x_v = x + 1return x_vx = paddle.ones([1, 2], dtype='float32')
x_v = func(x)
print(x_v) # [[2. 2.]]
也就是说,可以给 forward
函数头上直接加一个装饰器@to_static
就可以加速
该装饰器将函数内的动态图API转化为静态图API。此装饰器自动处理静态图模式下的Program和Executor,并将结果作为动态图Tensor返回。
如果被装饰的函数里面调用其他动态图函数,被调用的函数也会被转化为静态图函数。
若 to_static 以装饰器形式使用,则被装饰函数默认会被解析为此参数值,无需显式指定。
yml_file = os.path.join(args.save_dir, 'deploy.yaml')with open(yml_file, 'w') as file:transforms = cfg.export_config.get('transforms', [{'type': 'Normalize'}])data = {'Deploy': {'transforms': transforms,'model': 'model.pdmodel','params': 'model.pdiparams'}}yaml.dump(data, file)logger.info(f'Model is saved in {args.save_dir}.')
最后几行就是写 Deploy.yml
文件,值得说的就一句话:
transforms = cfg.export_config.get('transforms', [{'type': 'Normalize'}])
看到这个 cfg.export_config
对象有个get方法,可以反着猜一下,也许是个dict
打印一下,他就是个dict
,进入Config 的源代码看看:
@property
def export_config(self) -> Dict:return self.dic.get('export', {})
显然,self.dic
就是那个原始的 yaml 读进来之后的字典,export_config
返回 export 那里字典
之后,cfg.export_config.get
, 如果没有 transforms
的 key, 则返回 {'type': 'Normalize'}
OK了,export.py 文件终于说完了
[PaddleSeg源码阅读] PaddleSeg 导出静态图 export.py 文件中的道道相关推荐
- [PaddleSeg源码阅读] PaddleSeg Validation 中添加 Boundary IoU的计算(3)——添加Boundary IoU
经过前面: PaddleSeg Validation 中添加 Boundary IoU的计算(1)--val.py文件细节提示 PaddleSeg Validation 中添加 Boundary Io ...
- [PaddleSeg 源码阅读] PaddleSeg计算 mIoU
这是我改成的numpy版本的,应该可以直接用,下边是 Paddle 版本的 def calculate_area(pred, label, num_classes, ignore_index=255) ...
- IDEA源码阅读利器 — UML类图插件Diagram
来源:https://www.cnblogs.com/deng-cc/p/6927447.html 最近正好也没什么可忙的,就回过头来鼓捣过去的知识点,到 Servlet 部分时,以前学习的时候硬是把 ...
- Flask源码阅读-第四篇(flask\app.py)
flask.app该模块2000多行代码,主要完成应用的配置.初始化.蓝图注册.请求装饰器定义.应用的启动和监听,其中以下方法可以重点品读和关注 def setupmethod(f): @setupm ...
- CI框架源码阅读笔记8 控制器Controller.php
最近时间有些紧,源码阅读系列更新有些慢.鉴于Controller中代码比较少,本次Blog先更新该文件的源码分析. 在经过路由分发之后,实际的应用Controller接管用户的所有请求,并负责与用户数 ...
- Dubbo注册协议原理以及源码阅读
前言 继上次小编所讲RPC协议暴露服务并且远程调用之后,小编这次给大家带来注册中心协议整体流程原理以及源码精讲,Dubbo协议服务暴露与引用以及源码分析文章中,远程服务暴露可以只通过RPC协议即可,那 ...
- Octopus 源码阅读(一)
Octopus 源码阅读--fs部分 开源代码 bitmap.cpp bitmap中的代码基本上没啥好说的,比较清楚.不过不解的是为什么在初始化的时候要统计freecount,理论上buffer不是应 ...
- SpringMVC源码阅读:过滤器
SpringMVC源码阅读:过滤器 目录 1.前言 2.源码分析 3.自定义过滤器 3.1 自定义过滤器继承OncePerRequestFilter 3.2 自定义过滤器实现Filter接口 4.过滤 ...
- 代码分析:NASM源码阅读笔记
NASM源码阅读笔记 NASM(Netwide Assembler)的使用文档和代码间的注释相当齐全,这给阅读源码 提供了很大的方便.按作者的说法,这是一个模块化的,可重用的x86汇编器, 而且能够被 ...
最新文章
- asp连接mysql odbc,在ASP中连接MySQL数据库的方法,最好的通过ODBC方法
- ansible的调用使用
- mstar v56几路hdmi_Android TV : Mstar平台 GPIO 调试
- 提交官方MapReduce作业到YARN
- 计算机行政考试题库,2014香港特别行政区计算机等级考试试题 二级ACCESS考试题库...
- Vue二次封装axios为插件使用
- 李开复:算法是内功,程序员别冷落算法!
- 关于Spring Cloud Config服务器介绍
- B2B 网关软件 以新颖的模式 让企业步入新常态
- BZOJ 1606: [Usaco2008 Dec]Hay For Sale 购买干草
- VS/C#添加chart控件
- C++ 详解快速排序代码
- 前端的长度单位有哪些
- HTML+CSS大作业——动画漫展学习资料电影模板(6页) 网页设计作业 _ 动漫网页设计作业,网页设计作业 _ 动漫网页设计成品,网页设计作业 _ 动漫网页设计成品模板下载
- 曾是谷歌程序员,抛下百万年薪创业,4年成就7亿用户,今身价百亿!
- ehvierwer登录与不登录_微信上不去了怎么办,峰哥教你微信登录不上去的办法
- 万兴pdf编辑解压后打不开_PDF文档无法编辑的原因和解决方案
- return的返回用法
- js 数组转json,json转数组
- c++中 #defin的基本意思
热门文章
- 超星集团武汉研发中心面试题
- castep 编译安装说明
- Cocos Creator子游戏动态下载实现(大厅+子游戏模式)
- 计算机里最常用的概念
- 已解决UnicodeEncodeError: ‘ascii‘ codec can‘t encode characters in position 18-20: ordinal not in range
- 华大HC32F460RTC时钟实验
- java简单atm_Java实现简单银行ATM功能
- 深入理解最强桌面地图控件GMAP.NET ---离线地图
- IE添加可信任站点,启用ActiveX插件批处理
- IOS9 UISearchBar详解