前言

本文针对业余范围的Pytorch模型部署,类似各位想把自己开发的深度学习模型上线web端demo等等。

大家比较熟悉的Python框架主要有flask,使用flask部署上线深度学习模型过程简单,只需要在应用初始化时构造模型(model),在视图函数里调用模型前馈推理(inference)即可。这种做法的弊端在于,服务器将每个请求独立加载到GPU执行,而熟悉深度学习的朋友们都知道将输入张量堆叠成batch才能更有效利用GPU资源。

因此本文介绍的Sanic框架能便利地将多个用户请求集合成一个batch执行,执行完后再拆分开将结果返回给用户。本文将以cycleGAN的执行为例,其他model处理过程相似。

本文引用的代码来自《DeepLearning with Pytorch》,讲解为博主原创。

Sanic框架简介

Sanic框架利用了Python 3.6版本后新增的异步功能,是一个支持 async/await 语法的异步无阻塞框架。

Sanic框架与Flask框架同为Python环境下轻量级web框架,功能用法类似,已有flask基础的读者会很快适应本文,因此本文将重点介绍使用Sanic框架部署pytorch程序的思路,具体Web框架的执行细节可以参阅其他资料。Sanic框架入门指南可以在这里获得,这里也有相关函数介绍,官方文档在这里。

初始化Sanic应用

# 这里与flask框架类似,初始化一个Sanic app对象
app = Sanic(__name__)
# 定义设备
device = torch.device('cuda:0')
# 定义一些全局参数
MAX_QUEUE_SIZE = 3  # 队列最大长度,即容许等待的最大队列长度,超过则返回“too busy”
MAX_BATCH_SIZE = 2  # 单个batch内最大容量,一般取决于显存
MAX_WAIT = 1  # 从第一个batch被加入队列起的最大等待时间

构造模型

这部分我们初始化一个ModelRunner类,类中将会构造pytorch模型,构造栈等等

class ModelRunner:def __init__(self, model_name):self.model_name = model_nameself.queue = []self.queue_lock = None# 将这里修改为构造自己的model,确保output=self.model(input)self.model = get_pretrained_model(self.model_name, map_location=device)# 这是一个事件,顾名思义,代表当下是否需要执行一批(一个batch)的推理self.needs_processing = None# 为needs_processing事件准备的计时器self.needs_processing_timer = None

构造一个循环执行的函数model_runner

ModelRunner内的函数,这个函数创建了协程锁和协程事件(触发后就执行推理的事件)后,就进入无限循环,事件每触发一次,就会执行一次推理。

    async def model_runner(self):"""在Sanic主循环(app.loop)启动之后执行一些后台任务"""# 初始化了协程锁和协程事件# asyncio.Lock 参见 https://docs.python.org/zh-cn/3/library/asyncio-sync.html#lock# asyncio.Event 参见 https://docs.python.org/zh-cn/3/library/asyncio-sync.html#asyncio.Eventself.queue_lock = asyncio.Lock(loop=app.loop)  # 实现一个用于 asyncio 任务的互斥锁,保证对共享资源的独占访问(即访问时保证独占app.loop)self.needs_processing = asyncio.Event(loop=app.loop)logger.info("started model runner for {}".format(self.model_name))# 以下代码将一直执行while True:await self.needs_processing.wait()  # 等待,直至needs_processing事件被触发,即开始处理一个batch# 清空needs_processing事件和计时器,等待下一次用self.needs_processing.clear()if self.needs_processing_timer is not None:self.needs_processing_timer.cancel()self.needs_processing_timer = Noneasync with self.queue_lock:  # 相当于等待,直到保证独占app.loop,并上协程锁,执行完后释放锁if self.queue:  # 队列非空,计算最大等待时间longest_wait = app.loop.time() - self.queue[0]["time"]else:  # 队列空longest_wait = Nonelogger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))# 确保队列长度未溢出to_process = self.queue[:MAX_BATCH_SIZE]del self.queue[:len(to_process)]self.schedule_processing_if_needed()# 将等待队列中的输入张量堆叠成一个Batchbatch = torch.stack([t["input"] for t in to_process], dim=0)# 在新线程中执行模型推理result = await app.loop.run_in_executor(None, functools.partial(self.run_model, batch))# 将结果拆开,放到队列中各个请求(字典)对应的output键值中for t, r in zip(to_process, result):t["output"] = rt["done_event"].set()del to_processdef run_model(self, batch):"""执行模型的函数,这里只需要一步,根据实际需要调整"""return self.model(batch.to(device)).to('cpu')

其实这个函数是在Sanic应用被创建后就被调用的,在外层代码中有这样的安排:

app.add_task(style_transfer_runner.model_runner())

