环境: Ubuntu 18.04, tensorflow 2.4.1

该版本是全连接网络的优化版本,采用了卷积神经网络,参考。
可以看到BATCH_SIZE没有改动,而EPOCH明显减少。
但是EPOCH一轮的时间明显增长了很多。

代码

import tensorflow as tf
import numpy as np
from tensorflow import kerasEPOCH = 15
BATCH_SIZE = 128
VERBOSE = 1
NUM_CLASSES = 10# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)## build network
model = tf.keras.Sequential([tf.keras.Input(shape=(28, 28, 1)),tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),tf.keras.layers.Flatten(),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(NUM_CLASSES, activation="softmax"),]
)model.summary()## compile network
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
model.fit(x_train, y_train, batch_size=BATCH_SIZE, epochs=EPOCH, validation_split=0.1)## validation  0.991
val_loss,val_acc = model.evaluate(x_test, y_test, verbose=VERBOSE)
print("Test loss: ", val_loss)
print("Test accuracy: ", val_acc)

model summary解读

输入层是 (,28,28,1)4维张量,分别是数量,长度,宽度,通道(灰度)

第一层conv2d

  • kernel_size=(3,3)
    因为输入时长宽相等,我们就看一边。 没有指定strides,默认为1,
    (28−3)/1+1=26(28-3)/1+1=26(28−3)/1+1=26
  • filter=32
    表示有32个kernel,每个kernel在原张量上卷积后,都会有一个结果,所以最后一维就是32.
    这个就是第一层conv2dshape(,26,26,32)的由来。
  • param个数
    ((核宽∗核高)∗通道数+1)∗卷积核数((核宽*核高)*通道数+1)*卷积核数 ((核宽∗核高)∗通道数+1)∗卷积核数注:+1是因为一个bias
    ((3∗3)∗1+1)∗32=320((3*3)*1+1)*32=320((3∗3)∗1+1)∗32=320

第二层pooling

它主要是做收缩,所以不产生新的param。
pool_size=(2, 2),所以shape变成 (,13,13,32)

第三层conv2d

  • kernel_size=(3, 3)
    (13−3)/1+1=11(13-3)/1+1=11(13−3)/1+1=11
  • filter=64
    所以shape为(,11,11,64)
  • param个数
    ((核宽∗核高)∗通道数+1)∗卷积核数( ( 核宽 *核高)*通道数+1)* 卷积核数((核宽∗核高)∗通道数+1)∗卷积核数 注:这一层的通道数就是上层卷积核数
    ((3∗3)∗32+1)∗64=18496((3*3)*32+1)*64=18496((3∗3)∗32+1)∗64=18496

Flatten层

其实就是将多维张量变为一维输入,所以
5∗5∗64=16005*5*64=16005∗5∗64=1600

dropout层都改变shape

最后dense层

  • param个数
    (通道数+1)∗输出数( 通道数+1)* 输出数(通道数+1)∗输出数
    (1600+1)∗10=16010(1600+1)*10=16010(1600+1)∗10=16010

Total Param直观地体现了整体训练的时长

占用资源怎么样?

在云端上使用单虚拟核,1G内存是跑不了的,主要是内存不够。mnist跑起来大概需要1G+内存。


一些尝试

上面的样例,在15epoch跑下来,accuracy大致为0.991, 它比全连接网络确实准确了很多。

那么究竟增加卷积层带来了多少优化呢,在这里我们尝试做一些改动试一下。


TEST 1. 去除一层卷积层


我们把两层卷积层中的第二层去除,看一下效果。
accuracy: 0.9831
准确度降低了约0.7%,从summary来看,时间(训练params个数)反而增加了。

另外,因为mnist数据集中的图片较小,卷积层对它特征提取效果与层数应该增加没有太大关联。

TEST 2. 去除每层卷积层的Pooling


accuracy: 0.9968
准确度基本没有变化,但是时间(参数)指数级增长了

TEST 3. 去除Dropout层


accuracy: 0.9868
根据前面的说明,可以看到dropout并不改变model的shape也不会减少训练的参数个数。
所以时间虽然跟TEST 2一样,但是效果没有上面的好

TEST 4. 只留一层Conv并修改Conv层kernel参数

kernel_size=(3, 3) 改为kernel_size=(5, 5)
将kernel增加,是会增加该层的param的,所以虽然整体只有一层Conv2D,但是最终的训练参数比标准网络要多一个数量级

accuracy: 0.9871结果跟TEST 1差不多

TEST 5 只留一层Conv和Pooling并修改Conv层kernel参数


与 TEST 4的区别就是,保留了Pooling层,其他都一致。

accuracy: 0.987结果也差不多

但可见参数数据量明显下降一个数量级,Pooling的作用还是很凸显的。

TEST 6. 只留一层Conv和Pooling,并修改Conv层filter参数


filter=32 改为filter=8

accuracy: 0.9809

