java调用tensorflow训练好的模型
1. python的处理
整个模型的源码在此:https://github.com/shelleyHLX/tensorflow_java
多谢star
首先训练一个模型,代码如下
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework import graph_util## -1到1之间随机数 100个
train_X = np.linspace(-1, 1, 100)
train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.1# 显示模拟数据点plt.plot(train_X, train_Y, 'ro', label='test')
plt.legend()
plt.show()# 创建模型
# 占位符
X = tf.placeholder("float",name='X')
Y = tf.placeholder("float",name='Y')# 模型参数
# W初始化为-1到1之间的一个数字
W = tf.Variable(tf.random_normal([1]), name="weight")
# b初始化为0 也是一维 定义变量
b = tf.Variable(tf.zeros([1]), name="bias")# 前向结构 mulpiply两个数 相乘
z = tf.multiply(X, W) + b
op = tf.add(tf.multiply(X, W),b,name='results')
# 反向优化
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)# 初始化所有变量
init = tf.global_variables_initializer()# 定义参数
training_epochs = 20
display_step = 2def moving_avage(a, w=10):if len(a) < w:return a[:]return [val if idx<w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]saver = tf.train.Saver()# 启动session
with tf.Session() as sess:sess.run(init)# 存放批次值和损失值plotdata = {"batchsize": [], "loss": []}# 向量模型输入数据for epoch in range(training_epochs):for(x, y) in zip(train_X, train_Y):sess.run(optimizer, {X:x, Y:y})# 显示训练中的详细信息if epoch % display_step == 0:loss = sess.run(cost, {X:train_X, Y:train_Y})print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))if not (loss == "NA"):plotdata["batchsize"].append(epoch)plotdata["loss"].append(loss)print("Finished!")#保存模型saver.save(sess, "model/first")print("cost =", sess.run(cost, feed_dict={X:train_X, Y:train_Y}), "W=", sess.run(W), "b=", sess.run(b))const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,["results"])with tf.gfile.FastGFile("model/first.pb",mode='wb') as f:f.write(const_graph.SerializeToString())# 图形显示plt.plot(train_X, train_Y, 'ro', label='Original data')plt.plot(train_X, sess.run(W)*train_X+sess.run(b),label='Filttedline')plt.legend()plt.show()plotdata["avgloss"] = moving_avage(plotdata["loss"])# plt.figure(1)plt.subplot(211)plt.plot(plotdata["batchsize"],plotdata["avgloss"], 'b--')plt.xlabel('Minibatch number')plt.ylabel('Loss')plt.title('Minibatch run vs, Trainging loss')plt.show()print("x=0.2, z=", sess.run(z, {X:0.2}))
测试模型:
from tensorflow.python.platform import gfile
import tensorflow as tfsess = tf.Session()with gfile.FastGFile('model/first.pb','rb') as f:graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def,name='')sess.run(tf.global_variables_initializer())print(sess.run('weight:0'))
print(sess.run('bias:0'))input_x = sess.graph.get_tensor_by_name('X:0')op = sess.graph.get_tensor_by_name('results:0')ret = sess.run(op, feed_dict={input_x: 2})print(ret)
2 java的处理
新建一个maven项目
把模型加入项目中.
在pom.xml设置tensorflow,第一次使用会下载.
在xin/src/test/java/com.xin.tf_java.xin新建一个java类:abcd.java
内容如下:
package com.xin.tf_java.xin;import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.io.IOUtils;import javax.imageio.ImageIO;import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.apache.commons.io.IOUtils; public class abcd {public static void main(String[] args) throws FileNotFoundException, IOException {// TODO Auto-generated method stubtry (Graph graph = new Graph()) {byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model/first.pb"));graph.importGraphDef(graphBytes);try (Session session = new Session(graph)) {Tensor<?> out = session.runner().feed("X", Tensor.create(2.0f)).fetch("results").run().get(0);float[] r = new float[1];out.copyTo(r);System.out.println(r[0]);}}}}
要把commons-io-2.6.jar加入;下载位置:http://commons.apache.org/proper/commons-io/download_io.cgi
change project compliance and jre to 1.7照做就可以
右键运行
reference:
https://my.oschina.net/yjwxh/blog/2874957
java调用tensorflow训练好的模型相关推荐
- 如何用java语言调用tensorflow训练好的模型
1.TensorFlow的训练模型在Android和Java的应用及调用 2.tensorflow的python离线训练java在线预测方案 3.tensorflow训练的模型在java中的使用 4. ...
- java加载tensorflow训练的PB模型记录
java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...
- 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)
将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...
- 使用PaddleFluid和TensorFlow训练序列标注模型
专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...
- 将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)
将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite) 1. 写在前面 最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址 ...
- 基于TensorFlow训练花朵识别模型的源码和Demo
基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...
- 如何调用 caffe 训练好的模型对输入图片进行测试
如何调用 caffe 训练好的模型对输入图片进行测试 该部分包括两篇文章 win10 下 caffe 的第一个测试程序(附带详细讲解) 主要讲解如何利用 caffe 来训练模型. 如何调用 caffe ...
- TensorFlow 调用预训练好的模型—— Python 实现
1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...
- OpenCV的dnn模块调用TesorFlow训练的MoblieNet模型
七月 上海| 高性能计算之GPU CUDA培训 7月27-29日三天密集式学习 快速带你入门阅读全文> 正文共2073个字,2张图,预计阅读时间10分钟. 一.初得模型 那是一个月之前的事情了 ...
最新文章
- 这个赛道能超车几次?
- 企业开发与社交开发相辅相成
- Java设计模式 创建模式-单态模式(Singleton)
- QT的QAxFactory类的使用
- BZOJXXXX: [IOI2000]邮局——四边形不等式优化初探
- 面试官:要不讲讲 Cookie、Session、Token、JWT之间的区别?
- geek_How-To Geek正在寻找安全作家
- 商业银行如何进行分布式数据库选型思考
- 大动作!华为海思注册资本从6亿增加到20亿
- 获取表数据_大数据抽取解决方案——kettle分页循环
- 基于Java保险员工管理系统的设计与实现
- EAS 后台事务配置
- crontab 每周五_关于linux:如何在星期天每周运行crontab作业
- Spring动态代理的两种区别
- NetSuite导出CSV文件用Excel打开是乱码
- GitHub建立个人网站(一)
- 计算计算机系统包括哪些内容,什么是MIPS计算机系统的运算器
- vue krpano 视角监听
- 报错:The path is not a valid path to the xxx kernel headers.
- 【基础知识】~ FIFO
热门文章
- 43.【Java 实现验证码获取 C++实现密码加密和删除和QQ登入系统】
- 嵌入式以及嵌入式行业的基本信息
- 设计一个对银行账户余额操作的简单程序(Java)
- Python之跳出语句(break,continue)
- Hexo+GitHub Pages搭建个人博客( 0 基础、小白值得一看--实力软文!)- 初行ᵀᵀᴴ
- 移动游戏开发商50强(世界)
- 一个基于.Net Core 开源的物联网基础平台
- 辞职时被老板叫去谈话挽留,怎样避免被套路
- QQ空间打不开,IE里无法运行脚本的解决方案 转自:spookfox.cublog.cn
- IoT黑板报0113:你天天在扫的二维码其实是日本人发明的