jupyter notebook: https://github.com/Penn000/NN/blob/master/notebook/LeNet/LeNet.ipynb

LeNet训练MNIST

1 import warnings
2 warnings.filterwarnings('ignore')  # 不打印 warning
3
4 import tensorflow as tf
5 import numpy as np
6 import os

加载MNIST数据集

分别加载MNIST训练集、测试集、验证集

1 from tensorflow.examples.tutorials.mnist import input_data
2
3 mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
4 X_train, y_train = mnist.train.images, mnist.train.labels
5 X_test, y_test = mnist.test.images, mnist.test.labels
6 X_validation, y_validation = mnist.validation.images, mnist.validation.labels

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
1 print("Image Shape: {}".format(X_train.shape))
2 print("label Shape: {}".format(y_train.shape))
3 print()
4 print("Training Set:   {} samples".format(len(X_train)))
5 print("Validation Set: {} samples".format(len(X_validation)))
6 print("Test Set:       {} samples".format(len(X_test)))

Image Shape: (55000, 784)
label Shape: (55000, 10)Training Set:   55000 samples
Validation Set: 5000 samples
Test Set:       10000 samples

数据处理

由于LeNet的输入为32x32xC(C为图像通道数),而MNIST每张图像的尺寸为28x28,所以需要对图像四周进行填充,并添加一维,使得每幅图像的形状为32x32x1。

1 # 使用0对图像四周进行填充
2 X_train = np.array([np.pad(X_train[i].reshape((28, 28)), (2, 2), 'constant')[:, :, np.newaxis] for i in range(len(X_train))])
3 X_validation = np.array([np.pad(X_validation[i].reshape((28, 28)), (2, 2), 'constant')[:, :, np.newaxis] for i in range(len(X_validation))])
4 X_test = np.array([np.pad(X_test[i].reshape((28, 28)), (2, 2), 'constant')[:, :, np.newaxis] for i in range(len(X_test))])
5
6 print("Updated Image Shape: {}".format(X_train.shape))

Updated Image Shape: (55000, 32, 32, 1)

MNIST数据展示

 1 import random
 2 import numpy as np
 3 import matplotlib.pyplot as plt
 4 %matplotlib inline
 5
 6 index = random.randint(0, len(X_train))
 7 image = X_train[index].squeeze().reshape((32, 32))
 8
 9 plt.figure(figsize=(2,2))
10 plt.imshow(image, cmap="gray")
11 print(y_train[index])

[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]

LeNet网络结构

Input

The LeNet architecture accepts a 32x32xC image as input, where C is the number of color channels. Since MNIST images are grayscale, C is 1 in this case. LeNet的输入为32x32xC的图像,C为图像的通道数。在MNIST中,图像为灰度图,因此C等于1。

Architecture

Layer 1: Convolutional. 输出为28x28x6的张量。

Activation. 激活函数。

Pooling. 输出为14x14x6的张量。

Layer 2: Convolutional. 输出为10x10x16的张量。

Activation. 激活函数。

Pooling. 输出为5x5x16的张量。

Flatten. 将张量展平为一维向量,使用tf.contrib.layers.flatten可以实现。

Layer 3: Fully Connected. 输出为120长度的向量。

Activation. 激活函数。

Layer 4: Fully Connected. 输出为84长度的向量。

Activation. 激活函数。

Layer 5: Fully Connected (Logits). 输出为10长度的向量。

1 # 卷积层
2 def conv_layer(x, filter_shape, stride, name):
3     with tf.variable_scope(name):
4         W = tf.get_variable('weights', shape=filter_shape, initializer=tf.truncated_normal_initializer())
5         b = tf.get_variable('biases', shape=filter_shape[-1], initializer=tf.zeros_initializer())
6     return tf.nn.conv2d(x, W, strides=stride, padding='VALID', name=name) + b

1 # 全连接层
2 def fc_layer(x, in_size, out_size, name):
3     with tf.variable_scope(name):
4         W = tf.get_variable('weights', shape=(in_size, out_size), initializer=tf.truncated_normal_initializer())
5         b = tf.get_variable('biases', shape=(out_size), initializer=tf.zeros_initializer())
6
7     return tf.nn.xw_plus_b(x, W, b, name=name)

