本文实例讲述了Python实现的径向基(RBF)神经网络。分享给大家供大家参考,具体如下:

from numpy import array, append, vstack, transpose, reshape, \

dot, true_divide, mean, exp, sqrt, log, \

loadtxt, savetxt, zeros, frombuffer

from numpy.linalg import norm, lstsq

from multiprocessing import Process, Array

from random import sample

from time import time

from sys import stdout

from ctypes import c_double

from h5py import File

def metrics(a, b):

return norm(a - b)

def gaussian (x, mu, sigma):

return exp(- metrics(mu, x)**2 / (2 * sigma**2))

def multiQuadric (x, mu, sigma):

return pow(metrics(mu,x)**2 + sigma**2, 0.5)

def invMultiQuadric (x, mu, sigma):

return pow(metrics(mu,x)**2 + sigma**2, -0.5)

def plateSpine (x,mu):

r = metrics(mu,x)

return (r**2) * log(r)

class Rbf:

def __init__(self, prefix = 'rbf', workers = 4, extra_neurons = 0, from_files = None):

self.prefix = prefix

self.workers = workers

self.extra_neurons = extra_neurons

# Import partial model

if from_files is not None:

w_handle = self.w_handle = File(from_files['w'], 'r')

mu_handle = self.mu_handle = File(from_files['mu'], 'r')

sigma_handle = self.sigma_handle = File(from_files['sigma'], 'r')

self.w = w_handle['w']

self.mu = mu_handle['mu']

self.sigmas = sigma_handle['sigmas']

self.neurons = self.sigmas.shape[0]

def _calculate_error(self, y):

self.error = mean(abs(self.os - y))

self.relative_error = true_divide(self.error, mean(y))

def _generate_mu(self, x):

n = self.n

extra_neurons = self.extra_neurons

# TODO: Make reusable

mu_clusters = loadtxt('clusters100.txt', delimiter='\t')

mu_indices = sample(range(n), extra_neurons)

mu_new = x[mu_indices, :]

mu = vstack((mu_clusters, mu_new))

return mu

def _calculate_sigmas(self):

neurons = self.neurons

mu = self.mu

sigmas = zeros((neurons, ))

for i in xrange(neurons):

dists = [0 for _ in xrange(neurons)]

for j in xrange(neurons):

if i != j:

dists[j] = metrics(mu[i], mu[j])

sigmas[i] = mean(dists)* 2

# max(dists) / sqrt(neurons * 2))

return sigmas

def _calculate_phi(self, x):

C = self.workers

neurons = self.neurons

mu = self.mu

sigmas = self.sigmas

phi = self.phi = None

n = self.n

def heavy_lifting(c, phi):

s = jobs[c][1] - jobs[c][0]

for k, i in enumerate(xrange(jobs[c][0], jobs[c][1])):

for j in xrange(neurons):

# phi[i, j] = metrics(x[i,:], mu[j])**3)

# phi[i, j] = plateSpine(x[i,:], mu[j]))

# phi[i, j] = invMultiQuadric(x[i,:], mu[j], sigmas[j]))

phi[i, j] = multiQuadric(x[i,:], mu[j], sigmas[j])

# phi[i, j] = gaussian(x[i,:], mu[j], sigmas[j]))

if k % 1000 == 0:

percent = true_divide(k, s)*100

print(c, ': {:2.2f}%'.format(percent))

print(c, ': Done')

# distributing the work between 4 workers

shared_array = Array(c_double, n * neurons)

phi = frombuffer(shared_array.get_obj())

phi = phi.reshape((n, neurons))

jobs = []

workers = []

p = n / C

m = n % C

for c in range(C):

jobs.append((c*p, (c+1)*p + (m if c == C-1 else 0)))

worker = Process(target = heavy_lifting, args = (c, phi))

workers.append(worker)

worker.start()

for worker in workers:

worker.join()

return phi

def _do_algebra(self, y):

phi = self.phi

w = lstsq(phi, y)[0]

os = dot(w, transpose(phi))

return w, os

# Saving to HDF5

