本篇文章转载自博客园,作者: 刘建平Pinard

在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法优化的PMML文件大多数时候很笨拙,因此本文我们专门讨论下tensorflow机器学习模型的跨平台上线的方法。

一、tensorflow模型的跨平台上线的备选方案

tensorflow模型的跨平台上线的备选方案一般有三种:即PMML方式,tensorflow serving方式,以及跨语言API方式。PMML方式的主要思路在上一篇以及讲过。这里唯一的区别是转化生成PMML文件需要用一个Java库jpmml-tensorflow来完成,生成PMML文件后,跨语言加载模型和其他PMML模型文件基本类似。

tensorflow serving是tensorflow 官方推荐的模型上线预测方式,它需要一个专门的tensorflow服务器,用来提供预测的API服务。如果你的模型和对应的应用是比较大规模的,那么使用tensorflow serving是比较好的使用方式。但是它也有一个缺点,就是比较笨重,如果你要使用tensorflow serving,那么需要自己搭建serving集群并维护这个集群。所以为了一个小的应用去做这个工作,有时候会觉得麻烦。

跨语言API方式是本文要讨论的方式,它会用tensorflow自己的Python API生成模型文件,然后用tensorflow的客户端库比如Java或C++库来做模型的在线预测。下面我们会给一个生成生成模型文件并用tensorflow Java API来做在线预测的例子。

二、训练模型并生成模型文件

我们这里给一个简单的逻辑回归并生成逻辑回归tensorflow模型文件的例子。完整代码参见我的github:https://github.com/ljpzzz/machinelearning/blob/master/model-in-product/tensorflow-java

首先,我们生成了一个6特征,3分类输出的4000个样本数据。

import numpy as npimport matplotlib.pyplot as plt%matplotlib inlinefrom sklearn.datasets.samples_generator import make_classificationimport tensorflow as tfX1, y1 = make_classification(n_samples=4000, n_features=6, n_redundant=0,                             n_clusters_per_class=1, n_classes=3)

接着我们构建tensorflow的数据流图,这里要注意里面的两个名字,第一个是输入x的名字input,第二个是输出prediction_labels的名字output,这里的这两个名字可以自己取,但是后面会用到,所以要保持一致。

learning_rate = 0.01training_epochs = 600batch_size = 100x = tf.placeholder(tf.float32, [None, 6],name='input') # 6 featuresy = tf.placeholder(tf.float32, [None, 3]) # 3 classesW = tf.Variable(tf.zeros([6, 3]))b = tf.Variable(tf.zeros([3]))# softmax回归pred = tf.nn.softmax(tf.matmul(x, W) + b, name="softmax") cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)prediction_labels = tf.argmax(pred, axis=1, name="output")init = tf.global_variables_initializer()

接着就是训练模型了,代码比较简单,毕竟只是一个演示:

