程序

由于mnist数据集直接使用

(x_train, y_train), (x_test, y_test) = mnist.load_data()

这种加载方式,有时候由于网络原因,很难加载成功。为此,可以直接通过地址其地址下载下来。然后使用numpy加载一下数据就行。

# -*- coding: utf-8 -*-
import keras
# from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.optimizers import RMSprop
import matplotlib.pyplot as plt
import numpy as np
batch_size = 128
num_classes = 10
epochs = 20#由于使用程序下载很困难,这里手动下载导入数据
# the data, shuffled and split between train and test sets
# (x_train, y_train), (x_test, y_test) = mnist.load_data()path='F:/program_work/python_work/KerasTest/data/mnist.npz'
f = np.load(path)
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
f.close()x_train = x_train.reshape(60000, 784).astype('float32')
x_test = x_test.reshape(10000, 784).astype('float32')
x_train /= 255
x_test /= 255
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')# convert class vectors to binary class matrices
# label为0~9共10个类别,keras要求格式为binary class matricesy_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 全连接模型
model = Sequential()
model.add(Dense(512, activation='relu', input_shape=(784,)))
model.add(Dropout(0.2))
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(num_classes, activation='softmax'))model.summary()#损失函数使用交叉熵
model.compile(loss='categorical_crossentropy',optimizer=RMSprop(),metrics=['accuracy'])
#模型估计
model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,verbose=1,validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Total loss on Test Set:', score[0])
print('Accuracy of Testing Set:', score[1])#预测
result = model.predict_classes(x_test)
correct_indices = np.nonzero(result == y_test)[0]
incorrect_indices = np.nonzero(result != y_test)[0]
plt.figure()
for i, correct in enumerate(correct_indices[:9]):plt.subplot(3,3,i+1)plt.imshow(x_test[correct].reshape(28,28), cmap='gray', interpolation='none')plt.title("Predicted {}, Class {}".format(result[correct], y_test[correct]))plt.figure()
for i, incorrect in enumerate(incorrect_indices[:9]):plt.subplot(3,3,i+1)plt.imshow(x_test[incorrect].reshape(28,28), cmap='gray', interpolation='none')plt.title("Predicted {}, Class {}".format(result[incorrect], y_test[incorrect]))plt.show()

上面程序中,我们可以查看一些训练集的例子。如下图所示:

训练结果为:

关于全连接的理解,可以参考李宏毅的ppt。

损失函数通常使用的有以下两种。

对应的程序为:

