2022.08.16

本文使用Tensorflow中集成的TensorRT进行模型转换。

不需要安装TensorRT的Python库,但是TensorRT还需要安装,需要用到的包是libnvinfer,如若不安装TensorRT,需要使用apt-get install libnvinfer进行安装。

关于TensorRT的安装,请参考:

为Tensorflow安装TensorRT(tar)_January_Cao的博客-CSDN博客

0. 文件夹构成

├── compare_models.py
├── convert.py
├── data

│   ├── download_images.sh
│   ├── img0.jpg
│   ├── img1.jpg
│   ├── img2.jpg
│   └── img3.jpg

├── model_check.py
└── save_model.py

1. 准备数据

也可以随意从网上下载图片。

或者执行【sh download_images.sh】下载图片至data文件夹

wget -O img0.jpg https://res.cloudinary.com/roundglass/image/upload/q_auto/ar_16:9,c_fill,w_1250/f_auto,g_auto/v1639561104/rg/collective/media/vyy8zc9ww3gqsjkqy9gk.jpgwget -O img1.jpg https://www.fundacion-affinity.org/sites/default/files/los-10-sonidos-principales-del-gato-fa.jpgwget -O img2.jpg https://hips.hearstapps.com/hmg-prod.s3.amazonaws.com/images/dog-puppy-on-garden-royalty-free-image-1586966191.jpg?crop=1.00xw:0.669xh;0,0.190xh&resize=1200:*wget -O img3.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/0/05/Parrot.jpg/1280px-Parrot.jpg

2. 模型数据下载和推理,保存模型为saved_model

执行以下代码,可以看到对图片的推理结果,并把模型以saved_model的形式保存到‘resnet50_saved_model’路径下。

# save_model.py
import numpy as npfrom tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions# load model
model = ResNet50(weights='imagenet')for i in range(4):img_path = './data/img{}.jpg'.format(i)img = image.load_img(img_path, target_size=(224, 224))x = image.img_to_array(img)x = np.expand_dims(x, axis=0)x = preprocess_input(x)# predictionpreds = model.predict(x)print('{} - Predicted: {}'.format(img_path, decode_predictions(preds, top=3)[0]))# save model to resnet50_saved_model path
model.save('resnet50_saved_model') 

3. 查看结果

1/1 [==============================] - 2s 2s/step
./data/img0.jpg - Predicted: [('n01742172', 'boa_constrictor', 0.8889929), ('n01729322', 'hognose_snake', 0.04273584), ('n01734418', 'king_snake', 0.032792028)]
1/1 [==============================] - 0s 14ms/step
./data/img1.jpg - Predicted: [('n02124075', 'Egyptian_cat', 0.53852856), ('n02127052', 'lynx', 0.19898652), ('n02123159', 'tiger_cat', 0.1614549)]
1/1 [==============================] - 0s 14ms/step
./data/img2.jpg - Predicted: [('n02113799', 'standard_poodle', 0.5422854), ('n02093647', 'Bedlington_terrier', 0.44937962), ('n02091134', 'whippet', 0.002195063)]
1/1 [==============================] - 0s 14ms/step
./data/img3.jpg - Predicted: [('n01818515', 'macaw', 0.7934733), ('n01820546', 'lorikeet', 0.11080903), ('n01828970', 'bee_eater', 0.08086483)]

4. 利用保存的saved_model推测结果

导入刚刚保存的‘resnet50_saved_model’,结果和之前一致。

# model_check.py
import numpy as np
import tensorflow as tffrom tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions# load model
model = tf.keras.models.load_model('resnet50_saved_model')for i in range(4):img_path = './data/img{}.jpg'.format(i)img = image.load_img(img_path, target_size=(224, 224))x = image.img_to_array(img)x = np.expand_dims(x, axis=0)x = preprocess_input(x)# predictionpreds = model.predict(x)print('{} - Predicted: {}'.format(img_path, decode_predictions(preds, top=3)[0]))

