这是个我想干很久的事情了。之前研究tensorflow on spark, DL4j 都没有成功。所以这里首先讲一下我做这件事情的流程。模型的部署,首先你得有一个模型。这里假设你有了一个keras模型,假设你保存了一个keras 的.h5模型

python 准备阶段

你需要通过以下代码将keras h5的模型转化为pb文件

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True):
""":param session: 需要转换的tensorflow的session:param keep_var_names:需要保留的variable,默认全部转换constant:param output_names:output的名字:param clear_devices:是否移除设备指令以获得更好的可移植性:return:"""from tensorflow.python.framework.graph_util import convert_variables_to_constantsgraph = session.graphwith graph.as_default():freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or []))output_names = output_names or []# 如果指定了output名字,则复制一个新的Tensor,并且以指定的名字命名if len(output_names) > 0:for i in range(len(output_names)):# 当前graph中复制一个新的Tensor,指定名字tf.identity(model.outputs[i], name=output_names[i])output_names += [v.op.name for v in tf.global_variables()]input_graph_def = graph.as_graph_def()if clear_devices:for node in input_graph_def.node:node.device = ""frozen_graph = convert_variables_to_constants(session, input_graph_def,output_names, freeze_var_names)return frozen_graph

from keras import backend as K
import tensorflow as tf
from keras.models import load_model
model = load_model("models/model.h5")
print(model.input.op.name)
print(model.output.op.name)
print(model)
# 自定义output_names
frozen_graph = freeze_session(K.get_session(), output_names=["output"])
tf.train.write_graph(frozen_graph, "./", "model.pb", as_text=False)

如果你用的是tensorflow模型可以从这里开始看

这两行代码打印的是tensorflow模型的输入和输出,这个输入和输出的名字在后面java读入模型的时候有用到

print(model.input.op.name)
print(model.output.op.name)

java 部分才是细节多的地方

java部分,用java的好处在于可以把程序和资源打成jar包。不依赖于集群的资源。

类必须实现Serializable虚接口,否则不能将其写成UDF函数。

其实这一部分也可以用scala来写。这样类中的某些不能序列化的成员变量,比如Graph对象,可以通过 with Serializable来实现

public class Predictor implements Serializable 

java中资源文件的读取

资源在jar包当中时,需要在<build>标签中添加如下的内容

<resources>
<resource>
<directory>src/main/resources</directory>
<targetPath>resource</targetPath>
</resource>
</resources>

因为jar包是一个单独的文件,在打成jar包了以后,不能像在IDEA中那样运行,需要通过createTempFile和FileOutputStream的方式将内容读取出来

File labelEncoderFile = null;
String resource = "/resource/label_encoder.json";
URL res = getClass().getResource(resource);
if (res.getProtocol().equals("jar")) {try {InputStream input = getClass().getResourceAsStream(resource);labelEncoderFile = File.createTempFile("tempfile", ".tmp");OutputStream out = new FileOutputStream(labelEncoderFile);
int read;
byte[] bytes = new byte[1024];while ((read = input.read(bytes)) != -1) {out.write(bytes, 0, read);
}
out.close();
labelEncoderFile.deleteOnExit();
}catch (IOException ex) {ex.printStackTrace();
}
}else {//this will probably work in your IDE, but not from a JARlabelEncoderFile = new File(res.getFile());
}if (labelEncoderFile != null && !labelEncoderFile.exists()) {throw new RuntimeException("Error: File " + labelEncoderFile + " not found!");
}

python 预处理文件的保存和java对预处理文件的恢复

一个重要的问题是,在python中的模型预处理阶段也需要放到java当中,以实现端到端的

pipeline ml flow.对于文本处理来说,有两个,一个是tokenizer,一个是LabelEncoder

我个人比较建议使用json进行读写,因为dict转json比较容易,java也比较容易读取这个内容