TEST 7. 只留一层Conv和Pooling,修改pooling参数


pool_size=(2, 2) 改为pool_size=(5, 5))

accuracy: 0.977这个结果较原网络差了快2%,但是训练的参数最少,比原网络少了一个数量级

mnist学习实例(2)相关推荐

  1. mnist学习实例(1)

    环境: Ubuntu 18.04, tensorflow 2.4.1 mnist是Yann Lecun大神的手写数据,数据中的数字都是28X28的图像,每个像素点是[0-255]的值 其中训练数据为6 ...

  2. 涵盖 14 大主题!最完整的 Python 学习实例集来了!

    机器学习.深度学习最简单的入门方式就是基于 Python 开始编程实战.最近闲逛 GitHub,发现了一个非常不错的 Python 学习实例集,完全是基于 Python 来实现包括 ML.DL 等领域 ...

  3. ajax请求返回json实例,Jquery Ajax 学习实例2 向页面发出请求 返回JSon格式数据

    一.AjaxJson.aspx 处理业务数据,产生JSon数据,供JqueryRequest.aspx调用,代码如下: protected void Page_Load(object sender, ...

  4. php实训总结00字,说明的比较细的php 正则学习实例

    说明的比较细的php 正则学习实例 "^The": 匹配以 "The"开头的字符串; "of despair$": 匹配以 "of ...

  5. 深度学习之生成对抗网络(1)博弈学习实例

    深度学习之生成对抗网络(1)博弈学习实例 博弈学习实例  在 生成对抗网络(Generative Adversarial Network,简称GAN)发明之前,变分自编码器被认为是理论完备,实现简单, ...

  6. 从入门到入土:Python爬虫学习|实例练手|爬取猫眼榜单|Xpath定位标签爬取|代码

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  7. 从入门到入土:Python爬虫学习|实例练手|爬取百度翻译|Selenium出击|绕过反爬机制|

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  8. 从入门到入土:Python爬虫学习|实例练手|爬取新浪新闻搜索指定内容|Xpath定位标签爬取|代码注释详解

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

  9. 从入门到入土:Python爬虫学习|实例练手|爬取百度产品列表|Xpath定位标签爬取|代码注释详解

    此博客仅用于记录个人学习进度,学识浅薄,若有错误观点欢迎评论区指出.欢迎各位前来交流.(部分材料来源网络,若有侵权,立即删除) 本人博客所有文章纯属学习之用,不涉及商业利益.不合适引用,自当删除! 若 ...

最新文章

  1. 性价比超高:苹果发布了新数据集,助力室内场景理解
  2. 解题报告:luogu P2341 受欢迎的牛(Tarjan算法,强连通分量判定,缩点,模板)
  3. leetcode算法题--最低票价★
  4. ST05 跟踪SQL
  5. html抓取成xml,使用XML包将html表抓取到R数据帧中
  6. 50行代码的MVVM,感受闭包的艺术
  7. web 富文本编辑器总结
  8. redis 内存碎片清理
  9. handlersocket mysql,MySQL插件HandlerSocket
  10. FPGA经典设计思想
  11. Android 电话的国家代码
  12. 图片优化——质量与性能的博弈
  13. 《富爸爸穷爸爸》:为什么你很穷
  14. 【IEEE】IEEE论文接收后proof(校样)全流程实例讲解
  15. React Fullpage
  16. 操作系统安全防护技术
  17. python 中文转拼音原理_Python中文转拼音
  18. 继承 extends
  19. 自我实现tcmalloc的项目简化版本
  20. java.io.IOException: 远程主机强迫关闭了一个现有的连接

热门文章

  1. JavaScript权威Douglas Crockford:代码阅读和每个人都该学的编程
  2. 【问题收录】ImportError No module named MySQLdb 问题解决
  3. android布局共享,布局共享(如所有ACTIVITY拥有相同的布局部分,比如ACTIONBAR,在BASEACTIVITY中写入布局)...
  4. 在线作图|如何绘制一张哑铃图
  5. 在线作图|2分钟在线绘制RDA图
  6. QIIME 2教程. 29参考数据库DataResources(2020.11)
  7. ggbiplot-最好看的PCA作图:样品PCA散点+分组椭圆+主成分丰度和相关
  8. python使用matplotlib可视化:设置坐标轴的范围、设置主次坐标轴刻度、坐标轴刻度显示样式、坐标轴刻度数颜色、小数点位数、坐标轴刻度网格线、线条类型、数据点形状标签、文本字体、颜色、大小等
  9. R语言使用ggplot2包geom_jitter()函数绘制分组(strip plot,一维散点图)带状图(改变分组次序)实战
  10. R语言ggplot2可视化:ggplot2可视化密度图(显示数据密集区域)、ggplot2可视化密度图(对数坐标):log10比例的收入密度图突出了在常规密度图中很难看到的收入分布细节