1/1 [==============================] - 2s 2s/step
./data/img0.jpg - Predicted: [('n01742172', 'boa_constrictor', 0.8889929), ('n01729322', 'hognose_snake', 0.04273584), ('n01734418', 'king_snake', 0.032792028)]
1/1 [==============================] - 0s 20ms/step
./data/img1.jpg - Predicted: [('n02124075', 'Egyptian_cat', 0.53852856), ('n02127052', 'lynx', 0.19898652), ('n02123159', 'tiger_cat', 0.1614549)]
1/1 [==============================] - 0s 15ms/step
./data/img2.jpg - Predicted: [('n02113799', 'standard_poodle', 0.5422854), ('n02093647', 'Bedlington_terrier', 0.44937962), ('n02091134', 'whippet', 0.002195063)]
1/1 [==============================] - 0s 16ms/step
./data/img3.jpg - Predicted: [('n01818515', 'macaw', 0.7934733), ('n01820546', 'lorikeet', 0.11080903), ('n01828970', 'bee_eater', 0.08086483)]

5. 转换为精度为Float32模型

执行以下代码把刚刚的‘resnet50_saved_model’利用TensorRT优化模型,并把优化好的模型保存在‘resnet50_saved_model_TFTRT_FP32’

# convert.py
from tensorflow.python.compiler.tensorrt import trt_convert as trtconversion_params = trt.DEFAULT_TRT_CONVERSION_PARAMS._replace(precision_mode=trt.TrtPrecisionMode.FP32)converter = trt.TrtGraphConverterV2(input_saved_model_dir='resnet50_saved_model',conversion_params=conversion_params)
converter.convert()
converter.save(output_saved_model_dir='resnet50_saved_model_TFTRT_FP32')

6. 运行时间对比

比较保存的saved_model‘resnet50_saved_model’和用TensorRT优化后的模型‘resnet50_saved_model_TFTRT_FP32’的Throughput(每秒钟处理的图像数),结果如下,优化后的模型是之前的1.7倍。

resnet50_saved_model resnet50_saved_model_TFTRT_FP32
Throughput[images/s] 340 577
# compare_models.py
import timeimport numpy as npimport tensorflow as tf
from tensorflow.python.saved_model import tag_constants, signature_constants
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.resnet50 import preprocess_inputmodels = ['resnet50_saved_model', 'resnet50_saved_model_TFTRT_FP32']for model in models:saved_model = tf.keras.models.load_model(model)infer = saved_model.signatures['serving_default']batch_size = 8batched_input = np.zeros((batch_size, 224, 224, 3), dtype=np.float32)for i in range(batch_size):img_path = 'data/img%d.jpg' % (i % 4)img = image.load_img(img_path, target_size=(224, 224))x = image.img_to_array(img)x = np.expand_dims(x, axis=0)x = preprocess_input(x)batched_input[i, :] = xbatched_input = tf.constant(batched_input)N_warmup_run = 50N_run = 1000elapsed_time = []for i in range(N_warmup_run):preds = infer(batched_input)for i in range(N_run):start_time = time.time()preds = infer(batched_input)end_time = time.time()elapsed_time = np.append(elapsed_time, end_time - start_time)print('Throughput: {:.0f} images/s'.format(N_run * batch_size / elapsed_time.sum()))

如果把精度改为FLOAT16之后,处理速度会更快。

