scikit-learn系列之如何存储和导入机器学习模型

如何存储和导入机器学习模型

找到一个准确的机器学习模型,你的项目并没有完成。本文中你将学习如何使用scikit-learn来存储和导入机器学习模型。你可以把你的模型保持到文件中,然后再导入内存进行预测。

1. 用Pickle敲定你的模型

Pickle是python中一种标准的序列化对象的方法。你可以使用pickle操作来序列化你的机器学习算法,保存这种序列化的格式到一个文件中。稍后你可以导入这个文件反序列化你的模型,用它进行新的预测。以下的例子向你展示:如何使用Pima Indians onset of diabetes数据集,训练一个logistic回归模型,保存模型到文件,导入模型对未知数据进行预测。运行以下代码把模型存入你工作路径中的finalized_model.sav,导入模型,用未知数据评估模型的准确率。

# Save Model Using Pickle

import pandas

from sklearn import model_selection

from sklearn.linear_model import LogisticRegression

import pickle

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"

names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']

dataframe = pandas.read_csv(url, names=names)

array = dataframe.values

X = array[:,0:8]

Y = array[:,8]

test_size = 0.33

seed = 7

X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)

# Fit the model on 33%

model = LogisticRegression()

model.fit(X_train, Y_train)

# save the model to disk

filename = 'finalized_model.sav'

pickle.dump(model, open(filename, 'wb'))

# some time later...

# load the model from disk

loaded_model = pickle.load(open(filename, 'rb'))

result = loaded_model.score(X_test, Y_test)

print(result)

2. 用joblib敲定你的模型

Joblib 是SciPy生态的一部分,为管道化python的工作提供的工具。它提供了存储和导入python对象的工具,可以对Numpy数据结构进行有效的利用。这对于要求很多参数和存储整个数据集的算法(比如K-Nearest Neighbors)很有帮助。以下代码向你展示:如何使用Pima Indians onset of diabetes数据集,训练一个logistic回归模型,使用joblib保存模型到文件,导入模型对未知数据进行预测。运行以下代码把模型存入你工作路径中的finalized_model.sav,也会创建一个文件保存Numpy数组,导入模型,用未知数据评估模型的准确率。

# Save Model Using joblib

import pandas

from sklearn import model_selection

from sklearn.linear_model import LogisticRegression

from sklearn.externals import joblib

url = "https://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"

names = ['preg', 'plas', 'pres', 'skin', 'test', 'mass', 'pedi', 'age', 'class']

dataframe = pandas.read_csv(url, names=names)

array = dataframe.values

X = array[:,0:8]

Y = array[:,8]

test_size = 0.33

seed = 7

X_train, X_test, Y_train, Y_test = model_selection.train_test_split(X, Y, test_size=test_size, random_state=seed)

# Fit the model on 33%

model = LogisticRegression()

model.fit(X_train, Y_train)

# save the model to disk

filename = 'finalized_model.sav'

joblib.dump(model, filename)

# some time later...

# load the model from disk

loaded_model = joblib.load(filename)

result = loaded_model.score(X_test, Y_test)

print(result)

3. 保存模型的几点提醒

当你存储你的机器学习模型时,需要考虑以下重要问题。一定要记住,记录下你的工具版本,以便于重构环境。

1. python的版本:记录下python的版本。需要相同大版本号的python来序列化和反序列化模型。

2. 库的版本:主要的库的版本要保持一致,不仅限于Numpy和scikit-learn的版本。

3. 手动序列化:你可能想要手动的输出你的模型参数以便于你可以直接把他们用在scikit-learn或者其他的平台。确实学习算法参数实现比算法本身实现要难得多。如果你有能力也可以自己写代码来导出参数。

4. 知识点:

model_selection.train_test_split

pickle.dump, pickle.load

joblib.dump, joblib.load

