下载地址:DeepLearningToolBox
参考博客原文:https://blog.csdn.net/u010025211/article/details/50582693

1. DBN基础知识

DBN 是由多层 RBM 组成的一个神经网络,它既可以被看作一个生成模型,也可以当作判别模型,其训练过程是:使用非监督贪婪逐层方法去预训练获得权值。

训练过程:

  1. 首先充分训练第一个 RBM;
  2. 固定第一个 RBM 的权重和偏移量,然后使用其隐性神经元的状态,作为第二个 RBM 的输入向量;
  3. 充分训练第二个 RBM 后,将第二个 RBM 堆叠在第一个 RBM 的上方;
  4. 重复以上三个步骤任意多次;
  5. 如果训练集中的数据有标签,那么在顶层的 RBM 训练时,这个 RBM 的显层中除了显性神经元,还需要有代表分类标签的神经元,一起进行训练:
    a) 假设顶层 RBM 的显层有 500 个显性神经元,训练数据的分类一共分成了 10 类;
    b) 那么顶层 RBM 的显层有 510 个显性神经元,对每一训练训练数据,相应的标签神经元被打开设为 1,而其他的则被关闭设为 0。
  6. DBN 被训练好后如下图: (示意)

    图 1 训练好的深度信念网络。

图中的绿色部分就是在最顶层 RBM 中参与训练的标签。注意调优 (FINE-TUNING) 过程是一个判别模型

调优过程 (Fine-Tuning) :

生成模型使用 Contrastive Wake-Sleep 算法进行调优,其算法过程是:

  1. 除了顶层 RBM,其他层 RBM 的权重被分成向上的认知权重和向下的生成权重;
  2. Wake 阶段:认知过程,通过外界的特征和向上的权重 (认知权重) 产生每一层的抽象表示 (结点状态) ,并且使用梯度下降修改层间的下行权重 (生成权重) 。也就是“如果现实跟我想象的不一样,改变我的权重使得我想象的东西就是这样的”。
  3. Sleep 阶段:生成过程,通过顶层表示 (醒时学得的概念) 和向下权重,生成底层的状态,同时修改层间向上的权重。也就是“如果梦中的景象不是我脑中的相应概念,改变我的认知权重使得这种景象在我看来就是这个概念”。

使用过程 :

  1. 使用随机隐性神经元状态值,在顶层 RBM 中进行足够多次的吉布斯抽样;
  2. 向下传播,得到每层的状态。

二、代码部分

test_example_DBN

