# -*- coding: utf-8 -*-
'''
Created on 2016年4月1日@author: LIU
'''
import sys
import numpy
import matplotlib.pylab as plt
import numpy as np
import random
from scipy.linalg import norm
import PIL.Image
from utils import *class RBM(object):def __init__(self, input=None, n_visible=2, n_hidden=3, \W=None, hbias=None, vbias=None, rng=None):self.n_visible = n_visible  # num of units in visible (input) layerself.n_hidden = n_hidden    # num of units in hidden layerif rng is None:rng = numpy.random.RandomState(1234)if W is None:a = 1. / n_visibleinitial_W = numpy.array(rng.uniform(  # initialize W uniformly(随机生成实数在-a-a之间)low=-a,high=a,size=(n_visible, n_hidden)))W = initial_Wif hbias is None:hbias = numpy.zeros(n_hidden)  # initialize h bias 0if vbias is None:vbias = numpy.zeros(n_visible)  # initialize v bias 0self.rng = rngself.input = inputself.W = Wself.hbias = hbiasself.vbias = vbiasdef contrastive_divergence(self, lr=0.1, k=1, input=None):if input is not None:self.input = input''' CD-ks算法 '''ph_mean, ph_sample = self.sample_h_given_v(self.input)chain_start = ph_sample#实现一步吉布斯采样通过给隐层采样for step in xrange(k):if step == 0:nv_means, nv_samples,\nh_means, nh_samples = self.gibbs_hvh(chain_start)else:nv_means, nv_samples,\nh_means, nh_samples = self.gibbs_hvh(nh_samples)# chain_end = nv_samplesself.W += lr * (numpy.dot(self.input.T, ph_mean)- numpy.dot(nv_samples.T, nh_means))self.vbias += lr * numpy.mean(self.input - nv_samples, axis=0)self.hbias += lr * numpy.mean(ph_mean - nh_means, axis=0)# cost = self.get_reconstruction_cross_entropy()# return cost# 通过给出显层单元推断出隐层单元的    #计算隐层单元的激活率通过给出显层,得到一个采样通过给他们的def sample_h_given_v(self, v0_sample):h1_mean = self.propup(v0_sample)h1_sample = self.rng.binomial(size=h1_mean.shape,   # discrete: binomialn=1,p=h1_mean)return [h1_mean, h1_sample]#一一步吉布斯采样通过从隐层率开始def sample_v_given_h(self, h0_sample):v1_mean = self.propdown(h0_sample)v1_sample = self.rng.binomial(size=v1_mean.shape,   # discrete: binomialn=1,p=v1_mean)return [v1_mean, v1_sample]def propup(self, v):pre_sigmoid_activation = numpy.dot(v, self.W) + self.hbiasreturn sigmoid(pre_sigmoid_activation)def propdown(self, h):pre_sigmoid_activation = numpy.dot(h, self.W.T) + self.vbiasreturn sigmoid(pre_sigmoid_activation)#转换函数主要功能是通过给定的隐层采样来执行cd更新def gibbs_hvh(self, h0_sample):v1_mean, v1_sample = self.sample_v_given_h(h0_sample)h1_mean, h1_sample = self.sample_h_given_v(v1_sample)return [v1_mean, v1_sample,h1_mean, h1_sample]#计算重构误差     def get_reconstruction_cross_entropy(self):pre_sigmoid_activation_h = numpy.dot(self.input, self.W) + self.hbiassigmoid_activation_h = sigmoid(pre_sigmoid_activation_h)pre_sigmoid_activation_v = numpy.dot(sigmoid_activation_h, self.W.T) + self.vbiassigmoid_activation_v = sigmoid(pre_sigmoid_activation_v)cross_entropy =  - numpy.mean(numpy.sum(self.input * numpy.log(sigmoid_activation_v) +(1 - self.input) * numpy.log(1 - sigmoid_activation_v),axis=1))return cross_entropydef reconstruct(self, v):h = sigmoid(numpy.dot(v, self.W) + self.hbias)reconstructed_v = sigmoid(numpy.dot(h, self.W.T) + self.vbias)return reconstructed_vdef readData(path):data = []for line in open(path, 'r'):ele = line.split(' ')tmp = []for e in ele:if e != '':tmp.append(float(e.strip(' ')))data.append(tmp)return data
def test_rbm(learning_rate=0.1, k=1, training_epochs=50):
#     data = numpy.array([[1,1,1,0,0,0],
#                         [1,0,1,0,0,0],
#                         [1,1,1,0,0,0],
#                         [0,0,1,1,1,0],
#                         [0,0,1,1,0,0],
#                         [0,0,1,1,1,0]])data = readData('data.txt')data = np.array(data)data = data.transpose()rng = numpy.random.RandomState(123)# construct RBM
#     rbm = RBM(input=data, n_visible=6, n_hidden=2, rng=rng)rbm = RBM(input=data, n_visible=784, n_hidden=2, rng=rng)# trainfor epoch in xrange(training_epochs):rbm.contrastive_divergence(lr=learning_rate, k=k)cost = rbm.get_reconstruction_cross_entropy()print >> sys.stderr, 'Training epoch %d, cost is ' % epoch, cost# test
#     v = numpy.array([[1, 1, 0, 0, 0, 0],
#                      [0, 0, 0, 1, 1, 0]])v=data[1,:]print rbm.reconstruct(v)if __name__ == "__main__":test_rbm()
# -*- coding: utf-8 -*-
'''
Created on 2016年4月1日@author: LIU
'''
import numpy
numpy.seterr(all='ignore')def sigmoid(x):return 1. / (1 + numpy.exp(-x))def dsigmoid(x):return x * (1. - x)# def tanh(x):
#     return numpy.tanh(x)
#
# def dtanh(x):
#     return 1. - x * x
#
# def softmax(x):
#     e = numpy.exp(x - numpy.max(x))  # prevent overflow
#     if e.ndim == 1:
#         return e / numpy.sum(e, axis=0)
#     else:
#         return e / numpy.array([numpy.sum(e, axis=1)]).T  # ndim = 2
#
#
# def ReLU(x):
#     return x * (x > 0)
#
# def dReLU(x):
#     return 1. * (x > 0)