构造视图函数

为处理用户请求准备的路由函数,HTTP请求与响应的基础知识大家可以查阅其他资料。

路由函数代表了从用户发出HTTP请求后的操作,即包括接收用户以字节形式上传的图片、图片预处理、放到model中执行、返回给用户执行结果。

@app.route('/image', methods=['PUT'], stream=True)
async def image(request):try:print(request.headers)content_length = int(request.headers.get('content-length', '0'))MAX_SIZE = 2 ** 22  # 接收图片的最大大小,这里设置为:10MBif content_length:if content_length > MAX_SIZE:raise HandlingError("Too large")data = bytearray(content_length)else:data = bytearray(MAX_SIZE)pos = 0while True:# so this still copies too much stuff.data_part = await request.stream.read()if data_part is None:breakdata[pos: len(data_part) + pos] = data_partpos += len(data_part)if pos > MAX_SIZE:raise HandlingError("Too large")# 对图片流数据使用PIL打开并做必要预处理,预处理过程要根据自己的模型调整im = PIL.Image.open(io.BytesIO(data))im = torchvision.transforms.functional.resize(im, (228, 228))im = torchvision.transforms.functional.to_tensor(im)if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:raise HandlingError("need rgb image")# 真正的核心代码,使用runner函数处理输入得到结果out_im = await style_transfer_runner.process_input(im)# 将结果使用IO流输出给用户out_im = torchvision.transforms.functional.to_pil_image(out_im)imgByteArr = io.BytesIO()out_im.save(imgByteArr, format='JPEG')return sanic.response.raw(imgByteArr.getvalue(), status=200,content_type='image/jpeg')except HandlingError as e:# 错误处理return sanic.response.text(e.handling_msg, status=e.handling_code)

沟通路由函数与model runner的process_input