1 def relu_layer(x, name):
2     return tf.nn.relu(x, name=name)

 1 from tensorflow.contrib.layers import flatten
 2
 3 def LeNet(x):
 4     conv1 = conv_layer(x, filter_shape=(5, 5, 1, 6), stride=[1, 1, 1, 1], name='conv1')
 5     relu1 = relu_layer(conv1, 'relu1')
 6     max_pool1 = max_pool_layer(relu1,  kernel_size=[1, 2, 2, 1], stride=[1, 2, 2, 1], name='max_pool1')
 7
 8     conv2 = conv_layer(max_pool1, filter_shape=(5, 5, 6, 16), stride=[1, 1, 1, 1], name='conv2')
 9     relu2 = relu_layer(conv2, 'relu2')
10     max_pool2 = max_pool_layer(relu2,  kernel_size=[1, 2, 2, 1], stride=[1, 2, 2, 1], name='max_pool1')
11
12     flat = flatten(max_pool2)
13
14     fc3 = fc_layer(flat, 400, 120, name='fc3')
15     relu3 = relu_layer(fc3, 'relu3')
16
17     fc4 = fc_layer(relu3, 120, 84, name='fc4')
18     relu4 = relu_layer(fc4, 'relu4')
19
20     logits = fc_layer(relu4, 84, 10, name='fc5')
21
22     return logits

TensorFlow设置

 1 EPOCHS = 10
 2 BATCH_SIZE = 128
 3 log_dir = './log/'
 4
 5 x = tf.placeholder(tf.float32, (None, 32, 32, 1))
 6 y = tf.placeholder(tf.int32, (None, 10))
 7
 8 # 定义损失函数
 9 logits = LeNet(x)
10 cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)
11 loss = tf.reduce_mean(cross_entropy)
12 train = tf.train.AdamOptimizer(learning_rate=0.01).minimize(loss)

训练

 1 from sklearn.utils import shuffle
 2 import shutil
 3 log_dir = './logs/'
 4 if os.path.exists(log_dir):
 5     shutil.rmtree(log_dir)
 6 os.makedirs(log_dir)
 7 train_writer = tf.summary.FileWriter(log_dir+'train/')
 8 valid_writer = tf.summary.FileWriter(log_dir+'valid/')
 9
10 ckpt_path = './ckpt/'
11 saver = tf.train.Saver()
12
13 with tf.Session() as sess:
14     sess.run(tf.global_variables_initializer())
15     n_samples = len(X_train)
16
17     step = 0
18     for i in range(EPOCHS):
19         X_train, y_train = shuffle(X_train, y_train) # 打乱数据
20         # 使用mini-batch训练
21         for offset in range(0, n_samples, BATCH_SIZE):
22             end = offset + BATCH_SIZE
23             batch_x, batch_y = X_train[offset:end], y_train[offset:end]
24             sess.run(train, feed_dict={x: batch_x, y: batch_y})
25
26             train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
27             train_summary = tf.Summary(value=[
28                 tf.Summary.Value(tag="loss", simple_value=train_loss)
29             ])
30             train_writer.add_summary(train_summary, step)
31             train_writer.flush()
32             step += 1
33
34         # 每个epoch使用验证集对网络进行验证
35         valid_loss = sess.run(loss, feed_dict={x: X_validation, y: y_validation})
36         valid_summary = tf.Summary(value=[
37                 tf.Summary.Value(tag="loss", simple_value=valid_loss)
38         ])
39         valid_writer.add_summary(valid_summary, step)
40         valid_writer.flush()
41
42         print('epoch', i, '>>> loss:', valid_loss)
43
44     # 保存模型
45     saver.save(sess, ckpt_path + 'model.ckpt')
46     print("Model saved")

epoch 0 >>> validation loss: 39.530758
epoch 1 >>> validation loss: 19.649899
epoch 2 >>> validation loss: 11.780323
epoch 3 >>> validation loss: 8.7316675
epoch 4 >>> validation loss: 6.396747
epoch 5 >>> validation loss: 5.4544454
epoch 6 >>> validation loss: 4.5326686
epoch 7 >>> validation loss: 3.5578024
epoch 8 >>> validation loss: 3.2353864
epoch 9 >>> validation loss: 3.5096574
Model saved

训练和验证的loss曲线

测试

1 correct = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1))
2 accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
3
4 with tf.Session() as sess:
5     saver.restore(sess, tf.train.latest_checkpoint('./ckpt'))
6
7     test_accuracy = sess.run(accuracy, feed_dict={x: X_test, y: y_test})
8     print("Test Accuracy = {}".format(test_accuracy))

INFO:tensorflow:Restoring parameters from ./ckpt/model.ckpt
Test Accuracy = 0.9574000239372253

转载于:https://www.cnblogs.com/Penn000/p/10251187.html