with open('label_encoder.json', 'w') as f:json.dump(le_dict, f)with open('word_index.json', 'w') as f:json.dump(tokenizer.word_index, f)

下边是java从json的自己构造的labelEncoder的代码。LabelEncoder 和tokenizer 实际上就是一个java 的Map<String, String>。所以转换成了map也就实现了我自己的LabelEncoder

String labelEncoderContent = null;
try {labelEncoderContent = FileUtils.readFileToString(labelEncoderFile, "UTF-8");
}catch (IOException e) {e.printStackTrace();
}Map<String, String> outputToLabel = new HashMap<String, String>();
ObjectMapper labelEncoderMapper = new ObjectMapper();
try {outputToLabel = labelEncoderMapper.readValue(labelEncoderContent, new TypeReference<HashMap<String, String>>() {});
}catch (Exception e) {e.printStackTrace();
}

Predictor预测类中的两个关键成员变量

Predictor是我的预测java类,该类有两个很关键的成员变量,一个是this.graph,另外一个是this.sess 这两个类是用来读取pb文件,并预测的主要工具

而graph没有实现Serializable接口,所以不能在构造函数当中对其初始化。否则就不能广播,也就没有办法写成UDF。所以我在预测阶段判断了成员变量是否为空。

//this.graph是否为空,从而导入图和session
if (this.graph == null) {this.graph = new Graph();this.graph.importGraphDef(this.pbBytes);
}
if (this.sess == null) {this.sess = new Session(this.graph);
}

当我们的sess就位了以后就是使用tensorflow 的java接口来进行预测了。注意要在pom当中添加tensorflow的依赖

终于到预测部分了

