Python 作为当前机器学习中使用最多的一门编程语言,有很多对应的机器学习库,最常用的莫过于 scikit-learn 了。本文我们介绍下如何使用sklearn进行实时预测。先来看下典型的机器学习工作流。

绿色方框圈出来的表示将数据切分为训练集和测试集。

红色方框的上半部分表示对训练数据进行特征处理,然后再对处理后的数据进行训练,生成 model。

红色方框的下半部分表示对测试数据进行特征处理,然后使用训练得到的 model 进行预测。

红色方框的右下角部分表示对模型进行评估,评估可以分为离线和在线。


典型的 ML 模型

介绍完了典型的机器学习工作流了之后,来看下典型的 ML 模型。

import numpy as npimport pandas as pdfrom sklearn.datasets import load_irisfrom sklearn.ensemble import RandomForestClassifier

# 加载鸢尾花数据iris = load_iris()# 创建包含特征名称的 DataFramedf = pd.DataFrame(iris.data, columns=iris.feature_names)df['species'] = pd.Categorical.from_codes(iris.target, iris.target_names)

# 生成标记,切分训练集、测试集df['is_train'] = np.random.uniform(0, 1, len(df)) <= .75train, test = df[df['is_train']==True], df[df['is_train']==False]

# 生成 X 和 yfeatures = df.columns[:4]y = pd.factorize(train['species'])[0]

model = RandomForestClassifier(n_jobs=2)

# 训练模型model.fit(train[features], y)# 预测数据model.predict(test[features])

上面的模型对鸢尾花数据进行训练生成一个模型,之后该模型对测试数据进行预测,预测结果为每条数据属于哪种类别。

模型的保存和加载

上面我们已经训练生成了模型,但是如果我们程序关闭后,保存在内存中的模型对象也会随之消失,也就是说下次如果我们想要使用模型预测时,需要重新进行训练。如何解决这个问题呢?

很简单,既然内存中的对象会随着程序的关闭而消失,我们能不能将训练好的模型保存成文件。如果需要预测的话,直接从文件中加载生成模型呢?答案是可以的。

sklearn 提供了 joblib 模型,能够实现完成模型的保存和加载。

from sklearn.externals import joblib

# 保存模型到 model.joblib 文件joblib.dump(model, "model.joblib" ,compress=1)

# 加载模型文件,生成模型对象new_model = joblib.load("model.joblib")

new_pred_data = [[0.5, 0.4, 0.7, 0.1]]# 使用加载生成的模型预测新样本new_model.predict(new_pred_data)


构建实时预测

前面说到的运行方式是在离线环境中运行,在真实世界中,我们很多时候需要在线实时预测。一种解决方案是将模型服务化,在我们这个场景就是,我告诉你一个鸢尾花的 sepal_length, sepal_width, petal_length, petal_width 之后,你能够快速告诉我这个鸢尾花的类型,借助 flask 等 web 框架,开发一个 web service,实现实时预测。

因为依赖于 flask 框架,没有安装的需要安装下:

pip install flask

创建一个 ml_web.py 文件,内容如下:

# coding=utf-8from urlparse import urljoin

import flaskfrom flask import Flask, request, url_for, Responsefrom sklearn.externals import joblib

app = Flask(__name__)

# 加载模型model = joblib.load("model.joblib")

@app.route("/", methods=["GET"])def index():    with app.test_request_context():        # 生成每个函数监听的url以及该url的参数        result = {"predict_iris": {"url": url_for("predict_iris"),                                   "params": ["sepal_length", "sepal_width", "petal_length", "petal_width"]}}

        result_body = flask.json.dumps(result)

        return Response(result_body, mimetype="application/json")

@app.route("/ml/predict_iris", methods=["GET"])def predict_iris():    request_args = request.args

    # 如果没有传入参数,返回提示信息    if not request_args:        result = {            "message": "请输入参数:sepal_length, sepal_width, petal_length, petal_width"        }        result_body = flask.json.dumps(result, ensure_ascii=False)        return Response(result_body, mimetype="application/json")

    # 获取请求参数    sepal_length = float(request_args.get("sepal_length", "-1"))    sepal_width = float(request_args.get("sepal_width", "-1"))    petal_length = float(request_args.get("petal_length", "-1"))    petal_width = float(request_args.get("petal_width", -1))

    # 构建特征矩阵    vec = [[sepal_length, sepal_width, petal_length, petal_width]]    print("vec: {0}".format(vec))

    # 生成预测结果    predict_result = int(model.predict(vec)[0])    print("predict_result: {0}".format(predict_result))

    # 构造返回数据    result = {        "features": {            "sepal_length": sepal_length,            "sepal_width": sepal_width,            "petal_length": petal_length,            "petal_width": petal_width        },        "result": predict_result    }

    result_body = flask.json.dumps(result, ensure_ascii=False)    return Response(result_body,  mimetype="application/json")

if __name__ == "__main__":    app.run(port=8000)

在命令行启动它:

$ python ml_web.py * Running on http://127.0.0.1:8000/ (Press CTRL+C to quit)

在 PostMan(也可以在浏览器中打开) 中打开 http://127.0.0.1:8000/ml/predict_iris ,得到以下结果:

可以看到,这里提示我们输入 sepal_length, sepal_width, petal_length, petal_width 参数,所以我们需要添加上参数重新构造一个请求 url:http://127.0.0.1:8000/ml/predict_iris?sepal_length=10&sepal_width=1&petal_length=3&petal_width=2。

再次请求得到的结果如下:

可以看到,模型返回的结果为 2,也就是说模型认为这个鸢尾花的类别是 2。


总结

在真实世界中,我们经常需要将模型进行服务化。这里我们借助 flask 框架,将 sklearn 训练后生成的模型文件加载到内存中,针对每次请求传入不同的特征来实时返回不同的预测结果。

作者:1or0,专注于机器学习研究。

声明:本文为公众号 AI派 投稿,版权归对方所有。

如何构建真实世界可用的 ML 模型?相关推荐

  1. 移动应用AI化成新战场?详解苹果最新Core ML模型构建基于机器学习的智能应用...

    Google刚刚息鼓,苹果又燃战火!这一战,来自移动应用的AI化之争. 近日,苹果发布专为移动端优化的Core ML后,移动开发者对此的需求到底有多强烈?去年大获成功的AI应用Prisma又能告诉我们 ...

  2. 构建高可用LVS + keepalived+httpd和双主模型的keepalived方案

    ↑构建高可用LVS + keepalived+httpd和双主模型的keepalived方案↑ 标签:web服务器 拓扑图 模型 检测 软件 原创作品,允许转载,转载时请务必以超链接形式标明文章 原始 ...

  3. 使用YOLO Core ML模型构建对象检测iOS应用(七)

    目录 在我们的应用程序中添加模型 在捕获的视频帧上运行目标检测 绘制边界框 实际应用 下一步? 总目录 将ONNX对象检测模型转换为iOS Core ML(一) 解码Core ML YOLO对象检测器 ...

  4. apache beam_Apache Beam ML模型部署

    apache beam This blog post builds on the ideas started in three previous blog posts. 这篇博客文章基于之前 三篇 博 ...

  5. 使用Amazon SageMaker构建高质量AI作画模型Stable Diffusion

    使用Amazon SageMaker构建高质量AI作画模型Stable Diffusion 0. 前言 1. Amazon SageMaker 与机器学习 1.1 机器学习流程 1.2 Amazon ...

  6. 几行代码搞定ML模型,低代码机器学习Python库正式开源

    公众号关注 "视学算法" 设为 "星标",消息即可送达! 机器之心报道 机器之心编辑部 PyCaret 库支持在「低代码」环境中训练和部署有监督以及无监督的机器 ...

  7. 太赞了!NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...

  8. NumPy 手写所有主流 ML 模型,由普林斯顿博士后 David Bourgin打造的史上最强机器学习基石项目!...

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 用 NumPy 手写所有主流 ML 模型,普林斯顿博士后 David Bourgi ...

  9. AutoML大提速,谷歌开源自动化寻找最优ML模型新平台

    为了帮助研究者自动.高效地开发最佳机器学习模型,谷歌开源了一个不针对特定领域的 AutoML 平台.该平台基于 TensorFlow 构建,非常灵活,既可以找出最适合给定数据集和问题的架构,也能够最小 ...

最新文章

  1. SOAR SQL进行优化和改写的自动化工具
  2. 管道、通道、管程的区别
  3. Memory Ordering
  4. RabbitMQ系列-顺序消费模式和迅速消息发送模式
  5. 体验.NET Core使用IKVM对接Java
  6. Docker基础入门及示例
  7. python几种括号表示的类型
  8. MySQL多实例配置
  9. Dubbo视频教程《基于Dubbo的分布式系统架构视频教程》----课程列表
  10. 【书影观后感 八】《周期》万事皆周期
  11. Android解决Can't create handler inside thread that has not called Looper.prepare()
  12. 基于Tensorflow 2.x手动复现BERT
  13. mysql没有exe_MySQL解压之后没有exe程序,怎么解决,怎么安装访问
  14. CCF 201709-2 公共钥匙盒 (Java 100分)
  15. echarts图例颜色与地图底色
  16. gcc/gdb/make/动/静态链接库介绍
  17. 敏捷(Agile)是什么?--参加优普丰CSM认证培训有感
  18. JMeter,LoadRunner,软件压力测试?
  19. pdf压缩工具哪个好用?pdf压缩工具推荐?
  20. js 正则输入验证 整数 两位小数 三位小数

热门文章

  1. linux挂载硬盘_Linux把内存挂载成硬盘提高读写速度-内存虚拟盘
  2. 可穿戴医疗设备行业调研报告 - 市场现状分析与发展前景预测
  3. SpringBoot动态切换数据源-快速集成多数据源的启动器
  4. IT-游戏 学习资源思维导图(持续更新,欢迎关注点赞加评论)
  5. js金额格式化最简单方法 JS对货币格式化,js钱三位一隔,javascript货币格式化
  6. 给所有的input设置 autocomplete=off
  7. sql sever 2008 使用SSMS实现自动备份,每周一次,完整备份
  8. 老式Windows桌面的终结:Windows 11来了,DaaS还会远吗?
  9. 微软 CEO 纳德拉痛失爱子
  10. 1.6 万亿参数你怕了吗?谷歌大脑语言模型速度是 T5 速度的 7 倍