文章目录

  • 软件版本:
  • pytorch模型转onnx
  • onnx模型转tensorflow
  • tensorflow模型转tflite

软件版本:

tensorflow 2.3.1
pytorch 1.6.0
onnxruntime 1.8.1
cv2 4.5.3
onnx_tf 1.8.0
onnx 1.10.1

pytorch模型转onnx

import cv2
import numpy as np
import torch.onnx
import onnxruntime
import random# 为了保证pytorch每次输出结果相同
def set_seed(seed=1):random.seed(seed)np.random.seed(seed)torch.manaual_seed(seed)torch.cuda.manaual_seed(seed)torch.cuda.manaual_seed_all(seed)os.environ['PYTHONHASHSEED'] = str(seed)torch.backends.cudnn.determinstic = Truetorch.backends.cudnn.benchmark = Falsedef get_img_batch(img_path):# 具体预处理过程应根据模型训练代码具体确定,保持一致input_size = 224expand_size = int(input_size/0.875)img = cv2.imread(img_path)img = img[:,:,::-1]w,h = img.shape[1],img.shape[0]# equals to: transform.Resize(int), resize short side to int, keep ratioif w >= h:ratio = w / hw_ = expand_size * ratioh_ = expand_sizeelse:ratio = h / ww_ = expand_sizeh_ = expand_size * ratioh_,w_ = int(h_),int(w_)img = cv2.resize(img, (w_,h_)) # 注意顺序# equals to: transforms.CenterCrop(int), center square cropw, h = img.shape[1],img.shape[0]midx,midy=int(w/2),int(h/2)cropx,cropy=int(input_size/2),int(input_size/2)img = img[midy-cropy:midy+cropy, midx-cropx:midx+cropx]# normalizemean = torch.tensor([0.485*255,0.456*255,0.406*255]).view(1,3,1,1)std = torch.tensor([0.229*255,0.224*255,0.225*255]).view(1,3,1,1)img_batch = torch.from_numpy(img).float().unsqueeze(0) # 'float32' and expand dimsimg_batch = img_batch.permute(0,3,1,2)img_batch = img_batch.sub_(mean).div_(std)return img_batchdef load_torch_model(backbone_path):pretrained_dict = torch.load(backbone_path)net = models.__dict__['mobilenetv2'](width_mult=1.0)model_dict = net.state_dict()pretrained_dict = {k:v for k,v in pretrained_dict.items() if (k in model_dict)}model_dict.update(pretrained_dict)net.load_state_dict(model_dict)net.eval() # 重要!为了保证pytorch每次输出结果相同return netdef torch_to_onnx(torch_model):batch_size = 1input_shape = (3,224,224)x = torch.ones(batch_size, *input_shape)onnx_path = 'model.onnx'# export and save the modeltorch.onnx.export(torch_model,x,onnx_path,opset_version=12,input_names = ['input'],output_names = ['output'],)# 对比测试结果
def compare_torch_onnx(torch_model,onnx_sess,img_batch):sess_out = onnx_sess.run(None, {'input': img_batch.numpy()})sess_out = sess_out[0].flatten()sess_out = np.array(sess_out, dtype='float32')sess_out = torch.from_numpy(sess_out) # output featureonnx_pred = torch.nn.functional.softmax(sess_out, dim=0)onnx_index = np.argmax(onnx_pred).item() # output class indextorch_pred = torch_model(img_batch).detach().flatten() # featuretorch_pred = torch.nn.functional.softmax(torch_pred, dim=0)torch_pred = np.array(torch_pred, dtype='float32')torch_index = np,argmax(torch_pred).item() # index# 判断转换前后特征值差异np.testing.assert_almost_equal(torch_pred, onnx_pred, decimal=6)if __name__ == '__main__':set_seed()backbone_pth = 'model.pth.tar'onnx_model = onnxruntime.InferenceSeesion('model.onnx', None)torch_model = load_torch_model(backbone_pth)img_path = '1.jpg'img_batch = get_img_batch(img_path)# evaluation

onnx模型转tensorflow

import onnx
from onnx_tf.backend import preparefilename = 'model.onnx'
target_file_path = './tfmodel'
# load onnx model
onnx_model = onnx.load(filename)
tf_rep = prepare(onnx_model)
# save tf model to the path
tf_rep.export_graph(target_file_path)

tensorflow模型转tflite

# 因为上一步保存的模型文件已经是pb格式了,所以不用先转为pb,如果不是pb格式,参考:https://blog.csdn.net/qxqxqzzz/article/details/119668426?spm=1001.2014.3001.5501
def tf_tflite():tf_model_path, tflite_model_path = './tfmodel', 'model.tflite'converter = tf.lite.TFLiteCOnverter.from_saved_model(tf_model_path)converter.target_spec,supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINGS,tf.lite.OpsSet.SELECT_TF_OPS]tflite_model = converter.convert()with open(tflite_model_path, 'wb') as g:g.write(tflite_model)def tflite_prediction(img_batch):tflite_model = 'model.tflite'interpreter = tf.lite.Interpreter(model_path = tflite_model)interpreter.allocate_tensors()input_details = interpreter.get_input_details()output_details = interpreter.get_output_details()interpreter,set_tensor(input_details[0]['index'], img_batch)interpreter.invoke()tflite_pred = interpreter.get_tensor(output_details[0]['index']) # output featuretflite_pred = tf.convert_to_tensor(tflite_pred)tflite_pred = tf.nn.softmax(tflite_pred)print(tf.argmax(tflite_pred, 1)) # output class index

