由于实验室事情缘故,需要将Python写的神经网络转成Java版本的,但是python中的numpy等啥包也不知道在Java里面对应的是什么工具,所以索性直接寻找一个现成可用的Java神经网络框架,于是就找到了JOONE,JOONE是一个神经网络的开源框架,使用的是BP算法进行迭代计算参数,使用起来比较方便也比较实用,下面介绍一下JOONE的一些使用方法。

JOONE需要使用一些外部的依赖包,这在官方网站上有,也可以在这里下载。将所需的包引入工程之后,就可以进行编码实现了。

首先看下完整的程序,这个是上面那个超链接给出的程序,应该是官方给出的一个示例吧,因为好多文章都用这个,这其实是神经网络训练一个异或计算器:

import org.joone.engine.*;

import org.joone.engine.learning.*;

import org.joone.io.*;

import org.joone.net.*;

/*

*

* JOONE实现

*

* */

public class XOR_using_NeuralNet implements NeuralNetListener

{

private NeuralNet nnet = null;

private MemoryInputSynapse inputSynapse, desiredOutputSynapse;

LinearLayer input;

SigmoidLayer hidden, output;

boolean singleThreadMode = true;

// XOR input

private double[][] inputArray = new double[][]

{

{ 0.0, 0.0 },

{ 0.0, 1.0 },

{ 1.0, 0.0 },

{ 1.0, 1.0 } };

// XOR desired output

private double[][] desiredOutputArray = new double[][]

{

{ 0.0 },

{ 1.0 },

{ 1.0 },

{ 0.0 } };

/**

* @param args

*            the command line arguments

*/

public static void main(String args[])

{

XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

xor.initNeuralNet();

xor.train();

xor.interrogate();

}

/**

* Method declaration

*/

public void train()

{

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

// set the desired outputs

desiredOutputSynapse.setInputArray(desiredOutputArray);

desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

// get the monitor object to train or feed forward

Monitor monitor = nnet.getMonitor();

// set the monitor parameters

monitor.setLearningRate(0.8);

monitor.setMomentum(0.3);

monitor.setTrainingPatterns(inputArray.length);

monitor.setTotCicles(5000);

monitor.setLearning(true);

long initms = System.currentTimeMillis();

// Run the network in single-thread, synchronized mode

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go(true);

System.out.println(" Total time=  "

+ (System.currentTimeMillis() - initms) + "  ms ");

}

private void interrogate()

{

double[][] inputArray = new double[][]

{

{ 1.0, 1.0 } };

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

Monitor monitor = nnet.getMonitor();

monitor.setTrainingPatterns(4);

monitor.setTotCicles(1);

monitor.setLearning(false);

MemoryOutputSynapse memOut = new MemoryOutputSynapse();

// set the output synapse to write the output of the net

if (nnet != null)

{

nnet.addOutputSynapse(memOut);

System.out.println(nnet.check());

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go();

for (int i = 0; i

{

double[] pattern = memOut.getNextPattern();

System.out.println(" Output pattern # " + (i + 1) + " = "

+ pattern[0]);

}

System.out.println(" Interrogating Finished ");

}

}

/**

* Method declaration

*/

protected void initNeuralNet()

{

// First create the three layers

input = new LinearLayer();

hidden = new SigmoidLayer();

output = new SigmoidLayer();

// set the dimensions of the layers

input.setRows(2);

hidden.setRows(3);

output.setRows(1);

input.setLayerName(" L.input ");

hidden.setLayerName(" L.hidden ");

output.setLayerName(" L.output ");

// Now create the two Synapses

FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */

FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */

// Connect the input layer whit the hidden layer

input.addOutputSynapse(synapse_IH);

hidden.addInputSynapse(synapse_IH);

// Connect the hidden layer whit the output layer

hidden.addOutputSynapse(synapse_HO);

output.addInputSynapse(synapse_HO);

// the input to the neural net

inputSynapse = new MemoryInputSynapse();

input.addInputSynapse(inputSynapse);

// The Trainer and its desired output

desiredOutputSynapse = new MemoryInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(desiredOutputSynapse);

// Now we add this structure to a NeuralNet object

nnet = new NeuralNet();

nnet.addLayer(input, NeuralNet.INPUT_LAYER);

nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);

nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);

nnet.setTeacher(trainer);

output.addOutputSynapse(trainer);

nnet.addNeuralNetListener(this);

}

public void cicleTerminated(NeuralNetEvent e)

{

}

public void errorChanged(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

if (mon.getCurrentCicle() % 100 == 0)

System.out.println(" Epoch:  "

+ (mon.getTotCicles() - mon.getCurrentCicle()) + "  RMSE: "

+ mon.getGlobalError());

}

