TensorFlow Serving 可以快速部署 Tensorflow 模型,上线 gRPC 或 REST API。

官方推荐 Docker 部署,也给了训练到部署的完整教程:Servers: TFX for TensorFlow Serving。本文只是遵照教程进行的练习,有助于了解 TensorFlow 训练到部署的整个过程。

准备环境

准备好 TensorFlow 环境,导入依赖:

import sys# Confirm that we're using Python 3
assert sys.version_info.major == 3, 'Oops, not running Python 3. Use Runtime > Change runtime type'
import tensorflow as tf
from tensorflow import keras# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocessprint(f'TensorFlow version: {tf.__version__}')
print(f'TensorFlow GPU support: {tf.test.is_built_with_gpu_support()}')physical_gpus = tf.config.list_physical_devices('GPU')
print(physical_gpus)
for gpu in physical_gpus:# memory growth must be set before GPUs have been initializedtf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(physical_gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
TensorFlow version: 2.4.1
TensorFlow GPU support: True
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
1 Physical GPUs, 1 Logical GPUs

创建模型

载入 Fashion MNIST 数据集:

fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
train_images.shape: (60000, 28, 28, 1), of float64
test_images.shape: (10000, 28, 28, 1), of float64

用最简单的 CNN 训练模型,

model = keras.Sequential([keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3,strides=2, activation='relu', name='Conv1'),keras.layers.Flatten(),keras.layers.Dense(10, name='Dense')
])
model.summary()testing = False
epochs = 5model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
Conv1 (Conv2D)               (None, 13, 13, 8)         80
_________________________________________________________________
flatten (Flatten)            (None, 1352)              0
_________________________________________________________________
Dense (Dense)                (None, 10)                13530
=================================================================
Total params: 13,610
Trainable params: 13,610
Non-trainable params: 0
_________________________________________________________________
Epoch 1/5
1875/1875 [==============================] - 3s 722us/step - loss: 0.7387 - sparse_categorical_accuracy: 0.7449
Epoch 2/5
1875/1875 [==============================] - 1s 793us/step - loss: 0.4561 - sparse_categorical_accuracy: 0.8408
Epoch 3/5
1875/1875 [==============================] - 1s 720us/step - loss: 0.4097 - sparse_categorical_accuracy: 0.8566
Epoch 4/5
1875/1875 [==============================] - 1s 718us/step - loss: 0.3899 - sparse_categorical_accuracy: 0.8636
Epoch 5/5
1875/1875 [==============================] - 1s 719us/step - loss: 0.3673 - sparse_categorical_accuracy: 0.8701
313/313 [==============================] - 0s 782us/step - loss: 0.3937 - sparse_categorical_accuracy: 0.8630Test accuracy: 0.8629999756813049

保存模型

将模型保存成 SavedModel 格式,路径里加上版本号,以便 TensorFlow Serving 时可选择模型版本。

# Fetch the Keras session and save the model
# The signature definition is defined by the input and output tensors,
# and stored with the default serving key
import tempfileMODEL_DIR = os.path.join(tempfile.gettempdir(), 'tfx')
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))tf.keras.models.save_model(model,export_path,overwrite=True,include_optimizer=True,save_format=None,signatures=None,options=None
)print('\nSaved model:')
!ls -l {export_path}
export_path = /tmp/tfx/1INFO:tensorflow:Assets written to: /tmp/tfx/1/assetsSaved model:
total 88
drwxr-xr-x 2 john john  4096 Apr 13 15:10 assets
-rw-rw-r-- 1 john john 78169 Apr 13 15:12 saved_model.pb
drwxr-xr-x 2 john john  4096 Apr 13 15:12 variables

查看模型

使用 saved_model_cli 工具查看模型的 MetaGraphDefs (the models) 和 SignatureDefs (the methods you can call),了解信息。

!saved_model_cli show --dir '/tmp/tfx/1' --all
2021-04-13 15:12:29.433576: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:signature_def['__saved_model_init_op']:The given SavedModel SignatureDef contains the following input(s):The given SavedModel SignatureDef contains the following output(s):outputs['__saved_model_init_op'] tensor_info:dtype: DT_INVALIDshape: unknown_rankname: NoOpMethod name is:signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['Conv1_input'] tensor_info:dtype: DT_FLOATshape: (-1, 28, 28, 1)name: serving_default_Conv1_input:0The given SavedModel SignatureDef contains the following output(s):outputs['Dense'] tensor_info:dtype: DT_FLOATshape: (-1, 10)name: StatefulPartitionedCall:0Method name is: tensorflow/serving/predictDefined Functions:Function Name: '__call__'Option #1Callable with:Argument #1Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input')Argument #2DType: boolValue: FalseArgument #3DType: NoneTypeValue: NoneOption #2Callable with:Argument #1inputs: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='inputs')Argument #2DType: boolValue: FalseArgument #3DType: NoneTypeValue: NoneOption #3Callable with:Argument #1inputs: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='inputs')Argument #2DType: boolValue: TrueArgument #3DType: NoneTypeValue: NoneOption #4Callable with:Argument #1Conv1_input: TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32, name='Conv1_input')Argument #2DType: boolValue: TrueArgument #3DType: NoneTypeValue: None...

