Keras 是一个用 Python 编写的高级神经网络 API,它能够以 TensorFlow, CNTK, 或者 Theano 作为后端运行。Keras 的开发重点是支持快速的实验。有时候我们在使用keras设计好模型后,需要在其他平台进行运行,这时候我们就需要将keras h5 model转换为TensorFlow pb model,因为keras只是一个Python的高级库,而TensorFlow能够支持多平台的运行。

环境

Python 3.6

Keras 2.2.2

Tensorflow-gpu 1.8.0

Keras to Tensorflow

测试数据:

from keras.datasets import imdb

def get_data():

max_features = 20000

# cut texts after this number of words

# (among top max_features most common words)

maxlen = 100

print('Loading data...')

(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features)

print(x_train.shape, 'train sequences')

print(x_test.shape, 'test sequences')

print('Pad sequences (samples x time)')

x_train = sequence.pad_sequences(x_train, maxlen=maxlen)

x_test = sequence.pad_sequences(x_test, maxlen=maxlen)

print('x_train shape:', x_train.shape)

print('x_test shape:', x_test.shape)

y_train = np.array(y_train)

y_test = np.array(y_test)

return x_train, x_test, y_train, y_test

生成一个keras模型进行训练,获得模型和对应的权重文件:

from keras.layers import Conv1D, GlobalMaxPooling1D, Embedding, Dense, Dropout

from keras.datasets import imdb

from keras.preprocessing import sequence

from keras.models import Sequential

def gen_keras_model(x_train, x_test, y_train, y_test, train=False):

inp = Input(shape=(100,))

x = Embedding(20000, 50)(inp)

x = Dropout(0.2)(x)

x = Conv1D(250, 3, padding='valid', activation='relu', strides=1)(x)

x = GlobalMaxPooling1D()(x)

x = Dense(250, activation='relu')(x)

x = Dropout(0.2)(x)

x = Dense(1, activation='sigmoid')(x)

model = Model(inputs=inp, outputs=x)

if train:

model.compile(loss='binary_crossentropy',

optimizer='adam',

metrics=['accuracy'])

model.fit(x_train, y_train,

batch_size=32,

epochs=2,

validation_data=(x_test, y_test))

model.save_weights('model.h5')

return model

if __name__ == '__main__':

x_train, x_test, y_train, y_test = get_data()

model = gen_keras_model(x_train, x_test, y_train, y_test, True)

下面的函数将keras model转换为Tensorflow pb文件:

首先构建一个Session与空的计算图,将这个计算图设置为默认的计算图。

获取keras model的输出节点,将这个输出节点与节点名在这个计算图中进行绑定。

使用convert_variables_to_constants函数保存数输出节点,函数会自动推导计算图并将计算图中的变量取值以常量的形式保存。在保存模型文件的时候,我们只是导出了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。

最后向指定目录写入pb文件。

如果你的graph使用了Keras的learning phase(在训练和测试中行为不同),你首先要做的事就是在graph中硬编码你的工作模式(设为0,即测试模式),该工作通过:1)使用Keras的后端注册一个learning phase常量,2)重新构建模型,来完成。

import tensorflow as tf

from keras import backend as K

from tensorflow.python.framework import graph_util, graph_io

def export_graph(model, export_path):

input_names = model.input_names

if not tf.gfile.Exists(export_path):

tf.gfile.MakeDirs(export_path)

with K.get_session() as sess:

init_graph = sess.graph

with init_graph.as_default():

out_nodes = []

for i in range(len(model.outputs)):

out_nodes.append("output_" + str(i + 1))

tf.identity(model.output[i], "output_" + str(i + 1))

init_graph = sess.graph.as_graph_def()

main_graph = graph_util.convert_variables_to_constants(sess, init_graph, out_nodes)

graph_io.write_graph(main_graph, export_path, name='model.pb', as_text=False)

return input_names, out_nodes

if __name__ == '__main__':

x_train, x_test, y_train, y_test = get_data()

learning_phase = 0

K.set_learning_phase(learning_phase)

model = gen_keras_model(x_train, x_test, y_train, y_test, learning_phase)

model.load_weights('model.h5')