os_h5 = os_handle.create_dataset('os', data = os)

def train(self, x, y):

self.n = x.shape[0]

## Initialize HDF5 caches

prefix = self.prefix

postfix = str(self.n) + '-' + str(self.extra_neurons) + '.hdf5'

name_template = prefix + '-{}-' + postfix

phi_handle = self.phi_handle = File(name_template.format('phi'), 'w')

os_handle = self.w_handle = File(name_template.format('os'), 'w')

w_handle = self.w_handle = File(name_template.format('w'), 'w')

mu_handle = self.mu_handle = File(name_template.format('mu'), 'w')

sigma_handle = self.sigma_handle = File(name_template.format('sigma'), 'w')

## Mu generation

mu = self.mu = self._generate_mu(x)

self.neurons = mu.shape[0]

print('({} neurons)'.format(self.neurons))

# Save to HDF5

mu_h5 = mu_handle.create_dataset('mu', data = mu)

## Sigma calculation

print('Calculating Sigma...')

sigmas = self.sigmas = self._calculate_sigmas()

# Save to HDF5

sigmas_h5 = sigma_handle.create_dataset('sigmas', data = sigmas)

print('Done')

## Phi calculation

print('Calculating Phi...')

phi = self.phi = self._calculate_phi(x)

print('Done')

# Saving to HDF5

print('Serializing...')

phi_h5 = phi_handle.create_dataset('phi', data = phi)

del phi

self.phi = phi_h5

print('Done')

## Algebra

print('Doing final algebra...')

w, os = self.w, _ = self._do_algebra(y)

# Saving to HDF5

w_h5 = w_handle.create_dataset('w', data = w)

os_h5 = os_handle.create_dataset('os', data = os)

## Calculate error

self._calculate_error(y)

print('Done')

def predict(self, test_data):

mu = self.mu = self.mu.value

sigmas = self.sigmas = self.sigmas.value

w = self.w = self.w.value

print('Calculating phi for test data...')

phi = self._calculate_phi(test_data)

os = dot(w, transpose(phi))

savetxt('iok3834.txt', os, delimiter='\n')

return os

@property

def summary(self):

return '\n'.join( \

['-----------------',

'Training set size: {}'.format(self.n),

'Hidden layer size: {}'.format(self.neurons),

'-----------------',

'Absolute error : {:02.2f}'.format(self.error),

'Relative error : {:02.2f}%'.format(self.relative_error * 100)])

def predict(test_data):

mu = File('rbf-mu-212243-2400.hdf5', 'r')['mu'].value

sigmas = File('rbf-sigma-212243-2400.hdf5', 'r')['sigmas'].value

w = File('rbf-w-212243-2400.hdf5', 'r')['w'].value

n = test_data.shape[0]

neur = mu.shape[0]

mu = transpose(mu)

mu.reshape((n, neur))

phi = zeros((n, neur))

for i in range(n):

for j in range(neur):

phi[i, j] = multiQuadric(test_data[i,:], mu[j], sigmas[j])

os = dot(w, transpose(phi))

savetxt('iok3834.txt', os, delimiter='\n')

return os

希望本文所述对大家Python程序设计有所帮助。