部署模型

安装 Serving

echo "deb [arch=amd64] http://storage.googleapis.com/tensorflow-serving-apt stable tensorflow-model-server tensorflow-model-server-universal" | sudo tee /etc/apt/sources.list.d/tensorflow-serving.list && \
curl https://storage.googleapis.com/tensorflow-serving-apt/tensorflow-serving.release.pub.gpg | sudo apt-key add -sudo apt update
sudo apt install tensorflow-model-server

开启 Serving

开启 TensorFlow Serving ,提供 REST API :

  • rest_api_port: REST 请求端口。
  • model_name: REST 请求 URL ,自定义的名称。
  • model_base_path: 模型所在目录。
nohup tensorflow_model_server \--rest_api_port=8501 \--model_name=fashion_model \--model_base_path="/tmp/tfx" >server.log 2>&1 &
$ tail server.log
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-04-13 15:12:10.706648: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:206] Restoring SavedModel bundle.
2021-04-13 15:12:10.726722: I external/org_tensorflow/tensorflow/core/platform/profile_utils/cpu_utils.cc:112] CPU Frequency: 2599990000 Hz
2021-04-13 15:12:10.756506: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:190] Running initialization op on SavedModel bundle at path: /tmp/tfx/1
2021-04-13 15:12:10.759935: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:277] SavedModel load for tags { serve }; Status: success: OK. Took 110653 microseconds.
2021-04-13 15:12:10.760277: I tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc:59] No warmup data file found at /tmp/tfx/1/assets.extra/tf_serving_warmup_requests
2021-04-13 15:12:10.760486: I tensorflow_serving/core/loader_harness.cc:87] Successfully loaded servable version {name: fashion_model version: 1}
2021-04-13 15:12:10.763938: I tensorflow_serving/model_servers/server.cc:371] Running gRPC ModelServer at 0.0.0.0:8500 ...
[evhttp_server.cc : 238] NET_LOG: Entering the event loop ...
2021-04-13 15:12:10.765308: I tensorflow_serving/model_servers/server.cc:391] Exporting HTTP/REST API at:localhost:8501 ...

访问服务

随机显示一张测试图:

def show(idx, title):plt.figure()plt.imshow(test_images[idx].reshape(28,28))plt.axis('off')plt.title('\n\n{}'.format(title), fontdict={'size': 16})import random
rando = random.randint(0,len(test_images)-1)
show(rando, 'An Example Image: {}'.format(class_names[test_labels[rando]]))

创建 JSON 对象,给到三张要预测的图:

import json
data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:3].tolist()})
print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))
Data: {"signature_name": "serving_default", "instances": ...  [0.0], [0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]]]}

REST 请求

最新模型版本进行预测:

!pip install -q requestsimport requests
headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))

指定模型版本进行预测:

headers = {"content-type": "application/json"}
json_response = requests.post('http://localhost:8501/v1/models/fashion_model/versions/1:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']for i in range(0,3):show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))

GoCoding 个人实践的经验分享,可关注公众号!

