本文是基于TensorRT 5.0.2基础上,关于其内部的end_to_end_tensorflow_mnist例子的分析和介绍。

1 引言

假设当前路径为:

TensorRT-5.0.2.6/samples

其对应当前例子文件目录树为:

# tree pythonpython
├── common.py
├── end_to_end_tensorflow_mnist
│   ├── model.py
│   ├── README.md
│   ├── requirements.txt
│   └── sample.py

2 基于tensorflow生成模型

其中只有2个文件:

  • model:该文件包含简单的训练模型代码
  • sample:该文件使用UFF mnist模型去创建一个TensorRT inference engine

首先介绍下model.py

# 该脚本包含一个简单的模型训练过程
import tensorflow as tf
import numpy as np'''main中第一步:获取数据集 '''
def process_dataset():# 导入mnist数据集# 手动下载aria2c -x 16 https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz# 将mnist.npz移动到~/.keras/datasets/#  tf.keras.datasets.mnist.load_data会去读取~/.keras/datasets/mnist.npz,而不从网络下载(x_train, y_train),(x_test, y_test) = tf.keras.datasets.mnist.load_data()x_train, x_test = x_train / 255.0, x_test / 255.0# Reshape NUM_TRAIN = 60000NUM_TEST = 10000x_train = np.reshape(x_train, (NUM_TRAIN, 28, 28, 1))x_test = np.reshape(x_test, (NUM_TEST, 28, 28, 1))return x_train, y_train, x_test, y_test'''main中第二步:构建模型 '''
def create_model():model = tf.keras.models.Sequential()model.add(tf.keras.layers.InputLayer(input_shape=[28,28, 1]))model.add(tf.keras.layers.Flatten())model.add(tf.keras.layers.Dense(512, activation=tf.nn.relu))model.add(tf.keras.layers.Dense(10, activation=tf.nn.softmax))model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])return model'''main中第五步:模型存储 '''
def save(model, filename):output_names = model.output.op.namesess = tf.keras.backend.get_session()# freeze graphfrozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])# 移除训练的节点frozen_graph = tf.graph_util.remove_training_nodes(frozen_graph)# 保存模型with open(filename, "wb") as ofile:ofile.write(frozen_graph.SerializeToString())def main():''' 1 - 获取数据'''x_train, y_train, x_test, y_test = process_dataset()''' 2 - 构建模型'''model = create_model()''' 3 - 模型训练'''model.fit(x_train, y_train, epochs = 5, verbose = 1)''' 4 - 模型评估'''model.evaluate(x_test, y_test)''' 5 - 模型存储'''save(model, filename="models/lenet5.pb")if __name__ == '__main__':main()

在获得

models/lenet5.pb

之后,执行下述命令,将其转换成uff文件,输出结果如

'''该converter会显示关于input/output nodes的信息,这样你就可以用来在解析的时候进行注册;
本例子中,我们基于tensorflow.keras的命名规则,事先已知input/output nodes名称了 '''[root@30d4bceec4c4 end_to_end_tensorflow_mnist]# convert-to-uff models/lenet5.pb
Loading models/lenet5.pb

3 基于tensorflow的pb文件生成UFF并处理

