通过带Flask的REST API在Python中部署PyTorch
在本文中,将使用Flask来部署PyTorch模型,并用讲解用于模型推断的 REST API。特别是,将部署一个预训练的DenseNet 121模 型来检测图像。
备注: 可在GitHub上获取本文用到的完整代码
这是在生产中部署PyTorch模型的系列教程中的第一篇。到目前为止,以这种方式使用Flask是开始为PyTorch模型提供服务的最简单方法, 但不适用于具有高性能要求的用例。因此: * 如果已经熟悉TorchScript,则可以直接进入的Loading a TorchScript Model in C++教程。 * 如果首先需要复习TorchScript,请查看的Intro a TorchScript教程。

1.定义API 将首先定义API端点、请求和响应类型。的API端点将位于/ predict,它接受带有包含图像的file参数的HTTP POST请求。响应 将是包含预测的JSON响应: ```buildoutcfg {“class_id”: “n02124075”, “class_name”: “Egyptian_cat”}

2.依赖(包)

运行下面的命令来下载需要的依赖:

$ pip install Flask==1.0.3 torchvision-0.3.0
3.简单的Web服务器
以下是一个简单的Web服务器,摘自Flask文档
from flask import Flask
app = Flask(__name__)@app.route('/')
def hello():return 'Hello World!'
将以上代码段保存在名为app.py的文件中,现在可以通过输入以下内容来运行Flask开发服务器:
$ FLASK_ENV=development FLASK_APP=app.py flask run
当在web浏览器中访问http://localhost:5000/时,会收到文本Hello World的问候!
将对以上代码片段进行一些更改,以使其适合的API定义。首先,将重命名predict方法。将端点路径更新为/predict。 由于图像文件将通过HTTP POST请求发送,因此将对其进行更新,使其也仅接受POST请求:
@app.route('/predict', methods=['POST'])
def predict():return 'Hello World!'
还将更改响应类型,以使其返回包含ImageNet类的id和name的JSON响应。更新后的app.py文件现在为:
from flask import Flask, jsonify
app = Flask(__name__)@app.route('/predict', methods=['POST'])
def predict():return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})
4.推理
在下一部分中,将重点介绍编写推理代码。这将涉及两部分,第一部分是准备图像,以便可以将其馈送到DenseNet;第二部分,将编 写代码以从模型中获取实际的预测。
4.1 准备图像
DenseNet模型要求图像为尺寸为224 x 224的 3 通道RGB图像。还将使用所需的均值和标准偏差值对图像张量进行归一化。可以点击 这里来了解更多关于它的内容。
将使用来自torchvision库的transforms来建立转换管道,该转换管道可根据需要转换图像。可以在此处 阅读有关转换的更多信息。
import ioimport torchvision.transforms as transforms
from PIL import Imagedef transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)
上面的方法以字节为单位获取图像数据,应用一系列变换并返回张量。要测试上述方法,请以字节模式读取图像文件(首先将../_static/img/ sample_file.jpeg替换为计算机上文件的实际路径),然后查看是否获得了张量:
with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()tensor = transform_image(image_bytes=image_bytes)print(tensor)
•   输出结果:
tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3302, -1.3473],[ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],[ 0.7077,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],...,[ 1.3755,  1.3927,  1.4098,  ...,  1.1700,  1.3584,  1.6667],[ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],[ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],[[ 0.5728,  0.5378,  0.5203,  ..., -1.3704, -1.3529, -1.3529],[ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3179, -1.3354],[ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],...,[ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],[ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],[ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],[[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6476],[ 0.8099,  0.7576,  0.7228,  ..., -1.6476, -1.6476, -1.6650],[ 1.0017,  0.9145,  0.8797,  ..., -1.6476, -1.6650, -1.6999],...,[ 1.6291,  1.6291,  1.6465,  ...,  1.6291,  1.8208,  2.1346],[ 2.1868,  2.0300,  1.6814,  ...,  1.7685,  1.9428,  2.0125],[ 1.9254,  2.0997,  2.0823,  ...,  1.9428,  2.2043,  1.9080]]]])
4.2 预测
现在将使用预训练的DenseNet 121模型来预测图像的类别。将使用torchvision库中的一个库,加载模型并进行推断。在此示例中, 将使用预训练的模型,但可以对自己的模型使用相同的方法。在这个教程 中了解有关加载模型的更多信息。
from torchvision import models# 确保使用`pretrained`作为`True`来使用预训练的权重:
model = models.densenet121(pretrained=True)
# 由于仅将模型用于推理,因此请切换到“eval”模式:
model.eval()def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)return y_hat
张量y_hat将包含预测的类的id的索引。但是,需要一个易于阅读的类名。为此,需要一个类id来命名映射。将该文件 下载为imagenet_class_index.json并记住它的保存位置(或者,如果按照本文中的确切步骤操作,请将其保存在tutorials/_static中)。 此文件包含ImageNet类的id到ImageNet类的name的映射。将加载此JSON文件并获取预测索引的类的name。
import jsonimagenet_class_index = json.load(open('../_static/imagenet_class_index.json'))def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]
在使用字典imagenet_class_index之前,首先将张量值转换为字符串值,因为字典imagenet_class_index中的keys是字符串。将 测试上述方法:
with open("../_static/img/sample_file.jpeg", 'rb') as f:image_bytes = f.read()print(get_prediction(image_bytes=image_bytes))
•   输出结果:
['n02124075', 'Egyptian_cat']
会得到这样的一个响应:
['n02124075', 'Egyptian_cat']
数组中的第一项是ImageNet类的id,第二项是人类可读的name。
注意:是否注意到模型变量不是get_prediction方法的一部分?或者为什么模型是全局变量?就内存和计算而言,加载模型可能是 一项昂贵的操作。如果将模型加载到get_prediction方法中,则每次调用该方法时都会不必要地加载该模型。由于正在构建Web服务 器,因此每秒可能有成千上万的请求,因此不应该浪费时间为每个推断重复加载模型。因此,仅将模型加载到内存中一次。在生 产系统中,必须有效利用计算以能够大规模处理请求,因此通常应在处理请求之前加载模型。
5.将模型集成到的API服务器中
在最后一部分中,将模型添加到Flask API服务器中。由于的API服务器应该获取图像文件,因此将更新predict方法以从请求中 读取文件:
from flask import request@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':# 从请求中获得文件file = request.files['file']# 转化为字节img_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})
app.py文件现已完成。以下是完整版本;将路径替换为保存文件的路径,它的运行应是如下:
import io
import jsonfrom torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import Flask, jsonify, requestapp = Flask(__name__)
imagenet_class_index = json.load(open('<PATH/TO/.json/FILE>/imagenet_class_index.json'))
model = models.densenet121(pretrained=True)
model.eval()def transform_image(image_bytes):my_transforms = transforms.Compose([transforms.Resize(255),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])image = Image.open(io.BytesIO(image_bytes))return my_transforms(image).unsqueeze(0)def get_prediction(image_bytes):tensor = transform_image(image_bytes=image_bytes)outputs = model.forward(tensor)_, y_hat = outputs.max(1)predicted_idx = str(y_hat.item())return imagenet_class_index[predicted_idx]@app.route('/predict', methods=['POST'])
def predict():if request.method == 'POST':file = request.files['file']img_bytes = file.read()class_id, class_name = get_prediction(image_bytes=img_bytes)return jsonify({'class_id': class_id, 'class_name': class_name})if __name__ == '__main__':app.run()
让测试一下的web服务器,运行:
$ FLASK_ENV=development FLASK_APP=app.py flask run
可以使用requests库来发送一个POST请求到的app:
import requestsresp = requests.post("http://localhost:5000/predict",files={"file": open('<PATH/TO/.jpg/FILE>/cat.jpg','rb')})
打印resp.json()会显示下面的结果:
{"class_id": "n02124075", "class_name": "Egyptian_cat"}
6.下一步工作
编写的服务器非常琐碎,可能无法完成生产应用程序所需的一切。因此,可以采取一些措施来改善它:
•   端点/predict假定请求中总会有一个图像文件。这可能不适用于所有请求。的用户可能发送带有其它参数的图像,或者根本不发送任何图像。
•   用户也可以发送非图像类型的文件。由于没有处理错误,因此这将破坏的服务器。添加显式的错误处理路径来引发异常,这将使 能够更好地处理错误的输入
•   即使模型可以识别大量类别的图像,也可能无法识别所有图像。增强实现以处理模型无法识别图像中的任何情况的情况。
•   在开发模式下运行Flask服务器,该服务器不适合在生产中进行部署。可以查看教程 以在生产环境中部署Flask服务器。
•   还可以通过创建一个带有表单的页面来添加UI,该表单可以拍摄图像并显示预测。查看类似项目的演示及其源代码。
•   在本文中,仅展示了如何构建可以一次返回单个图像预测的服务。可以修改服务以能够一次返回多个图像的预测。此外,service-streamer 库自动将对服务的请求排队,并将它们采样到可用于模型的min-batches中。可以查看此教程。
•   最后,鼓励在页面顶部查看链接到的有关部署PyTorch模型的其它教程。

通过带Flask的REST API在Python中部署PyTorch相关推荐

  1. python实现简单的api接口-python中接口的实现实例

    接口基础知识: 简单说下接口测试,现在常用的2种接口就是http api和rpc协议的接口,今天主要说:http api接口是走http协议通过路径来区分调用的方法,请求报文格式都是key-value ...

  2. python的re模块是自带的吗_python内置模块手册 python中的re模块是自带的吗

    python3有哪些内置模块 python内置模块无法调用,pycharm环境,怎么办 python内置模块无法调用,pycharm环境是设置错误造成的,解决方法为: 点击左上角的file菜单 在下拉 ...

  3. VirusTotal api 在 python 中的 URL,域名使用

    URL 发送并扫描URL 首先发送扫描一个url,要向https://www.virustotal.com/vtapi/v2/url/scan 发送一个http post 请求, 其中api 接受请求 ...

  4. 『Python学习笔记』Python中的异步Web框架之fastAPI介绍RestAPI

    Python中的异步Web框架之fastAPI介绍&RestAPI 文章目录 一. fastAPI简要介绍 1.1. 安装 1.2. 创建 1.3. get方法 1.4. post方法 1.5 ...

  5. Python中的Optional和带默认值的参数

    文章目录 带默认值的参数 Typing.Optional类 Optional[X]等价于Union[X, None] 带默认值的参数 在Python中的类或者函数中,若参数在声明时附带了它的默认值,则 ...

  6. 用python中的turtle库绘制一些有趣的图

    最近有个在读大学的女生,想要我帮忙用python画几个图,在画的过程中觉得有些图还挺有意思的,分享给大家.    1 图1    第一个图是蚊香,感兴趣的小伙伴可以自己尝试在python中用turtl ...

  7. 在C++中部署python深度学习-学习笔记

    文章目录 一.简介 二.思路 三.深度学习部署平台和模型部署框架 3.1 部署平台 3.2 部署框架 四.基于TorchScript的PyTorch模型部署 4.1 TorchScript 1.Tra ...

  8. Python中应用决策树算法预测客户等级

    ​机器学习越来越多地在企业应用,本文跟大家分享一个采用python,应用决策树算法对跨国食品超市顾客等级进行预测的具体案例.如果想先行了解决策树算法原理,可以阅读文章决策树-ID3算法和C4.5算法. ...

  9. 【Python常用函数】一文让你彻底掌握Python中的pivot_table函数

    任何事情都是由量变到质变的过程,学习Python也不例外.只有把一个语言中的常用函数了如指掌了,才能在处理问题的过程中得心应手,快速地找到最优方案.本文和你一起来探索Python中的pivot_tab ...

最新文章

  1. 中国大学生创业报告发布
  2. ant学习笔记之(ant执行命令的详细参数和Ant自带的系统属性)
  3. 孙子算经余数C语言,行测数量关系备考:探索《孙子算经》之剩余定理
  4. oracle多个instance,Oracle 数据库EM访问多个Instance
  5. 从flink-example分析flink组件(1)WordCount batch实战及源码分析
  6. 「后端小伙伴来学前端了」Vue中Props配合自定义方法实现组件间的通信
  7. MTK Code Sync Clone
  8. linux eclipse java_从Linux终端编译运行Eclipse Java项目
  9. 华三防火墙h3cf100配置双宽带_H3C新一代F100系列防火墙评测报告
  10. zzulioj 1120: 最值交换
  11. 转子接地保护原理_转子一点接地保护原理示意图
  12. Mysql事务探索及其在Django中的实践(二)
  13. css之背景图固定大小不变、不重复、充满整个页面
  14. Swagger使用总结
  15. 数据存储与容灾(第2版)主编 鲁先志 武春岭综合训练答案
  16. JQ树形菜单加表格混合使用:treeTable组件使用
  17. 济南oracle 认证费用,济南ORACLE管理培训价格
  18. 安卓手机投屏软件_适合智能电视手机投屏的软件
  19. LeetCode第一题——曼哈顿距离
  20. 移动硬盘和电脑内置硬盘使用时的区别

热门文章

  1. 2021-2027全球与中国奶牛冻精市场现状及未来发展趋势
  2. 通过聚合数据API获取微信精选文章
  3. 用python快速画小猪佩奇
  4. 【Spring】spring基于注解的声明式事务控制
  5. python 如何获取当前系统的时间
  6. 王道考研 计算机网络笔记 第五章:传输层
  7. 最新Spring整合MyBatis详解教程
  8. AIFramework框架Jittor特性(下)
  9. Auto ML自动特征工程
  10. 利用硅光子学的移动心脏监护仪