TensorFlow Serving相关推荐

  1. 如何将TensorFlow Serving的性能提高超过70%?

    点击上方↑↑↑蓝字关注我们~ 「2019 Python开发者日」7折优惠最后2天,请扫码咨询 ↑↑↑ 译者 | Major 出品 | AI科技大本营(ID:rgznai100) TensorFlow已 ...

  2. TensorFlow serving远程访问引擎的容器部署

    2019独角兽企业重金招聘Python工程师标准>>> TensorFlow Serving是通过rpc接口远程访问tensorflow引擎的服务器. TensorFlow Serv ...

  3. tensorflow从入门到精通100讲(六)-在TensorFlow Serving/Docker中做keras 模型部署

    前言 不知道大家研究过没有,tensorflow模型有三种保存方式: 训练时我们会一般会将模型保存成:checkpoint文件 为了方便python,C++或者其他语言部署你的模型,你可以将模型保存成 ...

  4. TensorFlow Serving 尝尝鲜

    2019独角兽企业重金招聘Python工程师标准>>> 作者:Mao Chan BitTiger尊重原创版权,转载已经过作者授权. 2016年,机器学习在 Alpha Go 与李世石 ...

  5. TensorFlow Serving + Docker + Tornado机器学习模型生产级快速部署

    点击上方"AI搞事情"关注我们 内容转载自知乎:https://zhuanlan.zhihu.com/p/52096200 Justin ho 〉 本文将会介绍使用TensorFl ...

  6. 使用tensorflow serving部署keras模型(tensorflow 2.0.0)

    点击上方"AI搞事情"关注我们 内容转载自知乎:https://zhuanlan.zhihu.com/p/96917543 Justin ho 〉 Tensorflow 2.0.0 ...

  7. Tensorflow Serving部署tensorflow、keras模型详解

    写在篇前   本篇介绍如何使用Tensorflow Serving组件导出训练好的Tensorflow模型,并使用标准tensorflow model server来部署深度学习模型预测服务.tens ...

  8. 基于TensorFlow Serving的深度学习在线预估

    一.前言 随着深度学习在图像.语言.广告点击率预估等各个领域不断发展,很多团队开始探索深度学习技术在业务层面的实践与应用.而在广告CTR预估方面,新模型也是层出不穷: Wide and Deep[1] ...

  9. 构建并用 TensorFlow Serving 部署 Wide Deep 模型

    Wide & Deep 模型是谷歌在 2016 年发表的论文中所提到的模型.在论文中,谷歌将 LR 模型与 深度神经网络 结合在一起作为 Google Play 的推荐获得了一定的效果.在这篇 ...

  10. Tensorflow Serving 介绍

    TensorFlow Serving Introduction TensorFlow Serving 系统非常适用于大规模运行能够基于真实情况的数据并会发生动态改变的多重模型.它给出了一个把模型应用到 ...

最新文章

  1. 【HTML/CSS】CSS权重、继承及引入方式
  2. 仓库温度湿度控制措施_药品仓库如何保持温湿度均衡?
  3. Linux防火墙-netfilter filter表案列与nat表应用
  4. Oracle学习(四)_SQL函数
  5. 菜鸟编译OPenJDK全过程记录
  6. 2019年工程造价表_【行业要闻】中国建设工程造价管理协会 关于2019年工程造价咨询企业造价咨询收入排名的公告...
  7. [渝粤教育] 西南科技大学 土力学基础工程 在线考试复习资料
  8. 函数和结构(C++)
  9. UIM28RS-IE RS485/脉冲二合一28 闭环/开环步进一体机
  10. UVALive-3713 Astronauts (2-SAT)
  11. Grub2引导进入DOS系统
  12. .Net Entity Framework Core 设置浮点数精度
  13. 0x0FF0ED76 (ucrtbased.dll) (Project5.exe 中)处有未经处理的异常: 将一个无效参数传递给了将无效参数视为严重错误的函数。 出现了
  14. mysql乘法_mysql乘法
  15. html页面其中有添加员工的,编写一个添加员工信息的HTML页面,当用户点击添加按钮,请求AddEmpServlet,实现将用户提交的员工基本信息返回给客户端显示出来。...
  16. 信任、公平、梦想—新拍拍,新起点-拍拍网蒉莺春
  17. DirectX12学习笔记(四)Direct3D Initialization
  18. IDL数学分析与插值
  19. 基于Python新闻信息管理系统设计与实现 开题报告
  20. PID算法 控制参数如何设定调节

热门文章

  1. HTML:利用canvas画定位图标
  2. 历史上的今天:“计算机之父”争夺战;Microsoft Excel 诞生;百度推出百度地图...
  3. 戴尔服务器r420系统安装系统,DELLR420+R720服务器raid+驱动安装教程.docx
  4. AUTOMATE THE BORING STUFF WITH PYTHON读书笔记 - 第18章:SENDING EMAIL AND TEXT MESSAGES
  5. 如何批量在图片上加文字?
  6. NLP 论文领读 | 缺少有标注的数据集怎么训练文本检索模型?来看看 LaPraDoR怎么做的吧
  7. Retrofit自定义CallAdapterFactory
  8. 朋友圈爱心拼图php源码_微信朋友圈九宫格爱心拼图怎么弄 拼图教程
  9. Android实现校园新闻APP,基于android平台的校园新闻app的开发 大学毕业论文.doc
  10. 水花兄弟又凑齐了(20220111 Week2-1)