sess = tf.Session()sess.run(init)y2 = tf.one_hot(y1, 3)y2 = sess.run(y2)for epoch in range(training_epochs):    _, c = sess.run([optimizer, cost], feed_dict={x: X1, y: y2})    if (epoch+1) % 10 == 0:        print ("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(c))    print ("优化完毕!")correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y2, 1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))acc = sess.run(accuracy, feed_dict={x: X1, y: y2})print (acc)

打印输出我这里就不写了,大家可以自己去试一试。接着就是关键的一步,存模型文件了,注意要用convert_variables_to_constants这个API来保存模型,否则模型参数不会随着模型图一起存下来。

graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output"])tf.train.write_graph(graph, '.', 'rf.pb', as_text=False)

至此,我们的模型文件rf.pb已经被保存下来了,下面就是要跨平台上线了。 

三、模型文件在Java平台上线

这里我们以Java平台的模型上线为例,C++的API上线我没有用过,这里就不写了。我们需要引入tensorflow的java库到我们工程的maven或者gradle文件。这里给出maven的依赖如下,版本可以根据实际情况选择一个较新的版本。

<dependency>    <groupId>org.tensorflowgroupId>    <artifactId>tensorflowartifactId>    <version>1.7.0version>dependency>

接着就是代码了,这个代码会比JPMML的要简单,我给出了4个测试样本的预测例子如下,一定要注意的是里面的input和output要和训练模型的时候对应的节点名字一致。

import org.tensorflow.*;import org.tensorflow.Graph;import java.io.IOException;import java.nio.file.Files;import java.nio.file.Paths;/** * Created by 刘建平pinard on 2018/7/1. */public class TFjavaDemo {    public static void main(String args[]){        byte[] graphDef = loadTensorflowModel("D:/rf.pb");        float inputs[][] = new float[4][6];        for(int i = 0; i< 4; i++){            for(int j =0; j< 6;j++){                if(i<2) {                    inputs[i][j] = 2 * i - 5 * j - 6;                }                else{                    inputs[i][j] = 2 * i + 5 * j - 6;                }            }        }        Tensor input = covertArrayToTensor(inputs);        Graph g = new Graph();        g.importGraphDef(graphDef);        Session s = new Session(g);        Tensor result = s.runner().feed("input", input).fetch("output").run().get(0);        long[] rshape = result.shape();        int rs = (int) rshape[0];        long realResult[] = new long[rs];        result.copyTo(realResult);        for(long a: realResult ) {            System.out.println(a);        }    }    static private byte[] loadTensorflowModel(String path){        try {            return Files.readAllBytes(Paths.get(path));        } catch (IOException e) {            e.printStackTrace();        }        return null;    }    static private TensorcovertArrayToTensor(float inputs[][]){        return Tensors.create(inputs);    }}

我的预测输出是1,1,0,0,供大家参考。

四、一点小结

对于tensorflow来说,模型上线一般选择tensorflow serving或者client API库来上线,前者适合于较大的模型和应用场景,后者则适合中小型的模型和应用场景。因此算法工程师使用在产品之前需要做好选择和评估。

往期精彩:

深度学习多种模型评估指标介绍 - 附sklearn实现

干货 | Attention注意力机制超全综述

Tensorflow常用函数使用说明及实例简记

机器学习中优化相关理论知识简述

自己动手实现一个神经网络多分类器

Transformer 模型的 PyTorch 实现

干货 | NLP中的十个预训练模型

干货|一文弄懂机器学习中偏差和方差

FastText原理和文本分类实战,看这一篇就够了

Transformer模型细节理解及Tensorflow实现

GPT,GPT2,Bert,Transformer-XL,XLNet论文阅读速递

机器学习算法篇:最大似然估计证明最小二乘法合理性

Word2vec, Fasttext, Glove, Elmo, Bert, Flair训练词向量教程+数据+源码

别偷偷摸摸的在看,有用就点个好看呀

机器学习基于skcilearn tensorflow电子书_Tensorflow机器学习模型的跨平台上线相关推荐

  1. tensorflow机器学习模型的跨平台上线

    在用PMML实现机器学习模型的跨平台上线中,我们讨论了使用PMML文件来实现跨平台模型上线的方法,这个方法当然也适用于tensorflow生成的模型,但是由于tensorflow模型往往较大,使用无法 ...

  2. 用PMML实现机器学习模型的跨平台上线

    在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...

  3. java调用pmml_用PMML实现机器学习模型的跨平台上线

    在机器学习用于产品的时候,我们经常会遇到跨平台的问题.比如我们用Python基于一系列的机器学习库训练了一个模型,但是有时候其他的产品和项目想把这个模型集成进去,但是这些产品很多只支持某些特定的生产环 ...

  4. ML机器学习基于树的家族

    ML机器学习基于树的家族 目录 决策树模型与学习 特征选择 决策树的生成 3.1 决策树的编程实现 3.2 画出决策树的方式 决策树的剪枝 DBGT 随机森林 参考资料: 机器学习实战 统计学习方法 ...

  5. 干货 | 基于贝叶斯推断的分类模型 机器学习你会遇到的“坑”

    本文转载自公众号"读芯术"(ID:AI_Discovery) 本文3153字,建议阅读8分钟. 本文讲解了在学习基于贝叶斯推断的分类模型中,我们需要的准备和方法. 数学准备 概率: ...

  6. ML之ME/LF:基于不同机器学习框架(sklearn/TF)下算法的模型评估指标(损失函数)代码实现及其函数(Scoring/metrics)代码实现(仅代码)

    ML之ME/LF:基于不同机器学习框架(sklearn/TF)下算法的模型评估指标(损失函数)代码实现及其函数(Scoring/metrics)代码实现(仅代码) 目录 单个评价指标各种框架下实现 1 ...

  7. 基于贝叶斯推断的分类模型 机器学习你会遇到的“坑”

    链接:贝叶斯推断分类 数学准备 概率:事件不确定性程度的量化,概率越大,表示事件发生的可能性越大. 条件概率:P(A|B),在条件B下,发生A的概率. 联合概率:P(A,B),A事件与B事件同时发生的 ...

  8. 基于机器学习技术的用户行为分析:当前模型和应用研究综述(A survey for user behavior analysis based on machine learning technique)

    A survey for user behavior analysis based on machine learning techniques: current models and applica ...

  9. 机器学习实验笔记-基于信用卡数据建立行为评分模型的机器学习方法

    基于信用卡数据建立行为评分模型的机器学习方法 很久之前的一个答疑, 应该不会再影响评分了, 记录以供复习. 数据集与代码放在CSDN下载区域, 也可以留言索要. https://download.cs ...

最新文章

  1. oracle执行减法,oracle时间的加法和减法
  2. 小白的算法初识课堂(part4)--快速排序
  3. UDP通讯接收案例(组播方式)
  4. loj 6083.「美团 CodeM 资格赛」数码
  5. tab控件的使用心得
  6. 【毕业设计】JSP数据库连接池的研究与实现(源代码+论文)
  7. iOS开发日记9-终端命令
  8. 易语言c语言哪个做游戏脚本,游戏简易脚本制作教程
  9. SaaS微信小程序电商系统,一键生成小程序【源码分享】
  10. 被称为“Google 最大黑科技”,开发谷歌大脑,这位 AI 掌门人到底有多牛?
  11. matlab中 概率密度估计ksdensity,k-s检验kstest和kstest2(单/双样本检验数据是否符合某种分布)
  12. 方舟同步服务器信息,方舟服务器备份和数据库备份
  13. photoshop-photoshop记录
  14. 《App研发录》读书笔记
  15. 计算机科学丛书收藏,计算机科学丛书:机器学习
  16. python实现自动批量下载邮箱附件--GUI
  17. 静态时序分析—串扰延迟分析(Crosstalk Delay Analysis)
  18. [转贴]金庸的九家著名公司
  19. 国密sm2 js加密后台解密,sm3 js、后台加密,sm4 后台加密
  20. 从MySQL中导出表中数据_用命令从mysql中导出/导入表结构及数据

热门文章

  1. 深入浅出JVM-GC过程
  2. 微信小程序 - 用户进入客服会话会在右下角显示可能要发送的小程序提示
  3. Unable to delete directory: D:\Downloads\githubdownfive\tianxmyapp\library\
  4. Jmeter日志输出和日志级别设置
  5. superset 图标调整
  6. Module build failed (from ./node_modules/sass-loader/lib/loader.js):
  7. Eclipse里git提交冲突rejected – non-fast-forward
  8. 0网卡开启_中标麒麟Linux v7系统下设置双网卡bond或team绑定详细过程
  9. cdn加载vue很慢_Vue.js 项目打包优化实践
  10. Linux监控工具介绍系列——smem