python 导入模型_scikit-learn系列之如何存储和导入机器学习模型相关推荐

  1. scikit-learn系列之如何存储和导入机器学习模型

    scikit-learn系列之如何存储和导入机器学习模型 如何存储和导入机器学习模型 找到一个准确的机器学习模型,你的项目并没有完成.本文中你将学习如何使用scikit-learn来存储和导入机器学习 ...

  2. python 训练好的模型怎么保存_如何保存训练好的机器学习模型

    保存训练好的机器学习模型 当我们训练好一个model后,下次如果还想用这个model,我们就需要把这个model保存下来,下次直接导入就好了,不然每次都跑一遍,训练时间短还好,要是一次跑好几天的那怕是 ...

  3. 利用colab保存模型_在Google Colab上训练您的机器学习模型中的“后门”

    利用colab保存模型 Note: This post is for educational purposes only. 注意:此职位仅用于教育目的. In this post, I would f ...

  4. 模型效果差?我建议你掌握这些机器学习模型的超参数优化方法

    模型优化是机器学习算法实现中最困难的挑战之一.机器学习和深度学习理论的所有分支都致力于模型的优化. 机器学习中的超参数优化旨在寻找使得机器学习算法在验证数据集上表现性能最佳的超参数.超参数与一般模型参 ...

  5. 不要再「外包」AI 模型了!最新研究发现:有些破坏机器学习模型安全的「后门」无法被检测到...

    来源:AI科技评论 作者:王玥.刘冰一.黄楠 编辑:陈彩娴 一个不可检测的「后门」,随之涌现诸多潜伏问题,我们距离「真正的」机器安全还有多远? 试想一下,一个植入恶意「后门」的模型,别有用心的人将它隐 ...

  6. 机器学习模型训练_您打算什么时候重新训练机器学习模型

    机器学习模型训练 You may find a lot of tutorials which would help you build end to end Machine Learning pipe ...

  7. ONNX系列三 --- 使用ONNX使PyTorch AI模型可移植

    目录 PyTorch简介 导入转换器 快速浏览模型 将PyTorch模型转换为ONNX 摘要和后续步骤 参考文献 下载源547.1 KB 系列文章列表如下: ONNX系列一 --- 带有ONNX的便携 ...

  8. python机器学习手写字体识别_Python 3 利用机器学习模型 进行手写体数字检测

    0.引言 介绍了如何生成手写体数字的数据,提取特征,借助 sklearn 机器学习模型建模,进行识别手写体数字 1-9 模型的建立和测试. 用到的几种模型: 1. LR,Logistic Regres ...

  9. 新兵训练营系列课程——海量数据存储基础

    2019独角兽企业重金招聘Python工程师标准>>> 新兵训练营系列课程--海量数据存储基础 2015年8月12日 09:24 阅读 16831 微博平台研发作为微博的底层数据及业 ...

最新文章

  1. 一文了解神经网络的基本原理
  2. 那些德艺双馨的网站列表-updating
  3. jsp中jsp中群发邮件群发邮件
  4. ARM之MMU工作原理分析
  5. c#串口程序接收数据并打印_C#程序可打印各种数据类型的大小
  6. webbench接口并发测试
  7. 认真学习系列:编译原理——B站笔记
  8. JavaEE系统架构师学习路线
  9. 【转】ASP.NET AJAX入门系列(8):使用ScriptManager控件
  10. macos 10.15.2 iso镜像网盘下载
  11. c3p0连接池配置及实现详解
  12. 计算机网络安全讲座心得,网络安全知识培训心得体会
  13. 明尼苏达大学计算机工程,关于美国明尼苏达大学电气与计算机工程系洪明毅博后学术报告的通知...
  14. 交换机 Port-Channel(端口汇聚)的 配置
  15. 服务器网卡修复工具,网络适配器无法启动如何修复 不能启动的处理办法
  16. PCB 铜厚厚度和线宽的选择
  17. 计算机基础及ms office应用,计算机基础及MS OFFICE应用(2020年版)/全国计算机等级考试一级教程...
  18. C++ include头文件之后为什么还要在编译的时候加--lxxx
  19. CVPR 2018 论文解读集锦
  20. 医院CRM管理中随访的重要性

热门文章

  1. 经典Robocode例子代码--SnippetBot
  2. 删除链接文件 linux,rm 删除链接文件的问题
  3. 2018年中考计算机考试成绩,2018年中考录取分数汇总,35所初中考成绩看2019中考!...
  4. 谷歌应用商店现木马程序、百万WiFi路由器面临漏洞风险|12月6日全球网络安全热点
  5. 电力系统的延时功率流 (CPF)的计算【 IEEE-14节点】(Matlab代码实现)
  6. 用python玩转数据慕课答案第三周_大学慕课用Python玩转数据章节测验答案
  7. Vi/vim编辑文件无法保存和退出的解决方法
  8. 电容触摸测试MCU的灵活性
  9. c#串口模拟互发数据(COM1-COM2)
  10. 【每日新闻】2017年亚马逊研发投入排世界第一,超过华为、BAT 总和 | 数人云宣布与UMCloud合并