RBM代码Python相关推荐

  1. python搞笑代码-python有趣代码

    广告关闭 腾讯云11.11云上盛惠 ,精选热门产品助力上云,云服务器首年88元起,买的越多返的越多,最高返5000元! 前言本月将更新八篇python有趣系列文章. 本系列通过多个有趣案例,讲解pyt ...

  2. python一千行入门代码-Python 有哪些一千行左右的经典练手项目?

    谢邀.据我了解,没有千行左右的「经典」练手项目.但是我可以推荐一些练手项目.这些项目来着 教你阅读Python开源项目代码 - Python之美 - 知乎专栏 : 和工作中看别人代码差不多,基本每个人 ...

  3. python基础代码事例-推公式到写代码-python基础

    推公式到写代码-python基础 希望你能像看小说看杂文一样的心情看完这一系列,因为学习不总是枯燥的,希望像聊天一样娓娓道来. 专辑系列的阅读对象是那些懂些高等数学和线性代数,但没有经过编码训练的人. ...

  4. 具体knn算法概念参考knn代码python实现

    具体knn算法概念参考knn代码python实现 上面是参考<机器学习实战>的代码,和knn的思想 # _*_ encoding=utf8 _*_ import numpy as np i ...

  5. 奇异值分解SVD数学原理及代码(Python)

    奇异值分解SVD数学原理及代码(Python) 首先简单介绍一下什么是正交矩阵(酉矩阵) 如果 或 其中,E为单位矩阵,或,则n阶实矩阵A称为正交矩阵.正交矩阵是实数特殊化的酉矩阵,因此总是属于正规矩 ...

  6. 拉格朗日插值代码python实现(不掉包)

    拉格朗日插值代码python实现(不掉包) 今天我们来讲一下,使用拉格朗日插值公式进行插值,通过python实现 那么拉格朗日插值公式是什么样的呢? 百度百科定义如下: 当然如果你没有看懂的话,可以再 ...

  7. OTB官方评估代码python版本--评估自己跟踪器,对比其他跟踪器

    OTB官方评估代码python版本--评估自己跟踪器,对比其他跟踪器 代码环境准备 环境安装 数据集准备 跑自己跟踪器 结果格式准备 生成json文件并画图 Bonus OTB数据集是目标跟踪领域里面 ...

  8. python最简单的爬虫代码,python小实例一简单爬虫

    python新手求助 关于爬虫的简单例子 #coding=utf-8from bs4 import BeautifulSoupwith open('', 'r') as file: fcontent ...

  9. OTB官方评估代码python版本

    可参考:OTB官方评估代码python版本--评估自己跟踪器,对比其他跟踪器 博主写的很好,按照步骤可以运行 以下有几点注意的地方 1.我是用ubuntu系统,创建虚拟环境安装的python=2.7. ...

最新文章

  1. string来存放二进制数据
  2. 5G NR 同步过程
  3. Codeforces Beta Round #11 B. Jumping Jack 思维
  4. Spark之 使用SparkSql操作mysql和DataFrame的Scala实现
  5. linux pe大小,lvm中的pe默认是4M 最大能支持多大 1T?2T
  6. html table列平均,html table 列求和
  7. 实验三 类和对象
  8. ruby nokogiri 数据抓取
  9. .net2.0中对config文件的操作方法总结
  10. VC连接SQL2005
  11. Atitit 高性能架构法艾提拉著作 目录 1. 前期可以立即使用的技术 2 2. 分离法 3 2.1. Web db分离 3 2.2. 读写分离 4 2.3. CDN加速技术 4 2.4. 动静分
  12. 2019华为机试题 消息扩散
  13. 弘辽科技:复购率太低怎么办呢?
  14. mysql中 怎么插入反斜杠_MySQL中如何插入反斜杠,反斜杠被吃掉,反斜杠转义(转)...
  15. python弹性碰撞次数圆周率_期末作业 - 作业部落 Cmd Markdown 编辑阅读器
  16. 数字人正走进现实!AI大脑+高颜值
  17. 硬盘安装Fedora 12
  18. R语言因子型数值转数值型
  19. Codeforces633C Spy Syndrome 2 (单词Trie)
  20. 盘点闪电网络将在2020年爆发的九大理由

热门文章

  1. 语文字典计算机基础术语,2017年北京师范大学汉语文化学院893专业综合三(古代汉语、计算机基础)考研导师圈点必考题汇编...
  2. 主菜单中显示未定义标识符_mfc里提示的未定义标识符
  3. 全解一款六面体结构化网格划分利器-NUMECA IGG
  4. Flink教程(31)- Flink网络流控及反压
  5. 2022年京东春晚摇一摇分15亿红包活动
  6. .net软件工程师面试
  7. 学习OpenCV3:判断两条直线平行,并计算平行距离
  8. MySQL 索引最左匹配原则
  9. 在Wireshark中过滤UDS和OBD诊断ISO13400(DoIP)数据
  10. 软件测试经典面试题总结文库,软件测试经典面试题总结