python中rbf神经网络包_Python实现的径向基(RBF)神经网络示例相关推荐

  1. python中get函数作用_python get函数有什么作用?示例解析

    这篇文章之中我们来了解一下关于python字典之中的pythonget函数的相关知识,get函数是什么意思,他有什么作用都将会在接下来的文章之中得到解答. 描述 Python 字典(Dictionar ...

  2. python中change的用法_Python Pandas dataframe.pct_change()用法及代码示例

    Python是进行数据分析的一种出色语言,主要是因为以数据为中心的python软件包具有奇妙的生态系统. Pandas是其中的一种,使导入和分析数据更加容易. Pandas dataframe.pct ...

  3. python中mean的用法_Python Pandas dataframe.mean()用法及代码示例

    Python是进行数据分析的一种出色语言,主要是因为以数据为中心的python软件包具有奇妙的生态系统. Pandas是其中的一种,使导入和分析数据更加容易. Pandas dataframe.mea ...

  4. python中convert函数用法_Python Pandas DataFrame.tz_convert用法及代码示例

    Pandas DataFrame是带有标签轴(行和列)的二维大小可变的,可能是异构的表格数据结构.算术运算在行和列标签上对齐.可以将其视为Series对象的dict-like容器.这是 Pandas ...

  5. python中sinh是什么_Python numpy.sinh()用法及代碼示例

    numpy.sinh(x [,out])= ufunc'sin'):此數學函數可幫助用戶計算所有x(作為數組元素)的雙曲正弦值. 等效於1/2 *(np.exp(x)-np.exp(-x))或-1j ...

  6. python中change的用法_python pandas Series.pct_change用法及代码示例

    当前元素与先前元素之间的百分比变化. 默认情况下,计算与前一行的百分比变化.这在比较元素时间序列中的变化百分比时很有用. 参数: periods:int, 默认为 1形成百分比变化所需的时间. fil ...

  7. python中mean的用法_Python Pandas Series.mean()用法及代码示例

    Pandas 系列是带有轴标签的一维ndarray.标签不必是唯一的,但必须是可哈希的类型.该对象同时支持基于整数和基于标签的索引,并提供了许多方法来执行涉及索引的操作. Pandas Series. ...

  8. python中loc的用法_python pandas Series.loc用法及代码示例

    通过标签或布尔数组访问一组行和列. .loc[]主要基于标签,但也可以与布尔数组一起使用. 允许的输入为: 单个标签,例如5或者'a', (注意5被解释为索引的标签,而不是索引的整数位置). 标签的列 ...

  9. python中使用squarify包可视化treemap图:treemap将分层数据显示为一组嵌套矩形,每一组都用一个矩形表示,该矩形的面积与其值成正比

    python中使用squarify包可视化treemap图:treemap将分层数据显示为一组嵌套矩形,每一组都用一个矩形表示,该矩形的面积与其值成正比 目录

最新文章

  1. double和float计算精度不准的问题
  2. c#下各种数据库操作的封装!(支持ACCESS,SQLSERVER,DB2,ORACLE,MYSQL)(四)
  3. HTML怎么实现字体加粗
  4. 360能删除mysql吗_如何彻底删除MYSQL
  5. 一次编译libmono.so的记录
  6. Android持久化存储(4)greenDAO的使用
  7. rmmod: can't change directory to '3.4.39': No such file or directory 解决方法
  8. java 如何忽略异常_java中如何解决异常
  9. JBPM工作流(七)——详解流程图
  10. 追求卓越追求完美规范学习_追求新的黄金比例
  11. case when 多条件_3年前的设计如今被iPhone强推 PITAKA磁吸生态设计的前瞻性到底有多可怕?...
  12. 亚信安全发布《2022年网络安全发展趋势及十大威胁预测》
  13. ASP.NET MVC 重点教程一周年版 第八回 Helper之演化
  14. App专项测试之弱网测试
  15. 人性化的Ruby计数取值
  16. wpf 如何实现窗口浮动_如何实现工作表数据与UserForm窗口的交互,显示第一条记录...
  17. ios8 gps定位不好用
  18. maven安装配置:报错NB: JAVA_HOME should point to a JDK not a JRE
  19. 《生产实习》实习日志——JAVA大数据工程师
  20. MFC控件 --- 旋转控件

热门文章

  1. 什么是APS生产排程系统?
  2. 腾讯视频全网清晰度提升攻坚战
  3. XSS绕过安全狗WAF
  4. Android多媒体功能开发(11)——使用AudioRecord类录制音频
  5. creo 计算机配置,config配置文件乱码,creo配置文件config
  6. 拆机指点杆小红点的线序及PTPM754DR引脚定义
  7. lua_pcall详解
  8. 以牌照搜题为例,简单分析文字切割与识别部分
  9. 服务器驱动用什么工具_驱动、改向滚筒用什么胶板进行包胶?
  10. c语言连接多个字符串,c语言连接多个字符串(strcat函数实现)