tensorflow 模型训练部署为tfserving 服务

有以下三部
1 模型训练保存为savemodel
2 保存的模型在docker 部署服务。
3 在调用http 接口,进行模型推理。

1 模型训练保存为 models.save_model

import sys
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess#1 创建模型
(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
# 2 训练和评估模型
model = keras.Sequential([keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3, strides=2, activation='relu', name='Conv1'),keras.layers.Flatten(),keras.layers.Dense(10, name='Dense')
])
model.summary()testing = False
epochs = 5model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))# 3 保存模型 为 SaveModel 格式
MODEL_DIR = './model'
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))
tf.keras.models.save_model(model,export_path,overwrite=True,include_optimizer=True,save_format=None,signatures=None,options=None
)print('\nSaved model:')

4检查模型
使用saved_model_cli命令查看savemodel保存后的的模型和方法

python D:\pythonapp\anacondas\envs\torchenv\Lib\site-packages\tensorflow\python\tools\saved_model_cli.py show --dir model\1 --all
MetaGraphDef with tag-set: ‘serve’ contains the following SignatureDefs:

signature_def['__saved_model_init_op']:The given SavedModel SignatureDef contains the following input(s):The given SavedModel SignatureDef contains the following output(s):outputs['__saved_model_init_op'] tensor_info:dtype: DT_INVALIDshape: unknown_rankname: NoOpMethod name is:signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['Conv1_input'] tensor_info:dtype: DT_FLOATshape: (-1, 28, 28, 1)name: serving_default_Conv1_input:0The given SavedModel SignatureDef contains the following output(s):outputs['Dense'] tensor_info:dtype: DT_FLOATshape: (-1, 10)name: StatefulPartitionedCall:0Method name is: tensorflow/serving/predict

2 docker 运行tfserving 服务

1 docker 拉取tfserving 镜像。
docker pull tensorflow/serving

2 将上面保存的模型放到某个目录下
我是在windows下训练的模型,将保存在model路径下的模型 放在了/opt/tfserving下。

3 . 构建模型和tserving 的链接,启动服务。

docker run -p 8501:8501   --mount type=bind,source=/opt/tfserving/model/,target=/models/model   -e MODEL_NAME=model -t tensorflow/serving

4 模型提供的服务请求默认为 http://localhost:8501/v1/models/model:predict
在windows 下请求该接口。结果如图


import requests
iheaders = {"content-type": "application/json"}
json_response = requests.post('http://192.168.10.100:8501/v1/models/model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']for i in range(0,3):show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))

3 完整代码

1 模型训练保存为savemodel (可以使用saved_model_cli命令查看savemodel保存后的的模型和方法)后面预测时会使用到
2 保存的模型在docker 部署服务。(由于我的是window训练,linux服务,不能实现热更新,模型默认加载最新模型,也可以指定加载)。
3 在调用http 接口,进行模型推理。

完整代码

