PyTorch转Caffe模型

这里使用的工具来自:PytorchToCaffe,主体代码如下:
PytorchToCaffe-master
|——Caffe
|           |——caffe.proto
|           |——caffe_pb2.py
|           |——caffe_net.py
|           |——layer_param.py
|——example
|           |——xxx_pytorch_to_caffe.py
|——pytorch_to_caffe.py

使用前需要将caffe.proto和caffe_pb2.py替换成自己的。

一、代码执行入口:xxx_pytorch_to_caffe.py

以alexnet_pytorch_to_caffe.py为例,代码如下:

import sys
sys.path.insert(0,'.')
import torch
from torch.autograd import Variable
from torchvision.models.alexnet import alexnet
import pytorch_to_caffeif __name__=='__main__':name='alexnet'net=alexnet(True)input=Variable(torch.ones([1,3,226,226]))pytorch_to_caffe.trans_net(net,input,name)pytorch_to_caffe.save_prototxt('{}.prototxt'.format(name))pytorch_to_caffe.save_caffemodel('{}.caffemodel'.format(name))

这块是整个转模型代码的入口,首先定义自己的网络对象和一个输入(假数据),调用trans_net生成xxx.prototxt和xxx.caffemodel,最后保存。

二、主体功能代码:pytorch_to_caffe.py

1、trans_net

def trans_net(net,input_var,name='TransferedPytorchModel'):print('Starting Transform, This will take a while')log.init([input_var])log.cnet.net.name=namelog.cnet.net.input.extend([log.blobs(input_var)])log.cnet.net.input_dim.extend(input_var.size())global NET_INITTEDNET_INITTED=Truefor name,layer in net.named_modules():layer_names[layer]=nameprint("torch ops name:", layer_names)out = net.forward(input_var)print('Transform Completed')

log对象记录了转换的caffe模型中各层的名字、网络结构、模型参数等信息,起初先填写xxx.prototxt的输入dim信息。
log的类如下:

class TransLog(object):def __init__(self):"""doing init() with inputs Variable before using it"""self.layers={}self.detail_layers={}  self.detail_blobs={}  self._blobs=Blob_LOG()self._blobs_data=[]self.cnet=caffe_net.Caffemodel('')self.debug=Truedef init(self,inputs):""":param inputs: is a list of input variables"""self.add_blobs(inputs)#后面代码省略。。。

然后调用net对象的forward函数执行一次推理,在推理过程中会按照网络结构分别调用各层的底层函数,如运算到卷积层(nn.Conv2d)时,会调用torch.nn.functional.conv2d,但在pytorch_to_caffe.py中torch.nn.functional.conv2d被替换成了自定义的函数_conv2d,在该函数中会根据pytorch中定义的卷积层信息生成caffe的卷积层网络结构信息,同时保存卷积层权重和偏置参数:

def _conv2d(raw,input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):#通过卷积层的输入blob对应的标识符查找其blob名并打印,该blob名会在上一层中被记录print('conv: ',log.blobs(input))x=raw(input,weight,bias,stride,padding,dilation,groups)name=log.add_layer(name='conv')#记录该卷积层输出blob名log.add_blobs([x],name='conv_blob')#定义caffe的xxx.prototxt中的网络层layer=caffe_net.Layer_param(name=name, type='Convolution',bottom=[log.blobs(input)], top=[log.blobs(x)])layer.conv_param(x.size()[1],weight.size()[2:],stride=_pair(stride),pad=_pair(padding),dilation=_pair(dilation),bias_term=bias is not None,groups=groups)#保存卷积层权重和偏置参数if bias is not None:layer.add_data(weight.cpu().data.numpy(),bias.cpu().data.numpy())else:layer.param.convolution_param.bias_term=Falselayer.add_data(weight.cpu().data.numpy())log.cnet.add_layer(layer)return x
F.conv2d=Rp(F.conv2d,_conv2d)

2、Rp

# 核心组件,通过该类,实现对torch的function中的operators的输入,输出以及参数的读取
class Rp(object):def __init__(self,raw,replace,**kwargs):# replace the raw function to replace functionself.obj=replaceself.raw=rawdef __call__(self,*args,**kwargs):if not NET_INITTED:return self.raw(*args,**kwargs)for stack in traceback.walk_stack(None):if 'self' in stack[0].f_locals:layer=stack[0].f_locals['self']if layer in layer_names:log.pytorch_layer_name=layer_names[layer]print(layer_names[layer])breakout=self.obj(self.raw,*args,**kwargs)# if isinstance(out,Variable):#     out=[out]return out

实现自定义函数和torch.nn.functional中的函数的替换,如:

F.conv2d=Rp(F.conv2d,_conv2d)
F.linear=Rp(F.linear,_linear)
F.relu=Rp(F.relu,_relu)
F.leaky_relu=Rp(F.leaky_relu,_leaky_relu)
F.max_pool2d=Rp(F.max_pool2d,_max_pool2d)
F.avg_pool2d=Rp(F.avg_pool2d,_avg_pool2d)
F.adaptive_avg_pool2d = Rp(F.adaptive_avg_pool2d,_adaptive_avg_pool2d)
F.dropout=Rp(F.dropout,_dropout)
F.threshold=Rp(F.threshold,_threshold)
F.prelu=Rp(F.prelu,_prelu)
F.batch_norm=Rp(F.batch_norm,_batch_norm)
F.instance_norm=Rp(F.instance_norm,_instance_norm)
F.softmax=Rp(F.softmax,_softmax)
F.conv_transpose2d=Rp(F.conv_transpose2d,_conv_transpose2d)
F.interpolate = Rp(F.interpolate,_interpolate)
F.sigmoid = Rp(F.sigmoid,_sigmoid)
F.tanh = Rp(F.tanh,_tanh)
F.tanh = Rp(F.tanh,_tanh)
F.hardtanh = Rp(F.hardtanh,_hardtanh)

此外,有些版本的pytorch中的个别层底层调用的不是torch.nn.functional中的函数,而是torch,如sigmoid等,所以此时就要替换torch中的函数,如:

torch.split=Rp(torch.split,_split)
torch.max=Rp(torch.max,_max)
torch.cat=Rp(torch.cat,_cat)
torch.div=Rp(torch.div,_div)

三、遇到过的问题

1、ModuleNotFoundError: No module named 'google'
原因:python中没有安装protobuf

2、AttributeError: 'PoolingParameter' object has no attribute 'ceil_mode'
原因:使用的caffe中pooling层没有ceil_mode参数
解决方法:在caffe的pooling层添加该参数及相应的源码
1)在caffe.protode PoolingParameter中添加ceil_mode
2)修改pooling_layer.hpp中PoolingLayer类

3、TypeError: _avg_pool2d() takes from 3 to 7 postinal arguments but 8 were given
原因:avg pool层参数不对
解决方法:修改pytorch_to_caffe.py中_avg_pool2d

#def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True):
def _avg_pool2d(raw,input, kernel_size, stride = None, padding = 0, ceil_mode = False, count_include_pad = True, divisor_override = None):x = raw(input, kernel_size, stride, padding, ceil_mode, count_include_pad)_pool('ave',raw,input, x, kernel_size, stride, padding,ceil_mode)return x

4、如果训练时模型开了多卡训练,推理时也需要这样,即net = nn.DataParallel(net).cuda(),但由于数据被拷贝了多份,第一个卷积层拿到的blob地址不再是输入的了,导致找不到这个blob。
解决方法:去掉net = nn.DataParallel(net).cuda(),且将
net.load_state_dict(checkpoint)
checkpoint = torch.load("xxx.ckpt")
换成net.load_state_dict({k.replace('module.',''):v for k,v in torch.load("xxx.ckpt").items()})

5、当前版本的sigmoid是通过torch.sigmoid()实现的,不是F.sigmoid,所以需要添加torch.sigmoid = Rp(torch.sigmoid,_sigmoid),原先只有F.sigmoid = Rp(F.sigmoid,_sigmoid)

6、layer_name = log.add_layer(name="expand_as")#, with_num=True),log.add_layer参数没有with_num

7、当前的caffe不支持双线性插值层,自定义了一个插值层。重新定义_interpolate2()来取代F.interpolate,即:F.interpolate = Rp(F.interpolate,_interpolate2),_interpolate2()代码如下:

def _interpolate2(raw, input,size=None, scale_factor=None, mode='nearest', align_corners=None):x = raw(input,size , scale_factor ,mode)layer_name = log.add_layer(name='interpolate')top_blobs = log.add_blobs([x], name='interpolate_blob'.format(type))layer = caffe_net.Layer_param(name=layer_name, type='Interp',bottom=[log.blobs(input)], top=top_blobs)layer.interp_param(scale_factor= scale_factor)log.cnet.add_layer(layer)return x

8、1.9版本pytorch的torch.cat中的维度参数名称为dim,不是dimension,因此将def _cat(raw,inputs, dimension=0)换成def _cat(raw,inputs, dim=0)。