model.compile(loss='mse',optimizer=RMSprop(),metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',optimizer=RMSprop(),metrics=['accuracy'])

同时,模型的激活函数也有其他的,如ReLU,sigmoid等。对应的程序调整为:

model.add(Dense(num_classes, activation='relu'))
model.add(Dense(num_classes, activation='sigmoid'))

优化方式也可以调整为其他的,如Adam()或者SGD()等,对应的程序可以调整为:

model.compile(loss='categorical_crossentropy',optimizer=Adam(),metrics=['accuracy'])
model.compile(loss='categorical_crossentropy',optimizer=SGD(lr=0.1),metrics=['accuracy'])

keras手写数字识别--入门相关推荐

  1. 深度学习--TensorFlow(项目)Keras手写数字识别

    目录 效果展示 基础理论 1.softmax激活函数 2.神经网络 3.隐藏层及神经元最佳数量 一.数据准备 1.载入数据集 2.数据处理 2-1.归一化 2-2.独热编码 二.神经网络拟合 1.搭建 ...

  2. 从手写数字识别入门深度学习丨MNIST数据集详解

    就像无数人从敲下"Hello World"开始代码之旅一样,许多研究员从"MNIST数据集"开启了人工智能的探索之路. MNIST数据集(Mixed Natio ...

  3. TensorFlow8-mnist手写数字识别入门

    分类问题的损失函数为什么一般不用MSE?MSE在逻辑回归中可能具有多个局部最优点 不能用梯度下降算法

  4. keras从入门到放弃(十三)卷积神经网络处理手写数字识别

    今天来一个cnn例子 手写数字识别,因为是图像数据 import keras from keras import layers import numpy as np import matplotlib ...

  5. keras从入门到放弃(十)手写数字识别训练

    导入手写数字识别 import keras from keras import layers import matplotlib.pyplot as plt %matplotlib inline im ...

  6. 深度学习入门实例——基于keras的mnist手写数字识别

    本文介绍了利用keras做mnist数据集的手写数字识别. 参考网址 http://www.cnblogs.com/lc1217/p/7132364.html mnist数据集中的图片为28*28的单 ...

  7. Keras搭建CNN(手写数字识别Mnist)

    MNIST数据集是手写数字识别通用的数据集,其中的数据是以二进制的形式保存的,每个数字是由28*28的矩阵表示的. 我们使用卷积神经网络对这些手写数字进行识别,步骤大致为: 导入库和模块 我们导入Se ...

  8. MOOC网深度学习应用开发1——Tensorflow基础、多元线性回归:波士顿房价预测问题Tensorflow实战、MNIST手写数字识别:分类应用入门、泰坦尼克生存预测

    Tensorflow基础 tensor基础 当数据类型不同时,程序做相加等运算会报错,可以通过隐式转换的方式避免此类报错. 单变量线性回归 监督式机器学习的基本术语 线性回归的Tensorflow实战 ...

  9. 使用tf.keras搭建mnist手写数字识别网络

    使用tf.keras搭建mnist手写数字识别网络 目录 使用tf.keras搭建mnist手写数字识别网络 1.使用tf.keras.Sequential搭建序列模型 1.1 tf.keras.Se ...

最新文章

  1. 怎么让wordpress用sqlite3 搭建轻量级博客系统
  2. Linux下使用终端调试Python程序:pudb
  3. 天翼云从业认证课后习题(3.1天翼云计算产品)
  4. Linux 查看CPU信息、机器型号等硬件信息
  5. SAP CRM webclient ui drop down list key mode
  6. 《vSphere性能设计:性能密集场景下CPU、内存、存储及网络的最佳设计实践》一1.2.2 内存...
  7. url 参数传递的两种方式_VB编程中的传值与传址两种参数传递方式,你清楚吗?...
  8. 测试用例组织结构_用例和组织结构
  9. GridView position = 0重复加载的问题
  10. 从零开始攻略PHP(9)——错误和异常处理
  11. Xcode8上传app一直显示正在处理
  12. MySQL双主机双Master方案测试
  13. 【首发】'k4' 宏病毒专杀 原创新型excel宏病毒专杀工具
  14. hualinux 学生党 建议:读书就是为了社会目标做准备
  15. 台式电脑桌面没有计算机图标,电脑桌面图标全部消失怎么恢复 电脑桌面图标设置随意放置的方法...
  16. 全面开创城市数字经济新时代
  17. 学习笔记:图像分割之深度学习场景分割(2015开始)综述之前是手工特征
  18. Java中的Math函数常用方法都在这里
  19. 【visual studio】visual studio 2022 无法 复制黏贴
  20. 解决error ‘XXX‘ is not defined no-undef且项目没有eslintrc.js文件问题

热门文章

  1. 立冬了 广州还是夏天
  2. JS 数据结构之旅 :通过JS实现栈、队列、二叉树、二分搜索树、AVL树、Trie树、并查集树、堆
  3. 你好,了解一下Java 14带来的一系列新功能
  4. Spring EclipseLink NoSQL - 使用MongoDB和Oracle NoSQL DB构建
  5. Hystrix面试 - 深入 Hystrix 断路器执行原理
  6. BGP——BGP优化技术(总结+配置)
  7. Centos 7 定时关机
  8. JS天气插件(最全)
  9. Github+jsDelivr为脚本/图片等静态文件加速的全球CDN
  10. C#LeetCode刷题之#704-二分查找(Binary Search)