keras手写数字识别--入门
程序
由于mnist数据集直接使用
(x_train, y_train), (x_test, y_test) = mnist.load_data()
这种加载方式,有时候由于网络原因,很难加载成功。为此,可以直接通过地址其地址下载下来。然后使用numpy加载一下数据就行。
# -*- coding: utf-8 -*-
import keras
# from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop
import matplotlib.pyplot as plt
import numpy as np
batch_size = 128
num_classes = 10
epochs = 20#由于使用程序下载很困难,这里手动下载导入数据
# the data, shuffled and split between train and test sets
# (x_train, y_train), (x_test, y_test) = mnist.load_data()path='F:/program_work/python_work/KerasTest/data/mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()x_train = x_train.reshape(60000, 784).astype('float32')
x_test = x_test.reshape(10000, 784).astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')# convert class vectors to binary class matrices
# label为0~9共10个类别,keras要求格式为binary class matricesy_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 全连接模型
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))model.summary()#损失函数使用交叉熵
model.compile(loss='categorical_crossentropy',optimizer=RMSprop(),metrics=['accuracy'])
#模型估计
model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Total loss on Test Set:', score[0])
print('Accuracy of Testing Set:', score[1])#预测
result = model.predict_classes(x_test)
correct_indices = np.nonzero(result == y_test)[0]
incorrect_indices = np.nonzero(result != y_test)[0]
plt.figure()
for i, correct in enumerate(correct_indices[:9]):plt.subplot(3,3,i+1)plt.imshow(x_test[correct].reshape(28,28), cmap='gray', interpolation='none')plt.title("Predicted {}, Class {}".format(result[correct], y_test[correct]))plt.figure()
for i, incorrect in enumerate(incorrect_indices[:9]):plt.subplot(3,3,i+1)plt.imshow(x_test[incorrect].reshape(28,28), cmap='gray', interpolation='none')plt.title("Predicted {}, Class {}".format(result[incorrect], y_test[incorrect]))plt.show()
上面程序中,我们可以查看一些训练集的例子。如下图所示:
训练结果为:
关于全连接的理解,可以参考李宏毅的ppt。
损失函数通常使用的有以下两种。
对应的程序为:
model.compile(loss='mse',optimizer=RMSprop(),metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',optimizer=RMSprop(),metrics=['accuracy'])
同时,模型的激活函数也有其他的,如ReLU,sigmoid等。对应的程序调整为:
model.add(Dense(num_classes, activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))
优化方式也可以调整为其他的,如Adam()或者SGD()等,对应的程序可以调整为:
model.compile(loss='categorical_crossentropy',optimizer=Adam(),metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',optimizer=SGD(lr=0.1),metrics=['accuracy'])
keras手写数字识别--入门相关推荐
- 深度学习--TensorFlow(项目)Keras手写数字识别
目录 效果展示 基础理论 1.softmax激活函数 2.神经网络 3.隐藏层及神经元最佳数量 一.数据准备 1.载入数据集 2.数据处理 2-1.归一化 2-2.独热编码 二.神经网络拟合 1.搭建 ...
- 从手写数字识别入门深度学习丨MNIST数据集详解
就像无数人从敲下"Hello World"开始代码之旅一样,许多研究员从"MNIST数据集"开启了人工智能的探索之路. MNIST数据集(Mixed Natio ...
- TensorFlow8-mnist手写数字识别入门
分类问题的损失函数为什么一般不用MSE?MSE在逻辑回归中可能具有多个局部最优点 不能用梯度下降算法
- keras从入门到放弃(十三)卷积神经网络处理手写数字识别
今天来一个cnn例子 手写数字识别,因为是图像数据 import keras from keras import layers import numpy as np import matplotlib ...
- keras从入门到放弃(十)手写数字识别训练
导入手写数字识别 import keras from keras import layers import matplotlib.pyplot as plt %matplotlib inline im ...
- 深度学习入门实例——基于keras的mnist手写数字识别
本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...
- Keras搭建CNN(手写数字识别Mnist)
MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...
- MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测
Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...
- 使用tf.keras搭建mnist手写数字识别网络
使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...
最新文章
- 怎么让wordpress用sqlite3 搭建轻量级博客系统
- Linux下使用终端调试Python程序:pudb
- 天翼云从业认证课后习题(3.1天翼云计算产品)
- Linux 查看CPU信息、机器型号等硬件信息
- SAP CRM webclient ui drop down list key mode
- 《vSphere性能设计:性能密集场景下CPU、内存、存储及网络的最佳设计实践》一1.2.2 内存...
- url 参数传递的两种方式_VB编程中的传值与传址两种参数传递方式,你清楚吗?...
- 测试用例组织结构_用例和组织结构
- GridView position = 0重复加载的问题
- 从零开始攻略PHP(9)——错误和异常处理
- Xcode8上传app一直显示正在处理
- MySQL双主机双Master方案测试
- 【首发】'k4' 宏病毒专杀 原创新型excel宏病毒专杀工具
- hualinux 学生党 建议:读书就是为了社会目标做准备
- 台式电脑桌面没有计算机图标,电脑桌面图标全部消失怎么恢复 电脑桌面图标设置随意放置的方法...
- 全面开创城市数字经济新时代
- 学习笔记:图像分割之深度学习场景分割(2015开始)综述之前是手工特征
- Java中的Math函数常用方法都在这里
- 【visual studio】visual studio 2022 无法 复制黏贴
- 解决error ‘XXX‘ is not defined no-undef且项目没有eslintrc.js文件问题
热门文章
- 立冬了 广州还是夏天
- JS 数据结构之旅 :通过JS实现栈、队列、二叉树、二分搜索树、AVL树、Trie树、并查集树、堆
- 你好,了解一下Java 14带来的一系列新功能
- Spring EclipseLink NoSQL - 使用MongoDB和Oracle NoSQL DB构建
- Hystrix面试 - 深入 Hystrix 断路器执行原理
- BGP——BGP优化技术(总结+配置)
- Centos 7 定时关机
- JS天气插件(最全)
- Github+jsDelivr为脚本/图片等静态文件加速的全球CDN
- C#LeetCode刷题之#704-二分查找(Binary Search)