PyTorch转Caffe模型相关推荐

  1. linux caffe生成的模型,深度学习之pytorch转caffe转ncnn模型转换(三)

    搭建caffe平台: 先在Linux系统下搭建caffe环境,安装依赖包: sudo apt-get install libprotobuf-dev libleveldb-dev libsnappy- ...

  2. 【pytorch速成】Pytorch图像分类从模型自定义到测试

    文章首发于微信公众号<与有三学AI> [pytorch速成]Pytorch图像分类从模型自定义到测试 前面已跟大家介绍了Caffe和TensorFlow,链接如下. [caffe速成]ca ...

  3. 一步一步教你如何将 yolov3/yolov4 转为 caffe 模型

    实际工作中,目标检测 yolov3 或者 yolov4 模型移植到 AI 芯片中,经常需要将其先转换为 caffe1.x 模型,大家可能或多或少也有这方面的需求.例如华为海思 NNIE 只支持caff ...

  4. 【深度学习】Keras vs PyTorch vs Caffe:CNN实现对比

    作者 | PRUDHVI VARMA 编译 | VK 来源 | Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度 ...

  5. Pytorch版本YOLOv3模型转Darknet weights模型然后转caffemodel再转wk模型在nnie上面推理

    Pytorch版本YOLOv3模型转darknet weights模型然后转caffemodel再转wk模型在nnie上面推理 文章目录 Pytorch版本YOLOv3模型转darknet weigh ...

  6. 从.caffemodel/.caffemodel.h5提取Caffe模型的参数

    系列博客目录:Caffe转Pytorch模型系列教程 概述 目录 一.通用的提取参数方法 1.编译Caffe 2.打印.caffemodel的网络参数 3.保存.caffemodel的网络参数 二.提 ...

  7. Keras vs PyTorch vs Caffe:CNN实现对比

    作者|PRUDHVI VARMA 编译|VK 来源|Analytics Indiamag 在当今世界,人工智能已被大多数商业运作所应用,而且由于先进的深度学习框架,它非常容易部署.这些深度学习框架提供 ...

  8. 【Python】Caffe 模型转换 Caffe2 模型 (支持多输入 / 多输出)

    Model Translator from Caffe to Caffe2 用于将 Caffe 模型转换为对应 Caffe2 模型的 Python 脚本 官方提供了一个基础版本,经修改和优化后,已支持 ...

  9. pytorch转caffe的一些经历

    caffe是比较老的框架了,pytorch还不火的时候,还是比较流行的,有些比较著名的如人脸识别网络如centerloss,目标检测网络mtcnn.ssd,OCR识别都有对应的caffe版本.但有几个 ...

最新文章

  1. 增强迪基-福勒检验(ADF检验、augmented Dickey-Fuller test)是什么?解决了什么问题?
  2. (已解决)IT IS POSSIBLE THAT SOMEONE IS DOING SOMETHING NASTY。Someone could be eavesdropping on you
  3. ios视频硬解异常总结,12911总结
  4. rsync命令使用方法
  5. 家用使用计算机组装,不能再简单了!家用电脑DIY组装实操
  6. 指定范围内每个数的所有真约数
  7. 一起来看小米发布会!
  8. 时间操作(Java版)—获取给定时间与当前系统时间的差值(以毫秒为单位)
  9. Java的GUI学习六(Action事件)
  10. 国民经济行业分类 GB/T 4754-2017 最新2017版 代码整理
  11. 2016天猫快消母婴行业双11商家大会
  12. Unity使用Aspose.Words创建表格和UI截图一起插入到Word中并保存到本地的一种解决方案
  13. JAVANBA论坛系统计算机毕业设计Mybatis+系统+数据库+调试部署
  14. 普陀寺里的穿白T恤的奥特曼 2012年9月8日
  15. [Codeforces 274E]:Mirror Room(模拟)
  16. Next.js基本使用
  17. 《Python编程从入门到实践》(第2版)第二章 习题答案
  18. 西瓜口袋拼团商城系统搭建相关问题
  19. sports.php什么意思,使用Yahoo Fantasy Sports API的PHP和JSON
  20. flash+AS3制作的倒计时效果

热门文章

  1. 行稳致远,共建IDC产业优质生态圈
  2. LNG储罐用什么材料保温?
  3. 【Java SE】(五)方法和递归
  4. 如何使用家庭网络运行Aleo Prover
  5. 计算机名称内的隶属于无法更改,win7计算机名称怎么修改-修改win7计算机名称的方法 - 河东软件园...
  6. Android_弹钢琴(多媒体应用)
  7. CentOS8.5版本安装Docker报错,版本太新太多坑了
  8. 怎么把一张暗的照片调亮_太暗的照片怎么调整?
  9. css 划对号,css3画个圆圈里的对号
  10. EasyExcel自定义Converter解决性别转换问题