利用TensorRT转换ResNet50相关推荐

  1. 利用TensorRT实现神经网络提速(读取ONNX模型并运行)

    前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家.点击跳转到网站. 前言 这篇文章接着上一篇继续讲解如何具体使用TensorRT. 在之前已经写到过一篇去介绍什么是Te ...

  2. Java 格式转换:利用格式转换实现随机数生成随机 char 字母及 string 字母串

    文章目录 前言 一.char 型与数值型转换规则 二.随机数生成随机字母 三.随机数生成随机字母串 总结 前言 我们都知道在 Java 语言中有八种基本数据类型,而不同数据类型之间的转换你了解多少呢? ...

  3. js 利用canvas转换图片格式并下载图片

    1.利用canvas转换格式 思路很简单,就是在canvas上drawImage,然后再把canvas转换成想要的图片格式 convertImageToCanvas = (image) => { ...

  4. matlab计算单模光纤耦合效率的积分,一种利用模式转换提高单模光纤耦合效率的方法与流程...

    本发明属于无线光通信技术领域,具体涉及一种利用模式转换提高单模光纤耦合效率的方法. 背景技术: 无线光(Free Space Optical Communication,FSOC)通信是一种以光为信号 ...

  5. Android录屏并利用FFmpeg转换成gif(一)录屏

    Android录屏并利用FFmpeg转换成gif(一) 录屏 写博客时经常会希望用一段动画来演示app的行为,目前大多数的做法是在电脑上开模拟器,然后用gif录制软件录制模拟器屏幕,对于非开发人员来讲 ...

  6. Android录屏并利用FFmpeg转换成gif(二)交叉编译FFmpeg源码

    Android录屏并利用FFmpeg转换成gif(二) 写博客时经常会希望用一段动画来演示app的行为,目前大多数的做法是在电脑上开模拟器,然后用gif录制软件录制模拟器屏幕,对于非开发人员来讲这种方 ...

  7. Android录屏并利用FFmpeg转换成gif(三) 在Android中使用ffmpeg命令

    Android录屏并利用FFmpeg转换成gif(三) 写博客时经常会希望用一段动画来演示app的行为,目前大多数的做法是在电脑上开模拟器,然后用gif录制软件录制模拟器屏幕,对于非开发人员来讲这种方 ...

  8. Android录屏并利用FFmpeg转换成gif(四) 将mp4文件转换成gif文件

    Android录屏并利用FFmpeg转换成gif(四) 写博客时经常会希望用一段动画来演示app的行为,目前大多数的做法是在电脑上开模拟器,然后用gif录制软件录制模拟器屏幕,对于非开发人员来讲这种方 ...

  9. tensorrt安装_利用TensorRT对深度学习进行加速

    前言 TensorRT是什么,TensorRT是英伟达公司出品的高性能的推断C++库,专门应用于边缘设备的推断,TensorRT可以将我们训练好的模型分解再进行融合,融合后的模型具有高度的集合度.例如 ...

最新文章

  1. 关于数据库中NULL的描述,下列哪些说法符合《阿里巴巴Java开发手册》
  2. Pycharm初始创建项目和环境搭建(解决aconda库文件引入不全等问题)
  3. java+c#+json+时间_C#与Json时间的转换
  4. python结构嵌套_python2.3嵌套if结构:
  5. 我的第一个ASP类(显示止一篇下一篇文章)
  6. Dlib学习笔记:dlib array2d与 OpenCV Mat互转
  7. java 泰勒级数_鸡群优化算法(CSO)、蜻蜓算法(DA)、乌鸦搜索算法(CSA)、泰勒级数(Taylor series)...
  8. threejs 判断对象是否在可视区内
  9. 强化学习ppt_强化学习和最优控制的十个关键点81页PPT汇总
  10. python隐式调用_c#隐式调用python_C#调用python脚本样例
  11. 强悍的命令行 —— 路径相关
  12. [转载] python删除dataframe行和列
  13. java安卓游戏源码下载_77个安卓游戏 android源码
  14. DzzOffice_flowplayer播放器更改
  15. Python之面向对象-类与 类之间的关系
  16. 网站被黑客攻击怎么办?
  17. 36种漂亮的CSS3网页按钮Button样式
  18. MT8377 MT8389 MT6589 MT6577解析
  19. H5实现3D圣诞树效果
  20. JVM 垃圾收集器 学习笔记(《深入理解java虚拟机》之六 垃圾收集)

热门文章

  1. handlebars用法
  2. Python实现漏斗图的绘制
  3. php gb28181,GB28181测试工具
  4. 阿里云HaaS100驱动LCD液晶屏的方法(含fb.h错误解决方法)
  5. 【AHK】给通达信软件增加F1买入,F2卖出 交易热键(基于中银国际客户端测试)
  6. xampp中tomcat无法启动
  7. 深入浅出:了解前端回流跟重绘
  8. xmlns 命名空间
  9. PyCharm安装和配置
  10. 前端div的隐藏与展示控制