模型转换:pytorch模型转onnx, onnx转tensorflow, tensorflow转tflite相关推荐

  1. 关于在SNPE平台上进行ONNX模型转换DLC模型

    Onnx模型转化DLC模型 简介 在snpe平台上,将onnx模型转换为dlc模型 目录 snpe平台介绍 snpe平台与onnx配置 onnx模型转换dlc 模型量化 关于1.38版本SNPE部署时 ...

  2. 模型转换、模型压缩、模型加速工具汇总

    点击上方"计算机视觉工坊",选择"星标" 干货第一时间送达 编辑丨机器学习AI算法工程 一.场景需求解读   在现实场景中,我们经常会遇到这样一个问题,即某篇论 ...

  3. 【YoloV5 6.0|6.1 部署 TensorRT到torchserve】环境搭建|模型转换|engine模型部署(详细的packet文件编写方法)

    忽然发现,关于部署TensorRT的文章少的可怜,于是乎,决定分享一下我自己关于这部分内容的一些成功实操和心得.还是希望大家可以分享出去,让更多人看到!!! QQ: 1757093754 我的操作环境 ...

  4. sketchup 图片转模型_su模型转换3d模型(如何将3D模型转化为sketchup)

    求救求救~~~3dmax模型 导入su 模型的位置变了 非得... 第一步  全部放到一个图层然后打组 第二部   随便画个box体 第三部  选择建筑点击链接变换 第四步   点击拾取 选择box体 ...

  5. 【Python】Caffe 模型转换 Caffe2 模型 (支持多输入 / 多输出)

    Model Translator from Caffe to Caffe2 用于将 Caffe 模型转换为对应 Caffe2 模型的 Python 脚本 官方提供了一个基础版本,经修改和优化后,已支持 ...

  6. 如何将tensorflow模型转PYTORCH模型

    将tensorflow版本的.ckpt模型转成pytorch的.bin模型 - 最咸的鱼 - 博客园

  7. yolov5模型转换(pt=>onnx=>rknn)和板端验证测试

    测试环境说明: (1)由于模型转换工具需要onnx版本和rknn的tool工具需要的版本相互矛盾需要创建量开发环境,当前测试转换的模型是yolov5_v5.0的模型 (2)由于在搭建开发环境时还存在部 ...

  8. OpenCV转换PyTorch分类模型并使用OpenCV Python启动

    OpenCV转换PyTorch分类模型并使用OpenCV Python启动 转换PyTorch分类模型并使用OpenCV Python启动 目标 介绍 要求 实践 模型转换管道 模型评估 评估模式 测 ...

  9. 模型推理加速系列|如何用ONNX加速BERT特征抽取(附代码)

    简介 近期从事模型推理加速相关项目,所以抽空整理最近的学习经验.本次实验目的在于介绍如何使用ONNXRuntime加速BERT模型推理.实验中的任务是利用BERT抽取输入文本特征,至于BERT在下游任 ...

  10. 飞桨上线万能转换小工具,教你玩转TensorFlow、Caffe等模型迁移

    百度推出飞桨(PaddlePaddle)后,不少开发者开始转向国内的深度学习框架.但是从代码的转移谈何容易,之前的工作重写一遍不太现实,成千上万行代码的手工转换等于是在做一次二次开发. 现在,有个好消 ...

最新文章

  1. Closing Spring root WebApplicationContext
  2. 项目3----云服务器及其提供商
  3. leetcode算法题--完美数
  4. 读----------空乏的估算
  5. IDEA 2020.2 稳定版发布,带来了不少新功能...
  6. 【STM32】RTC相关函数和类型
  7. 推荐系统(3)-协同过滤2-矩阵分解算法
  8. python中的匿名函数lambda
  9. k8s边缘节点_边缘计算,如何啃下集群管理这块硬骨头?
  10. javascript-常用内置对象-随堂
  11. python使用缩进来体现-Python 使用缩进来体现代码之间的逻辑关系 .
  12. honeywell Xenon 1900 usb
  13. XRD测试常见问题及解答(三)
  14. 8位模型计算机设计与仿真
  15. 离散Hopfield神经网络的联想记忆—数字识别
  16. 爬虫入门(3)——拉钩网
  17. 华为交换机、路由器流量统计
  18. 微信关注二维码不显示
  19. win10 文件夹设置区分大小写
  20. 我的mybatis-plus用法,被全公司同事开始悄悄模仿了

热门文章

  1. PyTorch实现:经典网络 NiN
  2. 基于微信小程序的手游租号系统-毕业设计作品
  3. 通信工程专业课和毕业后的工作关系
  4. [转]Oracle 错误大全
  5. Codeforces 1088E Ehab and a component choosing problem(树形DP)
  6. 如何用c语言编写智能照明系统,基于STC89C52单片机的智能照明控制系统方案设计...
  7. PyEcharts学习笔记整理,基于B站千锋教育
  8. 5.软件测试-----自动化测试
  9. JSP文件放在WebContent下和放在WEB-INF下的区别
  10. Python 标准库 functools 模块详解