LeNet训练MNIST相关推荐

  1. Paddle 环境中 使用LeNet在MNIST数据集实现图像分类

    简 介: 测试了在AI Stuio中 使用LeNet在MNIST数据集实现图像分类 示例.基于可以搭建其他网络程序. 关键词: MNIST,Paddle,LeNet #mermaid-svg-FlRI ...

  2. 训练MNIST数据集模型

    1. 数据集准备 详细信息见: Caffe: LMDB 及其数据转换 mnist是一个手写数字库,由DL大牛Yan LeCun进行维护.mnist最初用于支票上的手写数字识别, 现在成了DL的入门练习 ...

  3. CAFFE学习笔记(一)Caffe_Example之训练mnist

     CAFFE学习笔记(一)Caffe_Example之训练mnist 0.参考文献 [1]caffe官网<Training LeNet on MNIST with Caffe>;  [ ...

  4. 实践详细篇-Windows下使用VS2015编译的Caffe训练mnist数据集

    上一篇记录的是学习caffe前的环境准备以及如何创建好自己需要的caffe版本.这一篇记录的是如何使用编译好的caffe做训练mnist数据集,步骤编号延用上一篇 <实践详细篇-Windows下 ...

  5. NNDL 实验六 卷积神经网络(3)LeNet实现MNIST 手动算子部分

    手写体数字识别是计算机视觉中最常用的图像分类任务,让计算机识别出给定图片中的手写体数字(0-9共10个数字).由于手写体风格差异很大,因此手写体数字识别是具有一定难度的任务. 我们采用常用的手写数字识 ...

  6. TensorFlow精进之路(十四):RNN训练MNIST数据集

    1.概述 前面介绍了RNN,这一节就用tensorflow的RNN来训练MNIST数据集,看看准确率如何. 2.代码实现 2.1.导入数据集 # encoding:utf-8 import tenso ...

  7. LeNet——训练和预测篇

    前言 本学习笔记参考自B站up主霹雳吧啦Wz 代码均来自其github开源项目WZMIAOMIAO/deep-learning-for-image-processing: deep learning ...

  8. 深度学习基础: BP神经网络训练MNIST数据集

    BP 神经网络训练MNIST数据集 不用任何深度学习框架,一起写一个神经网络训练MNIST数据集 本文试图让您通过手写一个简单的demo来讨论 1. 导包 import numpy as np imp ...

  9. pytorch训练MNIST

    本文记录了pytorch训练MNIST数据集的过程,通过本文可熟悉pytorch训练的大体操作过程. 一.导入各种模块 import torch import torch.nn as nn impor ...

最新文章

  1. C#学习笔记8:HTML和CSS基础学习笔记
  2. 禁止COOKIE后对SESSION的影响
  3. setlocal启动批处理文件中环境变量的本地化
  4. Scratch第四十九讲:完美的下落和反弹
  5. [SVN(ubuntu)] ubuntu使用svn
  6. C# 关键字 virtual、override和new的用法
  7. zabbix利用traceroute命令监控主备链路状态
  8. Mysql 的优化方式,都给你整理好了(附思维导图)
  9. 手把手教你安装Latex(保姆级教程)
  10. 元器件(Components)安规标准(UL+IEC)
  11. 图像处理的灰度化和二值化
  12. elementui进度条如何设置_Progress 进度条
  13. 电子商务里的P2P、O2O、P2C、B2C、B2B、C2C是什么?
  14. Linux进程中有xorg,linux – Xorg如何工作?
  15. Canvas画各种线
  16. 安全浏览器无法安装?看这一篇就够了
  17. 从零开始详解应用内支付——商品创建及测试上架
  18. 淘宝,1688,京东店铺所有商品接口分享
  19. python-opencv文件夹中所有视频按顺序截图片并按顺序命名
  20. 主方法外单独的两个类,不能直接互相调用

热门文章

  1. mysql类 php100_PHP100视频教程26:制作自己的PHP+MYSQL的类
  2. 苹果挂端口方法_苹果新系统遭吐槽!SSH 默认规则被破坏,程序员无法登录 Web 服务器......
  3. matlab软件介绍_活动回顾 | 您要的MATLAB课堂总结上线啦!
  4. 卖任小龙java视频,任小龙Java大神之路(第九季 SpringMVC)视频教程叩丁狼教育出品...
  5. 纠错编码基本实验matlab,纠错编码基本实验matlab实现包含源代码
  6. JAVA8常量池监控_深入探索Java常量池
  7. java 数组map_java中 数组 list map之间的互转
  8. git checkout和git reset的一些区别以及配置git简写命令
  9. 如何分析案件的性质_律师如何综合分析一个案件
  10. 网络推广软文之文章更新对网站排名的影响!