public void netStarted(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

System.out.print(" Network started for  ");

if (mon.isLearning())

System.out.println(" training. ");

else

System.out.println(" interrogation. ");

}

public void netStopped(NeuralNetEvent e)

{

Monitor mon = (Monitor) e.getSource();

System.out.println(" Network stopped. Last RMSE= "

+ mon.getGlobalError());

}

public void netStoppedError(NeuralNetEvent e, String error)

{

System.out.println(" Network stopped due the following error:  "

+ error);

}

}

现在我会逐步解释上面的程序。

【1】 从main方法开始说起,首先第一步新建一个对象:

XOR_using_NeuralNet xor = new XOR_using_NeuralNet();

【2】然后初始化神经网络:

xor.initNeuralNet();

初始化神经网络的方法中:

// First create the three layers

input = new LinearLayer();

hidden = new SigmoidLayer();

output = new SigmoidLayer();

// set the dimensions of the layers

input.setRows(2);

hidden.setRows(3);

output.setRows(1);

input.setLayerName(" L.input ");

hidden.setLayerName(" L.hidden ");

output.setLayerName(" L.output ");

上面代码解释:

input=new LinearLayer()是新建一个输入层,因为神经网络的输入层并没有训练参数,所以使用的是线性层;

hidden = new SigmoidLayer();这里是新建一个隐含层,使用sigmoid函数作为激励函数,当然你也可以选择其他的激励函数,如softmax激励函数

output则是新建一个输出层

之后的三行代码是建立输入层、隐含层、输出层的神经元个数,这里表示输入层为2个神经元,隐含层是3个神经元,输出层是1个神经元

最后的三行代码是给每个输出层取一个名字。

// Now create the two Synapses

FullSynapse synapse_IH = new FullSynapse(); /* input -> hidden conn. */

FullSynapse synapse_HO = new FullSynapse(); /* hidden -> output conn. */

// Connect the input layer whit the hidden layer

input.addOutputSynapse(synapse_IH);

hidden.addInputSynapse(synapse_IH);

// Connect the hidden layer whit the output layer

hidden.addOutputSynapse(synapse_HO);

output.addInputSynapse(synapse_HO);

上面代码解释:

上面代码的主要作用是将三个层连接起来,synapse_IH用来连接输入层和隐含层,synapse_HO用来连接隐含层和输出层

// the input to the neural net

inputSynapse = new MemoryInputSynapse();

input.addInputSynapse(inputSynapse);

// The Trainer and its desired output

desiredOutputSynapse = new MemoryInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(desiredOutputSynapse);

上面代码解释:

上面的代码是在训练的时候指定输入层的数据和目的输出的数据,

inputSynapse = new MemoryInputSynapse();这里指的是使用了从内存中输入数据的方法,指的是输入层输入数据,当然还有从文件输入的方法,这点在文章后面再谈。同理,desiredOutputSynapse = new MemoryInputSynapse();也是从内存中输入数据,指的是从输入层应该输出的数据

// Now we add this structure to a NeuralNet object

nnet = new NeuralNet();

nnet.addLayer(input, NeuralNet.INPUT_LAYER);

nnet.addLayer(hidden, NeuralNet.HIDDEN_LAYER);

nnet.addLayer(output, NeuralNet.OUTPUT_LAYER);

nnet.setTeacher(trainer);

output.addOutputSynapse(trainer);

nnet.addNeuralNetListener(this);

上面代码解释:

这段代码指的是将之前初始化的构件连接成一个神经网络,NeuralNet是JOONE提供的类,主要是连接各个神经层,最后一个nnet.addNeuralNetListener(this);这个作用是对神经网络的训练过程进行监听,因为这个类实现了NeuralNetListener这个接口,这个接口有一些方法,可以实现观察神经网络训练过程,有助于参数调整。

【3】然后我们来看一下train这个方法:

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

// set the desired outputs

desiredOutputSynapse.setInputArray(desiredOutputArray);

desiredOutputSynapse.setAdvancedColumnSelector(" 1 ");

上面代码解释:

inputSynapse.setInputArray(inputArray);这个方法是初始化输入层数据,也就是指定输入层数据的内容,inputArray是程序中给定的二维数组,这也就是为什么之前初始化神经网络的时候使用的是MemoryInputSynapse,表示从内存中读取数据

inputSynapse.setAdvancedColumnSelector(" 1,2 ");这个表示的是输入层数据使用的是inputArray的前两列数据。

desiredOutputSynapse这个也同理

Monitor monitor = nnet.getMonitor();

// set the monitor parameters

monitor.setLearningRate(0.8);

monitor.setMomentum(0.3);

monitor.setTrainingPatterns(inputArray.length);

monitor.setTotCicles(5000);

monitor.setLearning(true);

上面代码解释:

这个monitor类也是JOONE框架提供的,主要是用来调节神经网络的参数,monitor.setLearningRate(0.8);是用来设置神经网络训练的步长参数,步长越大,神经网络梯度下降的速度越快,monitor.setTrainingPatterns(inputArray.length);这个是设置神经网络的输入层的训练数据大小size,这里使用的是数组的长度;monitor.setTotCicles(5000);这个指的是设置迭代数目;monitor.setLearning(true);这个true表示是在训练过程。

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go(true);

上面代码解释:

nnet.getMonitor().setSingleThreadMode(singleThreadMode);这个指的是是不是使用多线程,但是我不太清楚这里的多线程指的是什么意思

nnet.go(true)表示的是开始训练。

【4】最后来看一下interrogate方法

double[][] inputArray = new double[][]

{

{ 1.0, 1.0 } };

// set the inputs

inputSynapse.setInputArray(inputArray);

inputSynapse.setAdvancedColumnSelector(" 1,2 ");

Monitor monitor = nnet.getMonitor();

monitor.setTrainingPatterns(4);

monitor.setTotCicles(1);

monitor.setLearning(false);

MemoryOutputSynapse memOut = new MemoryOutputSynapse();

// set the output synapse to write the output of the net

if (nnet != null)

{

nnet.addOutputSynapse(memOut);

System.out.println(nnet.check());

nnet.getMonitor().setSingleThreadMode(singleThreadMode);

nnet.go();

for (int i = 0; i

{

double[] pattern = memOut.getNextPattern();

System.out.println(" Output pattern # " + (i + 1) + " = "

+ pattern[0]);

}

System.out.println(" Interrogating Finished ");

}

这个方法相当于测试方法,这里的inputArray是测试数据, 注意这里需要设置monitor.setLearning(false);,因为这不是训练过程,并不需要学习,monitor.setTrainingPatterns(4);这个是指测试的数量,4表示有4个测试数据(虽然这里只有一个)。这里还给nnet添加了一个输出层数据对象,这个对象mmOut是初始测试结果,注意到之前我们初始化神经网络的时候并没有给输出层指定数据对象,因为那个时候我们在训练,而且指定了trainer作为目的输出。

接下来就是输出结果数据了,pattern的个数和输出层的神经元个数一样大,这里输出层神经元的个数是1,所以pattern大小为1.

【5】我们看一下测试结果:

Output pattern # 1 = 0.018303527517809233

表示输出结果为0.01,根据sigmoid函数特性,我们得到的输出是0,和预期结果一致。如果输出层神经元个数大于1,那么输出值将会有多个,因为输出层结果是0|1离散值,所以我们取输出最大的那个神经元的输出值取为1,其他为0

【6】最后我们来看一下神经网络训练过程中的一些监听函数:

cicleTerminated:每个循环结束后输出的信息

errorChanged:神经网络错误率变化时候输出的信息

netStarted:神经网络开始运行的时候输出的信息

netStopped:神经网络停止的时候输出的信息

【7】好了,JOONE基本上内容就是这些。还有一些额外东西需要说明:

1,从文件中读取数据构建神经网络

2.如何保存训练好的神经网络到文件夹中,只要测试的时候直接load到内存中就行,而不用每次都需要训练。

【8】先看第一个问题:

从文件中读取数据:

文件的格式:

0;0;0

1;0;1

1;1;0

0;1;1

中间使用分号隔开,使用方法如下,也就是把上文的MemoryInputSynapse换成FileInputSynapse即可。

fileInputSynapse = new FileInputSynapse();

input.addInputSynapse(fileInputSynapse);

fileDisireOutputSynapse = new FileInputSynapse();

TeachingSynapse trainer = new TeachingSynapse();

trainer.setDesired(fileDisireOutputSynapse);

我们看下文件是如何输出数据的:

private File inputFile = new File(Constants.TRAIN_WORD_VEC_PATH);

fileInputSynapse.setInputFile(inputFile);

fileInputSynapse.setFirstCol(2);//使用文件的第2列到第3列作为输出层输入

fileInputSynapse.setLastCol(3);

fileDisireOutputSynapse.setInputFile(inputFile);

fileDisireOutputSynapse.setFirstCol(1);//使用文件的第1列作为输出数据

fileDisireOutputSynapse.setLastCol(1);

其余的代码和上文的是一样的。

【9】然后看第二个问题:

如何保存神经网络

其实很简单,直接序列化nnet对象就行了,然后读取该对象就是java的反序列化,这个就不多做介绍了,比较简单。但是需要说明的是,保存神经网络的时机一定是在神经网络训练完毕后,可以使用下面代码:

public void netStopped(NeuralNetEvent e) {

Monitor mon = (Monitor) e.getSource();

try {

if (mon.isLearning()) {

saveModel(nnet); //序列化对象

}

} catch (IOException ee) {

// TODO Auto-generated catch block

ee.printStackTrace();

}

java版本lstm_LSTM java 实现相关推荐

  1. MAC自带的Java版本以及Java安装目录查看

    MAC自带的Java版本以及Java安装目录查看 电脑版本 macOS Mojave version 10.14 操作 1.Java版本 在Terminal下输入 java -version 如图 2 ...

  2. linux 修改java版本_Linux 有问必答:如何在 Linux 中改变默认的 Java 版本

    提问:当我尝试在Linux中运行一个Java程序时,我遇到了一个错误.看上去像程序编译所使用的Java版本与我本地的不同.我该如何在Linux上切换默认的Java版本? 当Java程序编译时,编译环境 ...

  3. (004) java后台开发之Eclipse(Neon) 版本安装Java EE插件

    Eclipse Neon Java版本安装Java EE插件 Help→Install New Software 地址:Neon - http://download.eclipse.org/relea ...

  4. java安装版本哪种好_我怎么知道我安装了哪个版本的Java?

    问题描述 我想开始玩java(最终到了可以为android或web编写基本小程序的地步),但是我已经在我的计算机上(从过去的实验中)弄糟了java. 我不确定我拥有哪个版本的Java,并且想知道是否有 ...

  5. 安装了多个java 如何切换java版本

    安装了多个java 如何切换java版本 问题描述 平常用的是java8,最近在学习java的新特性.这就需要从java8往更高的java版本切换.由于还在使用java8,测试完新特性后我需要再切换回 ...

  6. Java 版本变更列表 (Java SE 8 ~ Java SE 18 )

    java 版本变更列表 Java SE 8 Java SE 9 Java SE 10 Java SE 11 Java SE 12 Java SE 13 Java SE 14 Java SE 15 Ja ...

  7. java检查版本_如何检查Java版本

    java检查版本 Sometimes we need to check the Java version while executing a java program. We could be dev ...

  8. 3.1_2 JavaSE入门 P1 【Java基础】Java语言概述、JDK编译

    相关链接 Excel目录 目录 Part1 Java语言概述 1 Java语言概述 1.1 Java发展史 1.2 java应用平台 1.3 跨平台原理 1.4 JVM JRE JDK 1.5 Ora ...

  9. Unable to make protected final java.lang.Class java.lang.ClassLoader.defineClass

    从 Github 中下载 Demo 时遇到问题 java.lang.IllegalStateException: Cannot load configuration class: com.cxytia ...

最新文章

  1. Android Framework学习总结
  2. 《30天自制操作系统》笔记(01)——hello bitzhuwei’s OS!
  3. aes算法的地位_aes算法最后一轮为什么没有列混淆?
  4. 解决eclipse中Mybatis框架下sql语句执行后控制台不显示日志问题
  5. python (八)迭代器、生成器、列表推导式
  6. sap.ui.layout.form.SimpleForm.prototype
  7. 初级第一旬05— 蓝字观试题
  8. shell脚本--批量测试主机连通性ping IP
  9. 筛选DataTable数据的方法
  10. 产品体验报告-美团APP
  11. CmemDC类 的使用方法
  12. 斐讯K2P路由器设置AP模式(大部分路由器通用
  13. 3小时做完3天工作,她是用了什么办法做到的?
  14. 定积分积分换元之区间再现(a+b-x)+一元微积分
  15. 华为MIB关键字IOD查询地址及方法
  16. 使用python绘制奥运五环
  17. gensim 主题模型 seed
  18. 古月居《ROS入门21讲》零基础学习笔记
  19. 悟空CRM(基于jfinal+vue+ElementUI的前后端分离的开源CRM系统)
  20. 我是CSDN福利狮,今天来给大家发福利!

热门文章

  1. linux 挂载exfat u盘 yum,centos挂载exfat u盘
  2. [蓝桥杯][算法训练VIP]麦森数(Java大数+快速幂)
  3. Hakase and Nano(博弈)
  4. 为什么存png还有白色底_PNG的算法原理
  5. Mac osx系统中virtual box 中的Ubuntu系统的全屏显示问题解决
  6. 【论文阅读】Learning Traffic as Images: A Deep Convolutional ... [将交通作为图像学习: 用于大规模交通网络速度预测的深度卷积神经网络](1)
  7. dataframe 删除首尾空格_你敲空格的速度很快,但女人的手不是用来敲空格的!...
  8. python scatter 简书_【挖掘模型】:Python-DBSCAN算法 - 简书
  9. redis可以存多少条数据_在银行存50万元,一年能有多少利息?不工作可以吗?...
  10. P1518 两只塔姆沃斯牛 The Tamworth Two(简单的搜索题)