定义在ModelRunner类中,将单个用户请求堆叠成batch,判断是否执行一个batch,等待执行后返回结果分发给用户

    async def process_input(self, input):"""路由函数将请求导引至这里,这里将会把请求集合成batch(或者报告正忙"""our_task = {"done_event": asyncio.Event(loop=app.loop),  # 代表这批任务是否被处理完的事件"input": input,"time": app.loop.time()}async with self.queue_lock:# 若队列已满则报告正忙if len(self.queue) >= MAX_QUEUE_SIZE:raise HandlingError("I'm too busy", code=503)# 加入队列self.queue.append(our_task)logger.debug("enqueued task. new queue size {}".format(len(self.queue)))# 决定是否需要安排一波processself.schedule_processing_if_needed()# 等待处理完await our_task["done_event"].wait()# 返回结果return our_task["output"]

判断是否需要推理一个batch的schedule_processing_if_needed

显然,当队列长度已满或达到最大等待时间,需要送一个batch进模型推理,这里的self.needs_processing.set()就代表了设置当下需要处理。

    def schedule_processing_if_needed(self):"""这个函数决定是否需要安排一波模型推理"""if len(self.queue) >= MAX_BATCH_SIZE:# 若队列长度已满,则安排一波推理(将needs_processing设为触发,即为设置需要处理这个batch)logger.debug("next batch ready when processing a batch")self.needs_processing.set()elif self.queue:# 若队列长度未满但队列非空,当队列中第一个请求达到最大等待时间时,触发needs_processing事件执行logger.debug("queue nonempty when processing a batch, setting next timer")self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)
以上是各部分的分别介绍,整体代码如下
import sys
import asyncio
import itertools
import functools
from sanic import Sanic
from sanic.response import json, text
from sanic.log import logger
from sanic.exceptions import ServerErrorimport sanic
import threading
import PIL.Image
import io
import torch
import torchvision
from .cyclegan import get_pretrained_model# 这里与flask框架类似,初始化一个Sanic app对象
app = Sanic(__name__)device = torch.device('cuda:0')
# 定义一些全局参数
MAX_QUEUE_SIZE = 3  # 队列最大长度,即容许等待的最大队列长度,超过则返回“too busy”
MAX_BATCH_SIZE = 2  # 单个batch内最大容量,一般取决于显存
MAX_WAIT = 1  # 从第一个batch被加入队列起的最大等待时间class HandlingError(Exception):def __init__(self, msg, code=500):super().__init__()self.handling_code = codeself.handling_msg = msgclass ModelRunner:def __init__(self, model_name):self.model_name = model_nameself.queue = []self.queue_lock = None# 将这里修改为构造自己的model,确保output=self.model(input)self.model = get_pretrained_model(self.model_name, map_location=device)# 这是一个事件,顾名思义,代表当下是否需要执行一批(一个batch)的推理self.needs_processing = None# 为needs_processing事件准备的计时器self.needs_processing_timer = Nonedef schedule_processing_if_needed(self):"""这个函数决定是否需要安排一波模型推理"""if len(self.queue) >= MAX_BATCH_SIZE:# 若队列长度已满,则安排一波推理(将needs_processing设为触发,即为设置需要处理这个batch)logger.debug("next batch ready when processing a batch")self.needs_processing.set()elif self.queue:# 若队列长度未满但队列非空,当队列中第一个请求达到最大等待时间时,触发needs_processing事件执行logger.debug("queue nonempty when processing a batch, setting next timer")self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)async def process_input(self, input):"""路由函数将请求导引至这里,这里将会把请求集合成batch(或者报告正忙"""our_task = {"done_event": asyncio.Event(loop=app.loop),  # 代表这批任务是否被处理完的事件"input": input,"time": app.loop.time()}async with self.queue_lock:# 若队列已满则报告正忙if len(self.queue) >= MAX_QUEUE_SIZE:raise HandlingError("I'm too busy", code=503)# 加入队列self.queue.append(our_task)logger.debug("enqueued task. new queue size {}".format(len(self.queue)))# 决定是否需要安排一波processself.schedule_processing_if_needed()# 等待处理完await our_task["done_event"].wait()# 返回结果return our_task["output"]def run_model(self, batch):"""执行模型的函数,这里只需要一步,根据实际需要调整"""return self.model(batch.to(device)).to('cpu')async def model_runner(self):"""在Sanic主循环(app.loop)启动之后执行一些后台任务"""# 初始化了协程锁和协程事件# asyncio.Lock 参见 https://docs.python.org/zh-cn/3/library/asyncio-sync.html#lock# asyncio.Event 参见 https://docs.python.org/zh-cn/3/library/asyncio-sync.html#asyncio.Eventself.queue_lock = asyncio.Lock(loop=app.loop)  # 实现一个用于 asyncio 任务的互斥锁,保证对共享资源的独占访问(即访问时保证独占app.loop)self.needs_processing = asyncio.Event(loop=app.loop)logger.info("started model runner for {}".format(self.model_name))# 以下代码将一直执行while True:await self.needs_processing.wait()  # 等待,直至needs_processing事件被触发,即开始处理一个batch# 清空needs_processing事件和计时器,等待下一次用self.needs_processing.clear()if self.needs_processing_timer is not None:self.needs_processing_timer.cancel()self.needs_processing_timer = Noneasync with self.queue_lock:  # 相当于等待,直到保证独占app.loop,并上协程锁,执行完后释放锁if self.queue:  # 队列非空,计算最大等待时间longest_wait = app.loop.time() - self.queue[0]["time"]else:  # 队列空longest_wait = Nonelogger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))# 确保队列长度未溢出to_process = self.queue[:MAX_BATCH_SIZE]del self.queue[:len(to_process)]self.schedule_processing_if_needed()# 将等待队列中的输入张量堆叠成一个Batchbatch = torch.stack([t["input"] for t in to_process], dim=0)# 在新线程中执行模型推理result = await app.loop.run_in_executor(None, functools.partial(self.run_model, batch))# 将结果拆开,放到队列中各个请求(字典)对应的output键值中for t, r in zip(to_process, result):t["output"] = rt["done_event"].set()del to_processstyle_transfer_runner = ModelRunner(sys.argv[1])@app.route('/image', methods=['PUT'], stream=True)
async def image(request):try:print(request.headers)content_length = int(request.headers.get('content-length', '0'))MAX_SIZE = 2 ** 22  # 接收图片的最大大小,这里设置为:10MBif content_length:if content_length > MAX_SIZE:raise HandlingError("Too large")data = bytearray(content_length)else:data = bytearray(MAX_SIZE)pos = 0while True:# so this still copies too much stuff.data_part = await request.stream.read()if data_part is None:breakdata[pos: len(data_part) + pos] = data_partpos += len(data_part)if pos > MAX_SIZE:raise HandlingError("Too large")# 对图片流数据使用PIL打开并做必要预处理,预处理过程要根据自己的模型调整im = PIL.Image.open(io.BytesIO(data))im = torchvision.transforms.functional.resize(im, (228, 228))im = torchvision.transforms.functional.to_tensor(im)if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:raise HandlingError("need rgb image")# 真正的核心代码,使用runner函数处理输入得到结果out_im = await style_transfer_runner.process_input(im)# 将结果使用IO流输出给用户out_im = torchvision.transforms.functional.to_pil_image(out_im)imgByteArr = io.BytesIO()out_im.save(imgByteArr, format='JPEG')return sanic.response.raw(imgByteArr.getvalue(), status=200,content_type='image/jpeg')except HandlingError as e:# 错误处理return sanic.response.text(e.handling_msg, status=e.handling_code)# 在Sanic主循环(app.loop)启动之后执行一些后台任务
app.add_task(style_transfer_runner.model_runner())
# 启动Sanic应用
app.run(host="0.0.0.0", port=8000, debug=True)

Sanic框架下部署Pytorch模型相关推荐

  1. 学习笔记|Flask部署Pytorch模型+Gunicorn+Docker

    一.使用Flask部署Pytorch模型 其实原理很简单,我们希望使用一个已经训练好的pytorch模型,用它做预测或生成.我们的模型部署在服务器上,客户端可以通过http request调用我们部署 ...

  2. Keras框架下的保存模型和加载模型

    在Keras框架下训练深度学习模型时,一般思路是在训练环境下训练出模型,然后拿训练好的模型(即保存模型相应信息的文件)到生产环境下去部署.在训练过程中我们可能会遇到以下情况: 需要运行很长时间的程序在 ...

  3. 在C++平台上部署PyTorch模型流程+踩坑实录

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导读 本文主要讲解如何将pytorch的模型部署到c++平台上的模 ...

  4. 经验 | 在C++平台上部署PyTorch模型流程+踩坑实录

    点上方蓝字计算机视觉联盟获取更多干货 在右上方 ··· 设为星标 ★,与你不见不散 仅作学术分享,不代表本公众号立场,侵权联系删除 转载于:作者丨火星少女@知乎 来源丨https://zhuanlan ...

  5. 西北乱跑娃 --- bottle框架部署pytorch模型

    一.前言 在网站荡了很多关于深度学习的代码,包括在码云.github和各种博客.但是作者毕竟是一个文科生,也没有怎么接触过深度学习分类的代码,加上那些文章中的代码文件太多,极其晦涩难懂所以在之前的半年 ...

  6. 在Android设备部署PyTorch模型

    Pytorch Mobile Android Demo 1 HelloWorldApp 1 模型准备 2 源码分析 3 读取图片数据 4 读取模型 5 将图像转换为Tensor 6 运行模型 7 处理 ...

  7. 解锁新姿势-使用TensorRT部署pytorch模型

    一.整体流程概览 使用pytorch训练模型,生成*.pth文件 将*.pth转换成onnx模型 在tensorrt中加载onnx模型,并转换成trt的object 在trt中使用第三步转换的obje ...

  8. flask部署pytorch模型

    项目代码: https://pan.baidu.com/s/1-FdTk7XjryvUsZR9CW9T3g 提取码:6uo5 该项目上传至阿里云仓库:docker--构建自己的项目(阿里云仓库)| d ...

  9. 基于C++的PyTorch模型部署

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 引言 PyTorch作为一款端到端的深度学习框架,在1.0版本之后 ...

最新文章

  1. 【BZOJ 3160】 3160: 万径人踪灭 (FFT)
  2. caioj 1077 动态规划入门(非常规DP1:筷子)
  3. (原创) 对饱和状态NPN晶体管内部机制的理解分析
  4. Android bootchart分析
  5. Qt文档阅读笔记|Qt工作笔记-QMutexLocker的使用(抛出异常也能解锁)
  6. [转]Linux之ACL权限
  7. (31)FPGA面试题系统最高速度计算方法
  8. 3d激光雷达开发(基于统计滤波)
  9. python爬虫知乎问答
  10. Http协议及其实现httpd
  11. 提取pdf目录的方法
  12. python tkinter.Text 高级用法 -- 设计功能齐全的文本编辑器
  13. linux开源项目github,GitHub 上的优质 Linux 开源项目,真滴牛逼!
  14. python_习题练习_5_小游戏《唐僧大战白骨精》
  15. Python学习指南(看完不迷路)
  16. \Qt5\\bin\\d3dcompiler_47.dll
  17. 以下是一些提供技术专利申请模板的中文网站,供您参考
  18. 操作系统虚拟存储器实验---Python实现
  19. 三种实现分布式锁的方式
  20. Solidity之事件

热门文章

  1. java添加@Data注解
  2. 新零售发展蓝海|全球无人零售货柜与无人便利店趋势兴起
  3. win7升级Win10之360百度等升级助手均因系统未激活不可升级
  4. 初识OpenGL (-)坐标系统(Coordinate System)
  5. 点餐系统---------软件工程课程设计
  6. Unity 音频插件 - MasterAudio 实现音频管理系统
  7. 少年不惧岁月长,奋楫笃行国学香
  8. 程序猿的怎么软件园蹦出来
  9. 墨者WordPress插件漏洞分析溯源
  10. 环洋市场调研-2021年全球颜料红48:2行业调研及趋势分析报告