input_names, output_names = export_graph(model, 'model')

在Python Tensorflow环境下进行测试

首先在Session与Graph中读入pb文件,构建计算图。

然后根据输入张量与输出张量的张量名来获取到对应的张量,这里一定要加上:0。比如input_1:0是张量的名称而input_1表示的是节点的名称。

最后使用常规的Tensorflow操作来运行模型。

import numpy as np

import tensorflow as tf

from sklearn.metrics import accuracy_score

def run_graph(pb_file_path, input_name, output_name, x_test, y_test):

tf.reset_default_graph()

sess = tf.Session()

with tf.gfile.FastGFile(pb_file_path, 'rb') as f:

graph_def = tf.GraphDef()

graph_def.ParseFromString(f.read())

sess.graph.as_default()

tf.import_graph_def(graph_def, name='')

#输入

input_x = sess.graph.get_tensor_by_name('{}:0'.format(input_name))

#输出

op = sess.graph.get_tensor_by_name('{}:0'.format(output_name))

#预测结果

pred = []

for x in x_test:

res = sess.run(op, {input_x: x.reshape(1, -1)})

pred.append(res[0])

pred = np.array([1 if p > 0.5 else 0 for p in pred])

acc = accuracy_score(y_test, pred)

print('Accuracy:{}'.format(acc))

if __name__ == '__main__':

x_train, x_test, y_train, y_test = get_data()

learning_phase = 0

K.set_learning_phase(learning_phase)

model = gen_keras_model(x_train, x_test, y_train, y_test, learning_phase)

model.load_weights('model.h5')

input_names, output_names = export_graph(model, 'model')

pred = run_graph('model\model.pb', input_names[0], output_names[0], x_test, y_test)

输出如下:

Using TensorFlow backend.

Loading data...

(25000,) train sequences

(25000,) test sequences

Pad sequences (samples x time)

x_train shape: (25000, 100)

x_test shape: (25000, 100)

INFO:tensorflow:Froze 7 variables.

Converted 7 variables to const ops.

Accuracy:0.84388

在JavaTensorflow环境下进行测试

在 Windows 上安装按照以下步骤在 Windows 上安装适用于 Java 的 TensorFlow:

下载 libtensorflow.jar,这是 TensorFlow Java 归档 (JAR)。

解压缩该 .zip 文件。

配置到IDEA的External Libraries中。

eebda9dbecf8

setting

在Java中使用PB文件的代码如下,我们随机生成一个数组作为输入的张量进行测试。整个流程与Python下类似,需要注意的是生成输入张量时数组类型需要定义为float类型,不然会出现以下错误:

Exception in thread "main" java.lang.IllegalArgumentException: Expects arg[0] to be float but double is provided

Java下的测试代码:

import org.tensorflow.Graph;

import org.tensorflow.Session;

import org.tensorflow.Tensor;

import java.io.IOException;

import java.nio.file.Files;

import java.nio.file.Paths;

import java.util.Arrays;

public class TFTest {

public static void main(String[] args) throws IOException {

String path = "E:\\Documents\\Desktop\\code\\glu\\model\\model.pb";

float[][] input = new float[1][100];

for (int i=0; i < 100; i++){

input[0][i] = (float) (Math.random() * 100);

}

try (Graph graph = new Graph()){

graph.importGraphDef(Files.readAllBytes(Paths.get(path)));

try (Session sess = new Session(graph)){

try (Tensor x = Tensor.create(input);

Tensor y = sess.runner().feed("input_1", x).fetch("output_1").run().get(0)){

float[] res = (float[]) y.copyTo(new float[1]);

System.out.println(Arrays.toString(y.shape()));

System.out.println(Arrays.toString(res));

}

}

}

}

}

输出结果如下:

[1]

[0.088513985]

Process finished with exit code 0

