1. python的处理

整个模型的源码在此:https://github.com/shelleyHLX/tensorflow_java

多谢star

首先训练一个模型,代码如下

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.python.framework import graph_util## -1到1之间随机数 100个
train_X = np.linspace(-1, 1, 100)
train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.1# 显示模拟数据点plt.plot(train_X, train_Y, 'ro', label='test')
plt.legend()
plt.show()# 创建模型
# 占位符
X = tf.placeholder("float",name='X')
Y = tf.placeholder("float",name='Y')# 模型参数
# W初始化为-1到1之间的一个数字
W = tf.Variable(tf.random_normal([1]), name="weight")
# b初始化为0 也是一维  定义变量
b = tf.Variable(tf.zeros([1]), name="bias")# 前向结构   mulpiply两个数 相乘
z = tf.multiply(X, W) + b
op = tf.add(tf.multiply(X, W),b,name='results')
# 反向优化
cost = tf.reduce_mean(tf.square(Y - z))
learning_rate = 0.01
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)# 初始化所有变量
init = tf.global_variables_initializer()# 定义参数
training_epochs = 20
display_step = 2def moving_avage(a, w=10):if len(a) < w:return a[:]return [val if idx<w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)]saver = tf.train.Saver()# 启动session
with tf.Session() as sess:sess.run(init)# 存放批次值和损失值plotdata = {"batchsize": [], "loss": []}# 向量模型输入数据for epoch in range(training_epochs):for(x, y) in zip(train_X, train_Y):sess.run(optimizer, {X:x, Y:y})# 显示训练中的详细信息if epoch % display_step == 0:loss = sess.run(cost, {X:train_X, Y:train_Y})print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b))if not (loss == "NA"):plotdata["batchsize"].append(epoch)plotdata["loss"].append(loss)print("Finished!")#保存模型saver.save(sess, "model/first")print("cost =", sess.run(cost,  feed_dict={X:train_X, Y:train_Y}), "W=", sess.run(W), "b=", sess.run(b))const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,["results"])with tf.gfile.FastGFile("model/first.pb",mode='wb') as f:f.write(const_graph.SerializeToString())# 图形显示plt.plot(train_X, train_Y, 'ro', label='Original data')plt.plot(train_X, sess.run(W)*train_X+sess.run(b),label='Filttedline')plt.legend()plt.show()plotdata["avgloss"] = moving_avage(plotdata["loss"])# plt.figure(1)plt.subplot(211)plt.plot(plotdata["batchsize"],plotdata["avgloss"], 'b--')plt.xlabel('Minibatch number')plt.ylabel('Loss')plt.title('Minibatch run vs, Trainging loss')plt.show()print("x=0.2, z=", sess.run(z, {X:0.2}))

测试模型:

from tensorflow.python.platform import gfile
import tensorflow as tfsess = tf.Session()with gfile.FastGFile('model/first.pb','rb') as f:graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def,name='')sess.run(tf.global_variables_initializer())print(sess.run('weight:0'))
print(sess.run('bias:0'))input_x = sess.graph.get_tensor_by_name('X:0')op = sess.graph.get_tensor_by_name('results:0')ret = sess.run(op, feed_dict={input_x: 2})print(ret)

2 java的处理

新建一个maven项目

把模型加入项目中.

在pom.xml设置tensorflow,第一次使用会下载.

在xin/src/test/java/com.xin.tf_java.xin新建一个java类:abcd.java

内容如下:

package com.xin.tf_java.xin;import java.awt.image.BufferedImage;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import org.apache.commons.io.IOUtils;import javax.imageio.ImageIO;import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.Output;
import org.tensorflow.Session;
import org.tensorflow.Shape;
import org.tensorflow.Tensor;
import org.apache.commons.io.IOUtils; public class abcd {public static void main(String[] args) throws FileNotFoundException, IOException {// TODO Auto-generated method stubtry (Graph graph = new Graph()) {byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model/first.pb"));graph.importGraphDef(graphBytes);try (Session session = new Session(graph)) {Tensor<?> out = session.runner().feed("X", Tensor.create(2.0f)).fetch("results").run().get(0);float[] r = new float[1];out.copyTo(r);System.out.println(r[0]);}}}}

要把commons-io-2.6.jar加入;下载位置:http://commons.apache.org/proper/commons-io/download_io.cgi

change project compliance and jre to 1.7照做就可以

右键运行

reference:

https://my.oschina.net/yjwxh/blog/2874957

java调用tensorflow训练好的模型相关推荐

  1. 如何用java语言调用tensorflow训练好的模型

    1.TensorFlow的训练模型在Android和Java的应用及调用 2.tensorflow的python离线训练java在线预测方案 3.tensorflow训练的模型在java中的使用 4. ...

  2. java加载tensorflow训练的PB模型记录

    java加载tensorflow训练的PB模型记录 python训练 1. 模型的输入输出定义 2. 训练时保存模型的方法 java加载模型 1.maven依赖 2. Java代码实例 tensor注 ...

  3. 将tensorflow训练好的模型移植到Android (MNIST手写数字识别)

    将tensorflow训练好的模型移植到Android (MNIST手写数字识别) [尊重原创,转载请注明出处]https://blog.csdn.net/guyuealian/article/det ...

  4. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  5. 将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite)

    将TensorFlow训练好的模型迁移到Android APP上(TensorFlowLite) 1. 写在前面   最近在做一个数字手势识别的APP(关于这个项目,我会再写一篇博客仔细介绍,博客地址 ...

  6. 基于TensorFlow训练花朵识别模型的源码和Demo

    基于TensorFlow训练花朵识别模型的源码和Demo 转发来源: https://blog.csdn.net/Anymake_ren/article/details/80550684 下面就通过对 ...

  7. 如何调用 caffe 训练好的模型对输入图片进行测试

    如何调用 caffe 训练好的模型对输入图片进行测试 该部分包括两篇文章 win10 下 caffe 的第一个测试程序(附带详细讲解) 主要讲解如何利用 caffe 来训练模型. 如何调用 caffe ...

  8. TensorFlow 调用预训练好的模型—— Python 实现

    1. 准备预训练好的模型 TensorFlow 预训练好的模型被保存为以下四个文件 data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如 ...

  9. OpenCV的dnn模块调用TesorFlow训练的MoblieNet模型

    七月 上海| 高性能计算之GPU CUDA培训 7月27-29日三天密集式学习  快速带你入门阅读全文> 正文共2073个字,2张图,预计阅读时间10分钟. 一.初得模型 那是一个月之前的事情了 ...

最新文章

  1. 这个赛道能超车几次?
  2. 企业开发与社交开发相辅相成
  3. Java设计模式 创建模式-单态模式(Singleton)
  4. QT的QAxFactory类的使用
  5. BZOJXXXX: [IOI2000]邮局——四边形不等式优化初探
  6. 面试官:要不讲讲 Cookie、Session、Token、JWT之间的区别?
  7. geek_How-To Geek正在寻找安全作家
  8. 商业银行如何进行分布式数据库选型思考
  9. 大动作!华为海思注册资本从6亿增加到20亿
  10. 获取表数据_大数据抽取解决方案——kettle分页循环
  11. 基于Java保险员工管理系统的设计与实现
  12. EAS 后台事务配置
  13. crontab 每周五_关于linux:如何在星期天每周运行crontab作业
  14. Spring动态代理的两种区别
  15. NetSuite导出CSV文件用Excel打开是乱码
  16. GitHub建立个人网站(一)
  17. 计算计算机系统包括哪些内容,什么是MIPS计算机系统的运算器
  18. vue krpano 视角监听
  19. 报错:The path is not a valid path to the xxx kernel headers.
  20. 【基础知识】~ FIFO

热门文章

  1. 43.【Java 实现验证码获取 C++实现密码加密和删除和QQ登入系统】
  2. 嵌入式以及嵌入式行业的基本信息
  3. 设计一个对银行账户余额操作的简单程序(Java)
  4. Python之跳出语句(break,continue)
  5. Hexo+GitHub Pages搭建个人博客( 0 基础、小白值得一看--实力软文!)- 初行ᵀᵀᴴ
  6. 移动游戏开发商50强(世界)
  7. 一个基于.Net Core 开源的物联网基础平台
  8. 辞职时被老板叫去谈话挽留,怎样避免被套路
  9. QQ空间打不开,IE里无法运行脚本的解决方案 转自:spookfox.cublog.cn
  10. IoT黑板报0113:你天天在扫的二维码其实是日本人发明的