PyTorch转Caffe模型
PyTorch转Caffe模型
这里使用的工具来自:PytorchToCaffe,主体代码如下:
PytorchToCaffe-master
|——Caffe
| |——caffe.proto
| |——caffe_pb2.py
| |——caffe_net.py
| |——layer_param.py
|——example
| |——xxx_pytorch_to_caffe.py
|——pytorch_to_caffe.py
使用前需要将caffe.proto和caffe_pb2.py替换成自己的。
一、代码执行入口:xxx_pytorch_to_caffe.py
以alexnet_pytorch_to_caffe.py为例,代码如下:
import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models.alexnet import alexnet
import pytorch_to_caffeif __name__=='__main__':name='alexnet'net=alexnet(True)input=Variable(torch.ones([1,3,226,226]))pytorch_to_caffe.trans_net(net,input,name)pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))
这块是整个转模型代码的入口,首先定义自己的网络对象和一个输入(假数据),调用trans_net生成xxx.prototxt和xxx.caffemodel,最后保存。
二、主体功能代码:pytorch_to_caffe.py
1、trans_net
def trans_net(net,input_var,name='TransferedPytorchModel'):print('Starting Transform, This will take a while')log.init([input_var])log.cnet.net.name=namelog.cnet.net.input.extend([log.blobs(input_var)])log.cnet.net.input_dim.extend(input_var.size())global NET_INITTEDNET_INITTED=Truefor name,layer in net.named_modules():layer_names[layer]=nameprint("torch ops name:", layer_names)out = net.forward(input_var)print('Transform Completed')
log对象记录了转换的caffe模型中各层的名字、网络结构、模型参数等信息,起初先填写xxx.prototxt的输入dim信息。
log的类如下:
class TransLog(object):def __init__(self):"""doing init() with inputs Variable before using it"""self.layers={}self.detail_layers={} self.detail_blobs={} self._blobs=Blob_LOG()self._blobs_data=[]self.cnet=caffe_net.Caffemodel('')self.debug=Truedef init(self,inputs):""":param inputs: is a list of input variables"""self.add_blobs(inputs)#后面代码省略。。。
然后调用net对象的forward函数执行一次推理,在推理过程中会按照网络结构分别调用各层的底层函数,如运算到卷积层(nn.Conv2d)时,会调用torch.nn.functional.conv2d,但在pytorch_to_caffe.py中torch.nn.functional.conv2d被替换成了自定义的函数_conv2d,在该函数中会根据pytorch中定义的卷积层信息生成caffe的卷积层网络结构信息,同时保存卷积层权重和偏置参数:
def _conv2d(raw,input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):#通过卷积层的输入blob对应的标识符查找其blob名并打印,该blob名会在上一层中被记录print('conv: ',log.blobs(input))x=raw(input,weight,bias,stride,padding,dilation,groups)name=log.add_layer(name='conv')#记录该卷积层输出blob名log.add_blobs([x],name='conv_blob')#定义caffe的xxx.prototxt中的网络层layer=caffe_net.Layer_param(name=name, type='Convolution',bottom=[log.blobs(input)], top=[log.blobs(x)])layer.conv_param(x.size()[1],weight.size()[2:],stride=_pair(stride),pad=_pair(padding),dilation=_pair(dilation),bias_term=bias is not None,groups=groups)#保存卷积层权重和偏置参数if bias is not None:layer.add_data(weight.cpu().data.numpy(),bias.cpu().data.numpy())else:layer.param.convolution_param.bias_term=Falselayer.add_data(weight.cpu().data.numpy())log.cnet.add_layer(layer)return x
F.conv2d=Rp(F.conv2d,_conv2d)
2、Rp
# 核心组件,通过该类,实现对torch的function中的operators的输入,输出以及参数的读取
class Rp(object):def __init__(self,raw,replace,**kwargs):# replace the raw function to replace functionself.obj=replaceself.raw=rawdef __call__(self,*args,**kwargs):if not NET_INITTED:return self.raw(*args,**kwargs)for stack in traceback.walk_stack(None):if 'self' in stack[0].f_locals:layer=stack[0].f_locals['self']if layer in layer_names:log.pytorch_layer_name=layer_names[layer]print(layer_names[layer])breakout=self.obj(self.raw,*args,**kwargs)# if isinstance(out,Variable):# out=[out]return out
实现自定义函数和torch.nn.functional中的函数的替换,如:
F.conv2d=Rp(F.conv2d,_conv2d)
F.linear=Rp(F.linear,_linear)
F.relu=Rp(F.relu,_relu)
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d,_adaptive_avg_pool2d)
F.dropout=Rp(F.dropout,_dropout)
F.threshold=Rp(F.threshold,_threshold)
F.prelu=Rp(F.prelu,_prelu)
F.batch_norm=Rp(F.batch_norm,_batch_norm)
F.instance_norm=Rp(F.instance_norm,_instance_norm)
F.softmax=Rp(F.softmax,_softmax)
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
F.interpolate = Rp(F.interpolate,_interpolate)
F.sigmoid = Rp(F.sigmoid,_sigmoid)
F.tanh = Rp(F.tanh,_tanh)
F.tanh = Rp(F.tanh,_tanh)
F.hardtanh = Rp(F.hardtanh,_hardtanh)
此外,有些版本的pytorch中的个别层底层调用的不是torch.nn.functional中的函数,而是torch,如sigmoid等,所以此时就要替换torch中的函数,如:
torch.split=Rp(torch.split,_split)
torch.max=Rp(torch.max,_max)
torch.cat=Rp(torch.cat,_cat)
torch.div=Rp(torch.div,_div)
三、遇到过的问题
1、ModuleNotFoundError: No module named 'google'
原因:python中没有安装protobuf
2、AttributeError: 'PoolingParameter' object has no attribute 'ceil_mode'
原因:使用的caffe中pooling层没有ceil_mode参数
解决方法:在caffe的pooling层添加该参数及相应的源码
1)在caffe.protode PoolingParameter中添加ceil_mode
2)修改pooling_layer.hpp中PoolingLayer类
3、TypeError: _avg_pool2d() takes from 3 to 7 postinal arguments but 8 were given
原因:avg pool层参数不对
解决方法:修改pytorch_to_caffe.py中_avg_pool2d
#def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True):
def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True, divisor_override = None):x = raw(input, kernel_size, stride, padding, ceil_mode, count_include_pad)_pool('ave',raw,input, x, kernel_size, stride, padding,ceil_mode)return x
4、如果训练时模型开了多卡训练,推理时也需要这样,即net = nn.DataParallel(net).cuda(),但由于数据被拷贝了多份,第一个卷积层拿到的blob地址不再是输入的了,导致找不到这个blob。
解决方法:去掉net = nn.DataParallel(net).cuda(),且将
net.load_state_dict(checkpoint)
checkpoint = torch.load("xxx.ckpt")
换成net.load_state_dict({k.replace('module.',''):v for k,v in torch.load("xxx.ckpt").items()})
5、当前版本的sigmoid是通过torch.sigmoid()实现的,不是F.sigmoid,所以需要添加torch.sigmoid = Rp(torch.sigmoid,_sigmoid),原先只有F.sigmoid = Rp(F.sigmoid,_sigmoid)
6、layer_name = log.add_layer(name="expand_as")#, with_num=True),log.add_layer参数没有with_num
7、当前的caffe不支持双线性插值层,自定义了一个插值层。重新定义_interpolate2()来取代F.interpolate,即:F.interpolate = Rp(F.interpolate,_interpolate2),_interpolate2()代码如下:
def _interpolate2(raw, input,size=None, scale_factor=None, mode='nearest', align_corners=None):x = raw(input,size , scale_factor ,mode)layer_name = log.add_layer(name='interpolate')top_blobs = log.add_blobs([x], name='interpolate_blob'.format(type))layer = caffe_net.Layer_param(name=layer_name, type='Interp',bottom=[log.blobs(input)], top=top_blobs)layer.interp_param(scale_factor= scale_factor)log.cnet.add_layer(layer)return x
8、1.9版本pytorch的torch.cat中的维度参数名称为dim,不是dimension,因此将def _cat(raw,inputs, dimension=0)换成def _cat(raw,inputs, dim=0)。
PyTorch转Caffe模型相关推荐
- linux caffe生成的模型,深度学习之pytorch转caffe转ncnn模型转换(三)
搭建caffe平台: 先在Linux系统下搭建caffe环境,安装依赖包: sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy- ...
- 【pytorch速成】Pytorch图像分类从模型自定义到测试
文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...
- 一步一步教你如何将 yolov3/yolov4 转为 caffe 模型
实际工作中,目标检测 yolov3 或者 yolov4 模型移植到 AI 芯片中,经常需要将其先转换为 caffe1.x 模型,大家可能或多或少也有这方面的需求.例如华为海思 NNIE 只支持caff ...
- 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比
作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...
- Pytorch版本YOLOv3模型转Darknet weights模型然后转caffemodel再转wk模型在nnie上面推理
Pytorch版本YOLOv3模型转darknet weights模型然后转caffemodel再转wk模型在nnie上面推理 文章目录 Pytorch版本YOLOv3模型转darknet weigh ...
- 从.caffemodel/.caffemodel.h5提取Caffe模型的参数
系列博客目录:Caffe转Pytorch模型系列教程 概述 目录 一.通用的提取参数方法 1.编译Caffe 2.打印.caffemodel的网络参数 3.保存.caffemodel的网络参数 二.提 ...
- Keras vs PyTorch vs Caffe:CNN实现对比
作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度学习框架提供 ...
- 【Python】Caffe 模型转换 Caffe2 模型 (支持多输入 / 多输出)
Model Translator from Caffe to Caffe2 用于将 Caffe 模型转换为对应 Caffe2 模型的 Python 脚本 官方提供了一个基础版本,经修改和优化后,已支持 ...
- pytorch转caffe的一些经历
caffe是比较老的框架了,pytorch还不火的时候,还是比较流行的,有些比较著名的如人脸识别网络如centerloss,目标检测网络mtcnn.ssd,OCR识别都有对应的caffe版本.但有几个 ...
最新文章
- 增强迪基-福勒检验(ADF检验、augmented Dickey-Fuller test)是什么?解决了什么问题?
- (已解决)IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY。Someone could be eavesdropping on you
- ios视频硬解异常总结,12911总结
- rsync命令使用方法
- 家用使用计算机组装,不能再简单了!家用电脑DIY组装实操
- 指定范围内每个数的所有真约数
- 一起来看小米发布会!
- 时间操作(Java版)—获取给定时间与当前系统时间的差值(以毫秒为单位)
- Java的GUI学习六(Action事件)
- 国民经济行业分类 GB/T 4754-2017 最新2017版 代码整理
- 2016天猫快消母婴行业双11商家大会
- Unity使用Aspose.Words创建表格和UI截图一起插入到Word中并保存到本地的一种解决方案
- JAVANBA论坛系统计算机毕业设计Mybatis+系统+数据库+调试部署
- 普陀寺里的穿白T恤的奥特曼 2012年9月8日
- [Codeforces 274E]:Mirror Room(模拟)
- Next.js基本使用
- 《Python编程从入门到实践》(第2版)第二章 习题答案
- 西瓜口袋拼团商城系统搭建相关问题
- sports.php什么意思,使用Yahoo Fantasy Sports API的PHP和JSON
- flash+AS3制作的倒计时效果
热门文章
- 行稳致远,共建IDC产业优质生态圈
- LNG储罐用什么材料保温?
- 【Java SE】(五)方法和递归
- 如何使用家庭网络运行Aleo Prover
- 计算机名称内的隶属于无法更改,win7计算机名称怎么修改-修改win7计算机名称的方法 - 河东软件园...
- Android_弹钢琴(多媒体应用)
- CentOS8.5版本安装Docker报错,版本太新太多坑了
- 怎么把一张暗的照片调亮_太暗的照片怎么调整?
- css 划对号,css3画个圆圈里的对号
- EasyExcel自定义Converter解决性别转换问题