JAVA调用 keras,在TensorFlow(Python, Java)环境下使用Keras模型相关推荐

  1. flask keras 多线程环境下加载模型

    keras 多线程环境下加载模型 Tensor Tensor is not an element of this graph. 问题场景 keras 使用flask 发布深度学习模型服务,模型有一个定 ...

  2. ATX+Python+uiautomator2环境下进行手机UI自动化测试

    ATX+Python+uiautomator2环境下进行手机UI自动化测试 环境搭建 手机环境初始化 在网页端的UI查看器中查看控件及属性 以下是一些自己测试的脚本 环境搭建 开始配置uiautoma ...

  3. 小勇rust_大规模分布式环境下动态信任模型研究

    李小勇等:大规模分布式环境下动态信任模型研究1519 通过反馈控制机制,动态调节计算节点的信任值的上述参数:(2)提出了用机器学习中强化学习的方法计算信任度,并用惩罚因子对学习因子进行了明确定义,所以 ...

  4. Java练习-----2.对Windows和Linux环境下输入的文件路径格式进行校验

    1.需求 Windows环境下路径格式只能为 D:\Desktop\source Linux环境下路径格式只能为 /data/source 2.结果展示 Linux环境下运行成功,懒得开虚拟机,就不展 ...

  5. java 调用postgresql 函数_从Java调用PostgreSQL中的存储过程

    我编写了一个我想用Java调用的存储过程.但我不认为它能够对我通过的查询做任何事情.以下是我的java代码: String QUERY_LOCATION = "select (license ...

  6. java 调用祖父方法_在Java中调用祖父母方法:您不能

    java 调用祖父方法 在文章保护的重点中,我详细介绍了"受保护"如何扩展"包私有"访问. 我在那儿写道: 你能做的是 覆盖子类中的方法或 使用关键字super ...

  7. java调用kettle例子_Kettle API - Java调用示例

    Kettle API - Java调用示例 对向前兼容性的推荐:如果想要动态地创造Transformation (例如:从元数据),使用XML文件方法(KTR)而不是使用API.XML文件兼容Kett ...

  8. android java调用_关于Android中Java调用外部命令的三种方式

    此所谓三种方式,只是个人认为.本人还是菜鸟初涉,所以有所错误,请指正. 个人认为,Java调用外部命令.无非三种情况: 一.是只执行命令,不考虑返回值. 二.是执行命令的同时,还需要得到返回值. 三. ...

  9. java用户的授权及验证_Java环境下shiro的测试-认证与授权

    Java环境下shiro的测试 1.导入依赖的核心jar包 org.apache.shiro shiro-core 1.3.2 2.认证程序 2.1 构建users配置文件 xxx.ini doGet ...

最新文章

  1. 土木工程真的这么可怕吗?
  2. curl下载失败返回0_curl返回常见错误码
  3. 机器学习的宝典-华校专老师的笔记
  4. 手游引擎Unity和Cocos各有什么优劣?
  5. linux sudo输入密码无法获得锁,Linux系统提示无法获得锁/var/lib/dpkg/lock怎么办?
  6. 记录下返回list给前端 遇到 $ref:$.data.*** 问题
  7. 11g Database Installation flow
  8. 淘宝 NPM 镜像解决软件下载速度慢的问题
  9. 应急响应.windows
  10. 电精2(电神魔傀2) android版本下载
  11. MySQL8中文手册【持续更新】
  12. 书_阿朱_好好看书[转]
  13. java报表技术总结_15个Java的报表工具总结
  14. 因果推断-Uplift Model:Meta Learning
  15. 背篼酥课堂第九课--前端知识、APP知识
  16. 什么是 yum?更改yum源 yum的相关命令
  17. 问题 A: Jugs BFS
  18. 安得指针千万间,大庇天下地址具欢颜(中)
  19. Java实现 LeetCode 152 乘积最大子序列
  20. 对java后端的一些学习建议

热门文章

  1. 【Java】Java对象转换成Map
  2. Queries with streaming sources must be executed with writeStream.start()
  3. Spring : 基于tx标签的声明式事物
  4. SpringMVC框架搭建的步骤
  5. 12c集群日志位置_大数据系列教程006-开启日志聚合功能
  6. Java多线程学习三十六:主内存和工作内存的关系
  7. Java 中的异常处理
  8. 【Hive的高级查询详】
  9. MongoDB的默认用户名和密码是什么?
  10. 多线程执行sql报错处理