"""
part one 模型训练保存
"""
import sys
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
model = keras.Sequential([keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3, strides=2, activation='relu', name='Conv1'),keras.layers.Flatten(),keras.layers.Dense(10, name='Dense')
])
model.summary()testing = False
epochs = 5model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))MODEL_DIR = './model'
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))
tf.keras.models.save_model(model,export_path,overwrite=True,include_optimizer=True,save_format=None,signatures=None,options=None
)print('\nSaved model:')"""
part two
模型服务部署此部分在Linux完成
docker run -p 8501:8501   --mount type=bind,source=/opt/tfserving/model/,target=/models/model   -e MODEL_NAME=model -t tensorflow/serving
""""""
part three 模型推理测试
"""
def show(idx, title):plt.figure()plt.imshow(test_images[idx].reshape(28,28))plt.axis('off')plt.title('\n\n{}'.format(title), fontdict={'size': 16})import json
data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:1].tolist()})
print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))import requests
headers = {"content-type": "application/json"}
json_response = requests.post(r'http://192.168.10.100:8501/v1/models/model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))

tensorflow2 serving相关推荐

  1. 【NLP】NLP实战篇之tensorflow2.0快速入门

    修改上版代码格式问题.Tensorflow2.0跟Keras紧密结合,相比于1.0版本,2.0可以更快上手,并且能更方便找到需要的api.本文中以IMDB文本分类为例,简单介绍了从数据下载.预处理.建 ...

  2. [深度学习] tensorflow1.x和tensorflow2.x对比与总结

    tensorflow1.x和tensorflow2.x对比与总结 1. 主要区别有如下几点 1.0. 易于使用(Ease of use) 1.1. 使用Eager模式(Eager Execution) ...

  3. TensorFlow2.0(一)--简介与环境搭建

    简介与环境搭建 1. TensorFlow是什么 2. TensorFlow1.0与2.0架构 3. TensorFlow环境配置 1. TensorFlow是什么 TensorFlow是Google ...

  4. Tensorflow2.0

    Tensorflow2.0 Tensorflow 简介 Tensorflow是什么 Google开源软件库 采用数据流图,用于数值计算 支持多平台 GPU CPU 移动设备 最初用于深度学习,变得通用 ...

  5. tensorflow2.0(简介)

    tensorflow简介 一.tensorflow是什么 采用数据流图,用于数值计算 支持多种平台–GPU,CPU,移动设备 最初用于深度学习,越来越通用–只要能描述为数据流图就可以用tensorfl ...

  6. TensorFlow2 -官方教程 :保存和恢复模型

    文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...

  7. tf.saved_model.save模型导出、TensorFlow Serving模型部署、TensorBoard中的HParams 超参数调优

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 4.11 综合案例:模型导出与部署 学习目标 目标 掌握Ten ...

  8. Tensorflow2.0——新世界的大门

    由于框架的持续改善与更新,目前是一个比较好的机会来通过Tensorflow入门深度学习. 那些年追过的编程语言和框架 犹记得十年前刚刚接触编程的时候,我还是一个懵懂的少年,在计算机的世界里跌跌撞撞.为 ...

  9. tensorflow2.0基础简介

    tensorflow2.0简介 1.tensorflow 2.0基础知识简介 tensorflow2.0是谷歌在2019年3月份发布更新的一款到端开源机器学习平台,其目的在于优化tensorflow1 ...

  10. Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建

    Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建 一.Anaconda 创建 python3.7环境 1.进入 C:\Users\用户名 目录下,找到 ...

最新文章

  1. centos5 db_load 命令无法使用
  2. microsoft mysql下载_Microsoft SQL Server 2018
  3. 预计2021年电视出货量有所上涨,网络推广外包之下OLED大肆布局
  4. sql server tcp 信号灯超时时间已到_「图文详解」TCP为啥要3次握手和4次挥手?3次挥手不行吗?...
  5. mongo数据库CRUD
  6. 服务器克隆机网络端口排错
  7. 方法覆盖(重写)和方法重载
  8. 网络协议:超时与重传机制
  9. 超全开放 API 免费调用,这款 API 管理工具太香了!
  10. 设计一个矩形类rectangle_使用Python super()为您的类增强
  11. python程序画漂亮图_用python画图代码:正弦图像、多轴图等案例
  12. spss数据分析_排序数据_计算变量
  13. 定义类,super的使用,super的使用
  14. mysql的group by语句不会产生_MySQL:为什么查询列表中多了它,GROUP BY语句就会报错呢?...
  15. Windows注册表内容详解(转载)
  16. excel设置斑马线
  17. Android 百度地图marker中图片不显示的解决方案
  18. SEO外链收录:锚文本外链代发排名
  19. 对Scanner.hasNext的总结
  20. python 使用pymssql的基本总结

热门文章

  1. NNDL 作业7:第五章课后题(1×1 卷积核 | CNN BP)
  2. 【Unity好用插件】PSD文件转UI插件——Psd 2 Unity uGUI Pro ★★★完整过程
  3. android 视频上传网络异常,App上传视频(或大文件)失败怎么办?
  4. 输入法公司Kika完成2.2亿B+轮融资 猎豹移动领投
  5. 软件架构-解密电商系统商品模块业务
  6. python floor是什么意思_python里floor怎么用
  7. Retrofit的封装
  8. GIS方法类期刊和论文的综述(Introduction)怎么写?
  9. android透明背景边框线
  10. Testbench的激励添加和书写技巧