float[][] index_seqs = new float[1][MAX_LEN];
try (Tensor x = Tensor.create(index_seqs);
// input是输入的name,output是输出的nameTensor y = sess.runner()
.feed("input_1_3", x)
.feed("dropout_1/keras_learning_phase", Tensor.create(false))
.fetch("dense_1_3/Softmax").run().get(0)) {float[][] result = new float[1][2033];
y.copyTo(result);

result 就是我们想要的模型输出向量

scala part

而在scala部分 最主要的是org.apache.spark.sql.functions.udf的使用

import org.apache.spark.sql.functions.udf
val predictor:Predictor = Predictor.getInstance()
def predictWithProbability= { goods_name:String => predictor.predictWithProbability(name)}
val predictWithProbabilityUDF = udf(predictWithProbability)
val predictedDataSet:DataFrame = predictDataSet.withColumn("result", predictWithProbabilityUDF(predictDataSet.col("name")))

这样就能够得到结果了

打成jar包_keras, tensorflow模型部署通过jar包部署到spark环境攻略相关推荐

  1. DL框架之Tensorflow:深度学习框架Tensorflow的简介、安装、使用方法之详细攻略

    DL框架之Tensorflow:深度学习框架Tensorflow的简介.安装.使用方法之详细攻略 目录 Tensorflow的简介 1.描述 2.TensorFlow的六大特征 3.了解Tensorf ...

  2. DL:神经网络算法简介之耗算力的简介、原因、经典模型耗算力计算、GPU使用之详细攻略

    DL:神经网络算法简介之耗算力的简介.原因.经典模型耗算力计算.GPU使用之详细攻略 目录 神经网络算法耗算力的简介 神经网络算法耗算力的原因 神经网络算法耗算力的经典模型耗算力计算 1.AlexNe ...

  3. Ubuntu:Ubuntu下安装Anaconda和Tensorflow的简介、入门、安装流程之详细攻略

    Ubuntu:Ubuntu下安装Anaconda和Tensorflow的简介.入门.安装流程之详细攻略 目录 安装流程 1.安装nvidia显卡驱动 2.安装cuda8 3.安装Cudnn 4.Ana ...

  4. Python编程语言学习:包导入和模块搜索路径简介、使用方法之详细攻略

    Python编程语言学习:包导入和模块搜索路径简介.使用方法之详细攻略 目录 包导入和模块搜索路径简介 1.Pyhon搜索模块路径的机制 2.自定义配置搜索路径

  5. NPM:nodejs官方包管理工具的简介、安装、使用方法之详细攻略

    NPM:nodejs官方包管理工具的简介.安装.使用方法之详细攻略 目录 NPM之nodejs官方包管理工具的简介 NPM之nodejs官方包管理工具的安装 NPM之nodejs官方包管理工具的使用方 ...

  6. Py之pipenv:Python包的管理利器pipenv简介、安装、使用方法详细攻略

    Py之pipenv:Python包的管理利器pipenv简介.安装.使用方法详细攻略 目录 pipenv简介 pipenv安装 pipenv使用方法 pipenv简介 Python开发者应该听过pip ...

  7. Py之qrcode:Python包之qrcode的简介、安装、使用方法之详细攻略

    Py之qrcode:Python包之qrcode的简介.安装.使用方法之详细攻略 目录 qrcode简介 qrcode的安装 qrcode的使用方法 qrcode简介 二维码简称 QR Code(Qu ...

  8. Py之matplotlib:python包之matplotlib库图表绘制包的简介、安装、使用方法(matplotlib颜色大全)详细攻略

    Py之matplotlib:python包之matplotlib库图表绘制包的简介.安装.使用方法(matplotlib颜色大全)详细攻略 目录 matplotlib简介 matplotlib安装 m ...

  9. CV:Win10下深度学习框架安装之Tensorflow/tensorflow_gpu+Cuda+Cudnn(最清楚/最快捷)之详细攻略(图文教程)

    CV:Win10下深度学习框架安装之Tensorflow/tensorflow_gpu+Cuda+Cudnn(最清楚/最快捷)之详细攻略(图文教程) 导读 本人在Win10下安装深度学习框架Tenso ...

最新文章

  1. 三种基本排序的实现及其效率对比:冒泡排序、选择排序和插入排序
  2. 面试官: 谈谈什么是守护线程以及作用 ?
  3. FreeRTOS临界区应用与总结
  4. https跨域到http问题解决
  5. python爬取分页数据
  6. python中的 descriptor
  7. nlp5-n-gram/语言模型(数据平滑方法
  8. PyTorch 1.7 发布:支持 CUDA 11、FFT 新 API、及 Windows 分布式训练
  9. LiveNVR高性能稳定RTSP、Onvif探测流媒体服务配置通道接入海康、大华等摄像机进行全终端无插件直播...
  10. 简单记录一下做的项目过程中踩过的坑
  11. 照片宽高比怎么设置_2019年中级会计报名照片上传完整攻略
  12. CSRF--跨站请求伪造
  13. c35是什么意思_混凝土C35P6是什么意思
  14. css3多列布局(columnz),多列布局相关属性
  15. 看这里→大数据工程技术人员系列课程—《大数据工程技术人员-大数据基础技术》正式上线!...
  16. linux 怎么卸载glib,glib的安装
  17. 《支付宝对接之-当面付》
  18. 地平线:面向规模化量产的智能驾驶系统和软件开发
  19. 「普通人VS程序员」电脑还可以这样关机,神操作 建议阅读
  20. g团最多的服务器,艾泽拉斯服务器 5人BWL G团赏析(二)

热门文章

  1. java B2B2C Springcloud电子商城系统-Ribbon设计原理
  2. live555源码分析----RSTPServer创建过程分析
  3. 《高性能JavaScript》(读书笔记)
  4. 如何进行有效的需求调研
  5. 怎样使破解网页的禁止复制黏贴
  6. Delphi-网络编程-UDP聊天程序(转)
  7. wordpress 后台,登录,注册开启https的重写规则
  8. 笔记本打字不知道按了什么键,打字老出现数字?
  9. 云计算推进企业管理深化,私有云将会深入企业
  10. hive -f 传递参数