DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别
DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别
目录
输出结果
实现代码
输出结果
实现代码
from __future__ import print_function
print(__doc__)import numpy as np
import matplotlib.pyplot as plt from scipy.ndimage import convolve
from sklearn import linear_model, datasets, metrics
from sklearn.cross_validation import train_test_split
from sklearn.neural_network import BernoulliRBM
from sklearn.pipeline import Pipeline def nudge_dataset(X, Y):
direction_vectors = [[[0, 1, 0],[0, 0, 0],[0, 0, 0]],[[0, 0, 0],[1, 0, 0],[0, 0, 0]],[[0, 0, 0],[0, 0, 1],[0, 0, 0]],[[0, 0, 0],[0, 0, 0],[0, 1, 0]]]shift = lambda x, w: convolve(x.reshape((8, 8)), mode='constant',weights=w).ravel()
X = np.concatenate([X] +
[np.apply_along_axis(shift, 1, X, vector)
for vector in direction_vectors])
Y = np.concatenate([Y for _ in range(5)], axis=0)
return X, Ydigits = datasets.load_digits()
X = np.asarray(digits.data, 'float32')
X, Y = nudge_dataset(X, digits.target)
X = (X - np.min(X, 0)) / (np.max(X, 0) + 0.0001)X_train, X_test, Y_train, Y_test = train_test_split(X, Y,test_size=0.2,random_state=0) logistic = linear_model.LogisticRegression()
rbm = BernoulliRBM(random_state=0, verbose=True)classifier = Pipeline(steps=[('rbm', rbm), ('logistic', logistic)]) rbm.learning_rate = 0.06
rbm.n_iter = 20
# More components tend to give better prediction performance, but larger fitting time
rbm.n_components = 100
logistic.C = 6000.0classifier.fit(X_train, Y_train) logistic_classifier = linear_model.LogisticRegression(C=100.0)
logistic_classifier.fit(X_train, Y_train)print()
print("Logistic regression using RBM features:\n%s\n" % (metrics.classification_report(Y_test,classifier.predict(X_test) )))print("Logistic regression using raw pixel features:\n%s\n" % (
metrics.classification_report(
Y_test,
logistic_classifier.predict(X_test))))plt.figure(figsize=(4.2, 4))
for i, comp in enumerate(rbm.components_):
plt.subplot(10, 10, i + 1)
plt.imshow(comp.reshape((8, 8)), cmap=plt.cm.gray_r,
interpolation='nearest')
plt.xticks(())
plt.yticks(())
plt.suptitle('100 components extracted by RBM', fontsize=16)
plt.subplots_adjust(0.08, 0.02, 0.92, 0.85, 0.08, 0.23)plt.show()
相关文章
DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别
DL之RBM:(sklearn自带数据集为1797个样本*64个特征+5倍数据集)深度学习之BRBM模型学习+LR进行分类实现手写数字图识别相关推荐
- DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率
DL之NN:基于(sklearn自带手写数字图片识别数据集)+自定义NN类(三层64→100→10)实现97.5%准确率 目录 输出结果 核心代码 输出结果 核心代码 #DL之NN:基于sklearn ...
- DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练、预测(95%)
DL之DNN:利用DNN算法对mnist手写数字图片识别数据集(sklearn自带,1797*64)训练.预测(95%) 目录 数据集展示 输出结果 设计代码 数据集展示 先查看sklearn自带di ...
- TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线
TF之CNN:利用sklearn(自带手写数字图片识别数据集)使用dropout解决学习中overfitting的问题+Tensorboard显示变化曲线 目录 输出结果 设计代码 输出结果 设计代码 ...
- DL之RBM:基于RBM实现手写数字图片识别提高准确率
DL之RBM:基于RBM实现手写数字图片识别提高准确率 目录 输出结果 设计代码 输出结果 设计代码 import numpy as np import matplotlib.pyplot as pl ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Functional)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 下边两张 ...
- DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测
DL之CNN:利用卷积神经网络算法(2→2,基于Keras的API-Sequential)利用MNIST(手写数字图片识别)数据集实现多分类预测 目录 输出结果 设计思路 核心代码 输出结果 1.10 ...
- DL之DNN:利用DNN【784→50→100→10】算法对MNIST手写数字图片识别数据集进行预测、模型优化
DL之DNN:利用DNN[784→50→100→10]算法对MNIST手写数字图片识别数据集进行预测.模型优化 导读 目的是建立三层神经网络,进一步理解DNN内部的运作机制 目录 输出结果 设计思路 ...
- DL之NN:利用(本地数据集50000张数据集)调用自定义神经网络network.py实现手写数字图片识别94%准确率
DL之NN:利用(本地数据集50000张数据集)调用自定义神经网络network.py实现手写数字图片识别94%准确率 目录 输出结果 代码设计 输出结果 更新-- 代码设计 import mnist ...
- DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率
DL之NN/CNN:NN算法进阶优化(本地数据集50000张训练集图片),六种不同优化算法实现手写数字图片识别逐步提高99.6%准确率 目录 设计思路 设计代码 设计思路 设计代码 import mn ...
最新文章
- 布线技术不断演进满足快速增长的网络需求
- Android性能优化——界面流畅度优化
- 业务逻辑全写在sql_12306的业务逻辑很复杂么?一条SQL语句搞不定?
- 期货市场技术分析06_长期图表和商品指数
- 当 AI 遇见经典,科大讯飞发布两款智能笔记本新品!
- OpenCV stereo matching 代码
- 安卓控件显示等宽字体的办法
- 海康/大华实现web直播和回放,也可以直接对接摄像头
- npm查找依赖包版本
- Linux下tomcat修改端口(80)
- 什么是模式识别,模式识别概念的基本介绍
- 淘宝数据分享平台战略
- 【pandas】--DataFrame数据筛选(二)
- 为什么阿里会选择 Flink 作为新一代流式计算引擎?
- 【Redis】概述以及启动Redis并进入Redis
- [2017纪中10-26]摘Galo 树型背包
- Python小爬虫:爬取开心网日记,乐趣无穷
- 重学JavaSE 第4章 : 顺序结构、分支语句、循环结构、break, continue, return区别
- 计算机可视化视景仿真,计算与仿真、三维设计、图像处理、视景仿真、4k视频剪辑工作站介绍2015版.pptx...
- “如影计划” 不带手机也能随身带支付宝