tensorflow2 serving
tensorflow 模型训练部署为tfserving 服务
有以下三部
1 模型训练保存为savemodel
2 保存的模型在docker 部署服务。
3 在调用http 接口,进行模型推理。
1 模型训练保存为 models.save_model
import sys
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess#1 创建模型
(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
# 2 训练和评估模型
model = keras.Sequential([keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3, strides=2, activation='relu', name='Conv1'),keras.layers.Flatten(),keras.layers.Dense(10, name='Dense')
])
model.summary()testing = False
epochs = 5model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))# 3 保存模型 为 SaveModel 格式
MODEL_DIR = './model'
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))
tf.keras.models.save_model(model,export_path,overwrite=True,include_optimizer=True,save_format=None,signatures=None,options=None
)print('\nSaved model:')
4检查模型
使用saved_model_cli命令查看savemodel保存后的的模型和方法
python D:\pythonapp\anacondas\envs\torchenv\Lib\site-packages\tensorflow\python\tools\saved_model_cli.py show --dir model\1 --all
MetaGraphDef with tag-set: ‘serve’ contains the following SignatureDefs:
signature_def['__saved_model_init_op']:The given SavedModel SignatureDef contains the following input(s):The given SavedModel SignatureDef contains the following output(s):outputs['__saved_model_init_op'] tensor_info:dtype: DT_INVALIDshape: unknown_rankname: NoOpMethod name is:signature_def['serving_default']:The given SavedModel SignatureDef contains the following input(s):inputs['Conv1_input'] tensor_info:dtype: DT_FLOATshape: (-1, 28, 28, 1)name: serving_default_Conv1_input:0The given SavedModel SignatureDef contains the following output(s):outputs['Dense'] tensor_info:dtype: DT_FLOATshape: (-1, 10)name: StatefulPartitionedCall:0Method name is: tensorflow/serving/predict
2 docker 运行tfserving 服务
1 docker 拉取tfserving 镜像。
docker pull tensorflow/serving
2 将上面保存的模型放到某个目录下
我是在windows下训练的模型,将保存在model路径下的模型 放在了/opt/tfserving下。
3 . 构建模型和tserving 的链接,启动服务。
docker run -p 8501:8501 --mount type=bind,source=/opt/tfserving/model/,target=/models/model -e MODEL_NAME=model -t tensorflow/serving
4 模型提供的服务请求默认为 http://localhost:8501/v1/models/model:predict
在windows 下请求该接口。结果如图
import requests
iheaders = {"content-type": "application/json"}
json_response = requests.post('http://192.168.10.100:8501/v1/models/model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']for i in range(0,3):show(i, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[i])], np.argmax(predictions[i]), class_names[test_labels[i]], test_labels[i]))
3 完整代码
1 模型训练保存为savemodel (可以使用saved_model_cli命令查看savemodel保存后的的模型和方法)后面预测时会使用到
2 保存的模型在docker 部署服务。(由于我的是window训练,linux服务,不能实现热更新,模型默认加载最新模型,也可以指定加载)。
3 在调用http 接口,进行模型推理。
完整代码
"""
part one 模型训练保存
"""
import sys
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
import os
import subprocess(train_images, train_labels), (test_images, test_labels) = keras.datasets.fashion_mnist.load_data()# scale the values to 0.0 to 1.0
train_images = train_images / 255.0
test_images = test_images / 255.0# reshape for feeding into the model
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1)
test_images = test_images.reshape(test_images.shape[0], 28, 28, 1)class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']print('\ntrain_images.shape: {}, of {}'.format(train_images.shape, train_images.dtype))
print('test_images.shape: {}, of {}'.format(test_images.shape, test_images.dtype))
model = keras.Sequential([keras.layers.Conv2D(input_shape=(28,28,1), filters=8, kernel_size=3, strides=2, activation='relu', name='Conv1'),keras.layers.Flatten(),keras.layers.Dense(10, name='Dense')
])
model.summary()testing = False
epochs = 5model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=epochs)test_loss, test_acc = model.evaluate(test_images, test_labels)
print('\nTest accuracy: {}'.format(test_acc))MODEL_DIR = './model'
version = 1
export_path = os.path.join(MODEL_DIR, str(version))
print('export_path = {}\n'.format(export_path))
tf.keras.models.save_model(model,export_path,overwrite=True,include_optimizer=True,save_format=None,signatures=None,options=None
)print('\nSaved model:')"""
part two
模型服务部署此部分在Linux完成
docker run -p 8501:8501 --mount type=bind,source=/opt/tfserving/model/,target=/models/model -e MODEL_NAME=model -t tensorflow/serving
""""""
part three 模型推理测试
"""
def show(idx, title):plt.figure()plt.imshow(test_images[idx].reshape(28,28))plt.axis('off')plt.title('\n\n{}'.format(title), fontdict={'size': 16})import json
data = json.dumps({"signature_name": "serving_default", "instances": test_images[0:1].tolist()})
print('Data: {} ... {}'.format(data[:50], data[len(data)-52:]))import requests
headers = {"content-type": "application/json"}
json_response = requests.post(r'http://192.168.10.100:8501/v1/models/model:predict', data=data, headers=headers)
predictions = json.loads(json_response.text)['predictions']
show(0, 'The model thought this was a {} (class {}), and it was actually a {} (class {})'.format(class_names[np.argmax(predictions[0])], np.argmax(predictions[0]), class_names[test_labels[0]], test_labels[0]))
tensorflow2 serving相关推荐
- 【NLP】NLP实战篇之tensorflow2.0快速入门
修改上版代码格式问题.Tensorflow2.0跟Keras紧密结合,相比于1.0版本,2.0可以更快上手,并且能更方便找到需要的api.本文中以IMDB文本分类为例,简单介绍了从数据下载.预处理.建 ...
- [深度学习] tensorflow1.x和tensorflow2.x对比与总结
tensorflow1.x和tensorflow2.x对比与总结 1. 主要区别有如下几点 1.0. 易于使用(Ease of use) 1.1. 使用Eager模式(Eager Execution) ...
- TensorFlow2.0(一)--简介与环境搭建
简介与环境搭建 1. TensorFlow是什么 2. TensorFlow1.0与2.0架构 3. TensorFlow环境配置 1. TensorFlow是什么 TensorFlow是Google ...
- Tensorflow2.0
Tensorflow2.0 Tensorflow 简介 Tensorflow是什么 Google开源软件库 采用数据流图,用于数值计算 支持多平台 GPU CPU 移动设备 最初用于深度学习,变得通用 ...
- tensorflow2.0(简介)
tensorflow简介 一.tensorflow是什么 采用数据流图,用于数值计算 支持多种平台–GPU,CPU,移动设备 最初用于深度学习,越来越通用–只要能描述为数据流图就可以用tensorfl ...
- TensorFlow2 -官方教程 :保存和恢复模型
文章目录 准备工作:安装,导入,获取数据集,定义model 在训练期间保存模型(以 checkpoints 形式保存) Checkpoint 回调用法 checkpoint 回调选项 这些文件是什么? ...
- tf.saved_model.save模型导出、TensorFlow Serving模型部署、TensorBoard中的HParams 超参数调优
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 4.11 综合案例:模型导出与部署 学习目标 目标 掌握Ten ...
- Tensorflow2.0——新世界的大门
由于框架的持续改善与更新,目前是一个比较好的机会来通过Tensorflow入门深度学习. 那些年追过的编程语言和框架 犹记得十年前刚刚接触编程的时候,我还是一个懵懂的少年,在计算机的世界里跌跌撞撞.为 ...
- tensorflow2.0基础简介
tensorflow2.0简介 1.tensorflow 2.0基础知识简介 tensorflow2.0是谷歌在2019年3月份发布更新的一款到端开源机器学习平台,其目的在于优化tensorflow1 ...
- Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建
Anaconda3+python3.7.10+TensorFlow2.3.0+PyQt5环境搭建 一.Anaconda 创建 python3.7环境 1.进入 C:\Users\用户名 目录下,找到 ...
最新文章
- centos5 db_load 命令无法使用
- microsoft mysql下载_Microsoft SQL Server 2018
- 预计2021年电视出货量有所上涨,网络推广外包之下OLED大肆布局
- sql server tcp 信号灯超时时间已到_「图文详解」TCP为啥要3次握手和4次挥手?3次挥手不行吗?...
- mongo数据库CRUD
- 服务器克隆机网络端口排错
- 方法覆盖(重写)和方法重载
- 网络协议:超时与重传机制
- 超全开放 API 免费调用,这款 API 管理工具太香了!
- 设计一个矩形类rectangle_使用Python super()为您的类增强
- python程序画漂亮图_用python画图代码:正弦图像、多轴图等案例
- spss数据分析_排序数据_计算变量
- 定义类,super的使用,super的使用
- mysql的group by语句不会产生_MySQL:为什么查询列表中多了它,GROUP BY语句就会报错呢?...
- Windows注册表内容详解(转载)
- excel设置斑马线
- Android 百度地图marker中图片不显示的解决方案
- SEO外链收录:锚文本外链代发排名
- 对Scanner.hasNext的总结
- python 使用pymssql的基本总结
热门文章
- NNDL 作业7:第五章课后题(1×1 卷积核 | CNN BP)
- 【Unity好用插件】PSD文件转UI插件——Psd 2 Unity uGUI Pro ★★★完整过程
- android 视频上传网络异常,App上传视频(或大文件)失败怎么办?
- 输入法公司Kika完成2.2亿B+轮融资 猎豹移动领投
- 软件架构-解密电商系统商品模块业务
- python floor是什么意思_python里floor怎么用
- Retrofit的封装
- GIS方法类期刊和论文的综述(Introduction)怎么写?
- android透明背景边框线
- Testbench的激励添加和书写技巧