%%  ex1 train a 100 hidden unit RBM and visualize its weights
rand('state',0)
dbn.sizes = [100];
opts.numepochs =   1;
opts.batchsize = 100;
opts.momentum  =   0;
opts.alpha     =   1;
dbn = dbnsetup(dbn, train_x, opts);
dbn = dbntrain(dbn, train_x, opts);
figure; visualize(dbn.rbm{1}.W');   %  Visualize the RBM weights

第一个例子是训练含有100个隐层单元的RBM,然后可视化权重。方法和之前将的训练RBM来降维是类似的。
可视化权重结果:

 %%  ex2 train a 100-100 hidden unit DBN and use its weights to initialize a NNrand('state',0)%train dbndbn.sizes = [100 100];opts.numepochs =   1;opts.batchsize = 100;opts.momentum  =   0;opts.alpha     =   1;dbn = dbnsetup(dbn, train_x, opts);dbn = dbntrain(dbn, train_x, opts);%unfold dbn to nn
nn = dbnunfoldtonn(dbn, 10);
nn.activation_function = 'sigm';%train nn
opts.numepochs =  1;
opts.batchsize = 100;
nn = nntrain(nn, train_x, train_y, opts);
[er, bad] = nntest(nn, test_x, test_y);
assert(er < 0.10, 'Too big error');

dbnsetup

直接分层初始化每一层的rbm(受限波尔兹曼机(Restricted Boltzmann Machines, RBM)), 同样,W,b,c是参数,vW,vb,vc是更新时用到的与momentum的变量

 for u = 1 : numel(dbn.sizes) - 1dbn.rbm{u}.alpha    = opts.alpha;dbn.rbm{u}.momentum = opts.momentum;dbn.rbm{u}.W  = zeros(dbn.sizes(u + 1), dbn.sizes(u));dbn.rbm{u}.vW = zeros(dbn.sizes(u + 1), dbn.sizes(u));dbn.rbm{u}.b  = zeros(dbn.sizes(u), 1);dbn.rbm{u}.vb = zeros(dbn.sizes(u), 1);dbn.rbm{u}.c  = zeros(dbn.sizes(u + 1), 1);dbn.rbm{u}.vc = zeros(dbn.sizes(u + 1), 1);end

dbntrain

function dbn = dbntrain(dbn, x, opts)n = numel(dbn.rbm);//对每一层的rbm进行训练dbn.rbm{1} = rbmtrain(dbn.rbm{1}, x, opts);for i = 2 : nx = rbmup(dbn.rbm{i - 1}, x);dbn.rbm{i} = rbmtrain(dbn.rbm{i}, x, opts); end
end

首先映入眼帘的是对第一层进行rbmtrain(),后面每一层在train之前用了rbmup, rbmup其实就是简单的一句sigm(repmat(rbm.c’, size(x, 1), 1) + x * rbm.W’);也就是上面那张图从v到h计算一次,公式是Wx+c.

rbmtrain

   for i = 1 : opts.numepochs //迭代次数kk = randperm(m);err = 0;for l = 1 : numbatchesbatch = x(kk((l - 1) * opts.batchsize + 1 : l * opts.batchsize), :);v1 = batch;h1 = sigmrnd(repmat(rbm.c', opts.batchsize, 1) + v1 * rbm.W');            //gibbs sampling的过程v2 = sigmrnd(repmat(rbm.b', opts.batchsize, 1) + h1 * rbm.W);h2 = sigm(repmat(rbm.c', opts.batchsize, 1) + v2 * rbm.W');//Contrastive Divergence 的过程 //这和《Learning Deep Architectures for AI》里面写cd-1的那段pseudo code是一样的c1 = h1' * v1;c2 = h2' * v2;//关于momentum,请参看Hinton的《A Practical Guide to Training Restricted Boltzmann Machines》//它的作用是记录下以前的更新方向,并与现在的方向结合下,跟有可能加快学习的速度rbm.vW = rbm.momentum * rbm.vW + rbm.alpha * (c1 - c2)     / opts.batchsize;    rbm.vb = rbm.momentum * rbm.vb + rbm.alpha * sum(v1 - v2)' / opts.batchsize;rbm.vc = rbm.momentum * rbm.vc + rbm.alpha * sum(h1 - h2)' / opts.batchsize;//更新值rbm.W = rbm.W + rbm.vW;rbm.b = rbm.b + rbm.vb;rbm.c = rbm.c + rbm.vc;err = err + sum(sum((v1 - v2) .^ 2)) / opts.batchsize;end
end

dbnunfoldtonn

DBN的每一层训练完成后自然还要把参数传递给一个大的NN,这就是这个函数的作用.在这里DBN就相当于预训练网络,然后将训练好的参数赋给NN结构。

function nn = dbnunfoldtonn(dbn, outputsize)
%DBNUNFOLDTONN Unfolds a DBN to a NN
%   outputsize是你的目标输出label,比如在MINST就是10,DBN只负责学习feature
%   或者说初始化Weight,是一个unsupervised learning,最后的supervised还得靠NNif(exist('outputsize','var'))size = [dbn.sizes outputsize];elsesize = [dbn.sizes];endnn = nnsetup(size);%把每一层展开后的Weight拿去初始化NN的Weight%注意dbn.rbm{i}.c拿去初始化了bias项的值for i = 1 : numel(dbn.rbm)nn.W{i} = [dbn.rbm{i}.c dbn.rbm{i}.W];end
end

最后用NN来train(fine-tune)就可以了。只要理解了多层RBM,DBN就不是问题了。

深度置信网络基础知识及程序代码相关推荐

  1. 【总结】关于玻尔兹曼机(BM)、受限玻尔兹曼机(RBM)、深度玻尔兹曼机(DBM)、深度置信网络(DBN)理论总结和代码实践

    近期学习总结 前言 玻尔兹曼机(BM) 波尔兹曼分布推导过程 吉布斯采样 受限玻尔兹曼机(RBM) 能量函数 CD学习算法 代码实现受限玻尔兹曼机 深度玻尔兹曼机(DBM) 代码实现深度玻尔兹曼机 深 ...

  2. 【零散知识】受限波兹曼机(restricted Boltzmann machine,RBM)和深度置信网络(deep belief network,DBN)

    前言: { 最近一直在想要不要去线下的英语学习机构学英语 (本人的英语口语能力实在是低).如果我想完成今年的年度计划,那么今年就没时间学英语了. 这次的内容是之前落下的深度置信网络(deep beli ...

  3. 深度学习基础--不同网络种类--深度置信网络(DBN)

    深度置信网络(DBN)   RBM的作用就是用来生成似然分布的互补先验分布,使得其后验分布具有因子形式.   因此,DBN算法解决了Wake-Sleep算法表示分布难以匹配生成分布的难题,通过RBM使 ...

  4. 计算机基础知识对程序员来说有多重要?

    数据结构和算法,操作系统,编译原理,计算机组成原理这些课程对普通程序员来说是否需要去学习?会带来哪些帮助? 我们依次来了解这几门课程是在工作中有啥用,回答有点长,请保持耐心:) 一.数据结构与算法 正 ...

  5. 基于深度学习的安卓恶意应用检测----------android manfest.xml + run time opcode, use 深度置信网络(DBN)...

    基于深度学习的安卓恶意应用检测 from:http://www.xml-data.org/JSJYY/2017-6-1650.htm 苏志达, 祝跃飞, 刘龙     摘要: 针对传统安卓恶意程序检测 ...

  6. 收藏100个网络基础知识

    100 个网络基础知识普及,看完成半个网络高手! 1)什么是链接? 链接是指两个设备之间的连接.它包括用于一个设备能够与另一个设备通信的电缆类型和协议. 2)OSI 参考模型的层次是什么? 有 7 个 ...

  7. 第二十六期:100 个网络基础知识普及,看完成半个网络高手

    本篇文章是关于100个网络基础知识普及,看完成半个网络高手!下面,我们一起来看. 作者:佚名来源 本篇文章是关于100个网络基础知识普及,看完成半个网络高手!下面,我们一起来看. 1)什么是链接? 链 ...

  8. 计算机知识太多了,计算机基础知识对程序员来说有多重要?

    原标题:计算机基础知识对程序员来说有多重要? 科班和培训生同比于自学者的优势就在于这些计算机专业的核心课程(数据结构与算法这种不管科班培训都要学的不算):离散数学.编译原理.计算机组成原理.操作与系统 ...

  9. Python 3深度置信网络(DBN)在Tensorflow中的实现MNIST手写数字识别

    任何程序错误,以及技术疑问或需要解答的,请扫码添加作者VX:1755337994 使用DBN识别手写体 传统的多层感知机或者神经网络的一个问题: 反向传播可能总是导致局部最小值. 当误差表面(erro ...

  10. 深度学习 --- 基于RBM的深度置信网络DBN-DNN详解

    上一节我们详细的讲解了受限玻尔兹曼机RBM的原理,详细阐述了该网络的快速学习原理以及算法过程,不懂受限玻尔兹曼机的建议先好好理解上一节的内容,本节主要讲解的是使用RBM组成深层神经网络的深度置信网络D ...

最新文章

  1. blktrace 工具集使用 及其实现原理
  2. [自带避雷针]DropShadowEffect导致内存暴涨
  3. scala成长之路(2)对象和类
  4. python使用界面-如何使用Python建立有窗口、按钮之类的图形界面
  5. myeclipse导入外部javaweb项目
  6. git bash学习3 -简单杂乱知识点记录
  7. 编译原理词/语法分析
  8. Ansible自动化运维基础-------ad-hoc
  9. 去java文件 注释_去除java文件中注释部分
  10. android 华为手机灭屏搜索不到蓝牙_华为Mate 30更新EMUI10.1.0.132版本,新增10项实用功能...
  11. 腾讯云mysql升级失败怎么办_本地连接腾讯云Mysql失败问题
  12. Java中的Flyweight设计模式
  13. 用ES6 Generator替代回调函数
  14. 企业文化海报设计模板,企业文化经典标语挂图素材
  15. 【C语言】C语言实现strcmp库函数
  16. Cell:清华程功组揭示皮肤菌群的一种气味挥发物促进黄病毒感染宿主吸引蚊虫...
  17. token什么意思中文在C语言中,token是什么意思_token中文意思_token英译汉_英汉词典...
  18. pid倒立摆matlab,基于MATLAB的直线一级倒立摆的PID控制研究
  19. HZNU2012图解
  20. HBuilderXHBuilder连接雷电模拟器

热门文章

  1. html5中左浮动怎么写代码,html浮动详解(代码实例)
  2. win10环境下matlab2017b编译运行c++文件步骤
  3. 联想 Thinkserver TS250服务器RAID1 重建测试
  4. IOS测试的一般流程和注意事项
  5. Redis 实战场景
  6. php安装libpng,求助:libpng编译问题
  7. java学习练习预埋件配筋计算
  8. Android 使用gson完成Json转map,json转单个对象,json转数组
  9. jspstudy oracle,SQL不走索引的几种常见情况
  10. STM32官方USB库下载指南