DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别

目录

输出结果

实现代码


输出结果

实现代码

from __future__ import print_function
print(__doc__)import numpy as np
import matplotlib.pyplot as plt  from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.cross_validation import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline            def nudge_dataset(X, Y):
direction_vectors = [[[0, 1, 0],[0, 0, 0],[0, 0, 0]],[[0, 0, 0],[1, 0, 0],[0, 0, 0]],[[0, 0, 0],[0, 0, 1],[0, 0, 0]],[[0, 0, 0],[0, 0, 0],[0, 1, 0]]]shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',weights=w).ravel()
X = np.concatenate([X] +
[np.apply_along_axis(shift, 1, X, vector)
for vector in direction_vectors])
Y = np.concatenate([Y for _ in range(5)], axis=0)
return X, Ydigits = datasets.load_digits()
X = np.asarray(digits.data, 'float32')
X, Y = nudge_dataset(X, digits.target)
X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001)X_train, X_test, Y_train, Y_test = train_test_split(X, Y,test_size=0.2,random_state=0) logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True)classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)]) rbm.learning_rate = 0.06
rbm.n_iter = 20
# More components tend to give better prediction performance, but larger fitting time
rbm.n_components = 100
logistic.C = 6000.0classifier.fit(X_train, Y_train)  logistic_classifier = linear_model.LogisticRegression(C=100.0)
logistic_classifier.fit(X_train, Y_train)print()
print("Logistic regression using RBM features:\n%s\n" % (metrics.classification_report(Y_test,classifier.predict(X_test)  )))print("Logistic regression using raw pixel features:\n%s\n" % (
metrics.classification_report(
Y_test,
logistic_classifier.predict(X_test))))plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
plt.subplot(10, 10, i + 1)
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle('100 components extracted by RBM', fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)plt.show()

相关文章
DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别

DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别相关推荐

  1. DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率

    DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率 目录 输出结果 核心代码 输出结果 核心代码 #DL之NN:基于sklearn ...

  2. DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练、预测(95%)

    DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练.预测(95%) 目录 数据集展示 输出结果 设计代码 数据集展示 先查看sklearn自带di ...

  3. TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线

    TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...

  4. DL之RBM:基于RBM实现手写数字图片识别提高准确率

    DL之RBM:基于RBM实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 import numpy as np import matplotlib.pyplot as pl ...

  5. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...

  6. DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测

    DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...

  7. DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化

    DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...

  8. DL之NN:利用(本地数据集50000张数据集)调用自定义神经网络network.py实现手写数字图片识别94%准确率

    DL之NN:利用(本地数据集50000张数据集)调用自定义神经网络network.py实现手写数字图片识别94%准确率 目录 输出结果 代码设计 输出结果 更新-- 代码设计 import mnist ...

  9. DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率

    DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率 目录 设计思路 设计代码 设计思路 设计代码 import mn ...

最新文章

  1. 布线技术不断演进满足快速增长的网络需求
  2. Android性能优化——界面流畅度优化
  3. 业务逻辑全写在sql_12306的业务逻辑很复杂么?一条SQL语句搞不定?
  4. 期货市场技术分析06_长期图表和商品指数
  5. 当 AI 遇见经典,科大讯飞发布两款智能笔记本新品!
  6. OpenCV stereo matching 代码
  7. 安卓控件显示等宽字体的办法
  8. 海康/大华实现web直播和回放,也可以直接对接摄像头
  9. npm查找依赖包版本
  10. Linux下tomcat修改端口(80)
  11. 什么是模式识别,模式识别概念的基本介绍
  12. 淘宝数据分享平台战略
  13. 【pandas】--DataFrame数据筛选(二)
  14. 为什么阿里会选择 Flink 作为新一代流式计算引擎?
  15. 【Redis】概述以及启动Redis并进入Redis
  16. [2017纪中10-26]摘Galo 树型背包
  17. Python小爬虫:爬取开心网日记,乐趣无穷
  18. 重学JavaSE 第4章 : 顺序结构、分支语句、循环结构、break, continue, return区别
  19. 计算机可视化视景仿真,计算与仿真、三维设计、图像处理、视景仿真、4k视频剪辑工作站介绍2015版.pptx...
  20. “如影计划” 不带手机也能随身带支付宝

热门文章

  1. Python- 反射 及部份内置属性方法
  2. ssdb主从及双主模型配置和简单管理
  3. Cesium源码编译过程
  4. 说实话,Hibernate 和 MyBatis 哪个更好用?
  5. 又臭又长!流着泪我也要把它给改完!
  6. 面试官:聊聊微信和淘宝扫码登录背后的实现原理?
  7. 2019年,被高估的AI与数据科学该如何发展?
  8. 10个Eclipse珍藏插件推荐
  9. 刚换工作,记录下心得
  10. Java 洛谷 P1008 三连击