# 该例子使用UFF MNIST 模型去创建一个TensorRT Inference Engine
from random import randint
from PIL import Image
import numpy as npimport pycuda.driver as cuda
import pycuda.autoinit # 该import会让pycuda自动管理CUDA上下文的创建和清理工作import tensorrt as trtimport sys, os
# sys.path.insert(1, os.path.join(sys.path[0], ".."))
# import common# 这里将common中的GiB和find_sample_data,allocate_buffers,do_inference等函数移动到该py文件中,保证自包含。
def GiB(val):'''以GB为单位,计算所需要的存储值,向左位移10bit表示KB,20bit表示MB '''return val * 1 << 30def find_sample_data(description="Runs a TensorRT Python sample", subfolder="", find_files=[]):'''该函数就是一个参数解析函数。Parses sample arguments.Args:description (str): Description of the sample.subfolder (str): The subfolder containing data relevant to this samplefind_files (str): A list of filenames to find. Each filename will be replaced with an absolute path.Returns:str: Path of data directory.Raises:FileNotFoundError'''# 为了简洁,这里直接将路径硬编码到代码中。data_root = kDEFAULT_DATA_ROOT = os.path.abspath("/TensorRT-5.0.2.6/python/data/")subfolder_path = os.path.join(data_root, subfolder)if not os.path.exists(subfolder_path):print("WARNING: " + subfolder_path + " does not exist. Using " + data_root + " instead.")data_path = subfolder_path if os.path.exists(subfolder_path) else data_rootif not (os.path.exists(data_path)):raise FileNotFoundError(data_path + " does not exist.")for index, f in enumerate(find_files):find_files[index] = os.path.abspath(os.path.join(data_path, f))if not os.path.exists(find_files[index]):raise FileNotFoundError(find_files[index] + " does not exist. ")if find_files:return data_path, find_fileselse:return data_path
#-----------------TRT_LOGGER = trt.Logger(trt.Logger.WARNING)class ModelData(object):MODEL_FILE = os.path.join(os.path.dirname(__file__), "models/lenet5.uff")INPUT_NAME ="input_1"INPUT_SHAPE = (1, 28, 28)OUTPUT_NAME = "dense_1/Softmax"'''main中第二步:构建engine'''
def build_engine(model_file):with trt.Builder(TRT_LOGGER) as builder, \builder.create_network() as network, \trt.UffParser() as parser:builder.max_workspace_size = GiB(1)# 解析 Uff 网络parser.register_input(ModelData.INPUT_NAME, ModelData.INPUT_SHAPE)parser.register_output(ModelData.OUTPUT_NAME)parser.parse(model_file, network)# 构建并返回一个enginereturn builder.build_cuda_engine(network)'''main中第三步 '''
def allocate_buffers(engine):inputs = []outputs = []bindings = []stream = cuda.Stream()for binding in engine:size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_sizedtype = trt.nptype(engine.get_binding_dtype(binding))# 分配host和device端的bufferhost_mem = cuda.pagelocked_empty(size, dtype)device_mem = cuda.mem_alloc(host_mem.nbytes)# 将device端的buffer追加到device的bindings.bindings.append(int(device_mem))# Append to the appropriate list.if engine.binding_is_input(binding):inputs.append(HostDeviceMem(host_mem, device_mem))else:outputs.append(HostDeviceMem(host_mem, device_mem))return inputs, outputs, bindings, stream'''main中第四步 '''
# 从pagelocked_buffer.中读取测试样本
def load_normalized_test_case(data_path, pagelocked_buffer, case_num=randint(0, 9)):test_case_path = os.path.join(data_path, str(case_num) + ".pgm")# Flatten该图像成为一个1维数组,然后归一化,并copy到host端的 pagelocked内存中.img = np.array(Image.open(test_case_path)).ravel()np.copyto(pagelocked_buffer, 1.0 - img / 255.0)return case_num'''main中第五步:执行inference '''
# 该函数可以适应多个输入/输出;输入和输出格式为HostDeviceMem对象组成的列表
def do_inference(context, bindings, inputs, outputs, stream, batch_size=1):# 将数据移动到GPU[cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]# 执行inference.context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)# 将结果从 GPU写回到host端[cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]# 同步streamstream.synchronize()# 返回host端的输出结果return [out.host for out in outputs]def main():''' 1 - 寻找模型文件'''data_path = find_sample_data(description="Runs an MNIST network using a UFF model file", subfolder="mnist")model_file = ModelData.MODEL_FILE''' 2 - 基于build_engine函数构建engine'''with build_engine(model_file) as engine:''' 3 - 分配buffer并创建一个流'''inputs, outputs, bindings, stream = allocate_buffers(engine)with engine.create_execution_context() as context:''' 4 - 读取测试样本,并归一化'''case_num = load_normalized_test_case(data_path, pagelocked_buffer=inputs[0].host)''' 5 - 执行inference,do_inference函数会返回一个list类型,此处只有一个元素'''[output] = do_inference(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)pred = np.argmax(output)print("Test Case: " + str(case_num))print("Prediction: " + str(pred))if __name__ == '__main__':main()

结果如:

转载于:https://www.cnblogs.com/shouhuxianjian/p/10525000.html

TensorRTSamplePython[end_to_end_tensorflow_mnist]相关推荐

  1. 【模型加速】TensorRT安装、测试及常见问题

    ■ 安装过程 一.安装依赖环境 ● Ubuntu 20.04 ● CUDA 11.1 ● cuDNN 8.0.4 ● python 3.8.5 – 可以通过命令查看cuda.cudnn.python版 ...

  2. 加速深度学习在线部署,TensorRT安装及使用教程

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 1 什么是TensorRT 一般的深度学习项目,训练时为了加快速度,会使用多GPU分布式训练. ...

  3. 深度学习之格式转换笔记(三):keras(.hdf5)模型转TensorFlow(.pb) 转TensorRT(.uff)格式

    环境: tensorflow1.15,cuda10.0,cudnn7.6.4 将keras训练好保存的.hdf5格式模型转为tensorflow的.pb模型,然后转为tensorrt支持的uff格式. ...

最新文章

  1. mysql hy000 1005,mysql – ERROR 1005(HY000):无法创建表(errno:150)
  2. Xamarin Essentials教程安全存储SecureStorage
  3. linux中的jobs命令,Linux 中 jobs 命令详解
  4. Eclipse直接运行算法第4版例子(重定向和读取指定路径文件)
  5. Python突击(一)
  6. iPhone开发视频教程 Objective-C部分 (51课时)
  7. word手写字体以假乱真_常用的100个Word快捷键
  8. 2016 网易校招内推C/C++第二场8.6
  9. Go语言 for循环案例
  10. 无线路由器文件服务器,D-LINK路由器局域网文件共享详解
  11. echarts 设置地图默认缩放比例 尺寸
  12. Java开发自学教程!java应届生面试自我介绍
  13. 010项目沟通管理和干系人管理
  14. changelog 生成  npm install -g conventional-changelog-cli
  15. 给iOS App减肥
  16. Esri携“新一代Web GIS”亮相中国地理信息产业大会
  17. iOS开发面试攻略(KVO、KVC、多线程、锁、runloop、计时器)
  18. surface pro3深度linux,Microsoft Surface Pro 3 (简体中文)
  19. 基于Java的学生学费支付系统
  20. 光速读懂ElasticSearch

热门文章

  1. 一些Iphone sqlite 的包装类
  2. 利用自定义事件实现不同窗体间的通讯 -- C#篇
  3. mysql writing to net_mysql 提示 Writing to net_MySQL
  4. mysql求和 子查询_MySQL:子查询中的值总和
  5. Windows安装Linux, (WSL)Windows Subsystem for Linux
  6. 协同过滤算法_一文带你了解协同过滤的前世今生
  7. 【渝粤教育】国家开放大学2018年秋季 0689-22T老年心理健康 参考试题
  8. 【渝粤教育】21秋期末考试物权法10774k1
  9. 美发布《2025年的数学科学》报告
  10. (转)Linux中的screen命令使用