一、实验介绍

  • 实验环境:jupyter notebook、Tensorflow.keras

  • 数据集Fashion Mnist与样例代码及相关参考:
    https://www.tensorflow.org/tutorials/keras/classification
    https://zhuanlan.zhihu.com/p/161656714
    https://msd.misuland.com/pd/4146263742822220136
    https://baijiahao.baidu.com/s?id=1653421414340022957&wfr=spider&for=pc

    1、认识数据集
    2、尝试使用DNN、CNN等神经网络模型进行分类
    3、评估、比较各模型的性能

二、Fashion MNIST数据集

  • 包含10个类别的70000个灰度图像。这些图像以低分辨率(28×28像素)展示了单件衣物。
  • Fashion MINIST旨在临时替代经典 MNIST 数据集,后者常被用作计算机视觉机器学习程序的“Hello, World”。MNIST数据集包含手写数字(0、1、2 等)的图像,其格式与您将使用的衣物图像的格式相同。
  • 使用 Fashion MNIST来实现多样化,因为它比常规 MNIST更具挑战性。这两个数据集都相对较小,都用于验证某个算法是否按预期工作。对于代码的测试和调试,它们都是很好的起点。

标签是整数数组,介于0~9之间,以下是这些标签对应的服装类

标签
0 T恤/上衣
1 裤子
2 套头衫
3 连衣裙
4 外套
5 凉鞋
6 衬衫
7 运动鞋
8
9 短靴

三、实验代码和结果截图

1、导入python包

tf.kerasTensorflow中用来构建和训练模型的高级API。

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow import keras
2、全局变量
  • train_imagestrain_labels是模型用来学习的数据,test_imagestest_labels被用来对模型进行测试。

  • 每个图像都会被映射到一个标签。

fashion_minist = keras.datasets.fashion_mnist# 载入Fashion MINIST数据集
# 返回四个28×28的Numpy数组,像素值介于0~255之间,标签是整数数组,介于0~9之间
(train_images,train_labels),(test_images,test_labels) = fashion_minist.load_data()
# 数据集中不包括类名称,将它们存储在class_names中供稍后绘制图像时使用
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat','Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']# 将像素值缩小至0~1之间,对训练集和测试集做相同处理
train_images = train_images / 255.0
test_images = test_images / 255.0# (为了验证数据的格式是否正确,以及是否已准备好构建和训练网络)显示训练集中的前25个图像,并在每个图像的下方显示类名称
plt.figure(figsize = (10,10))
for i in range (25):plt.subplot(5,5,i+1)# 五行五列plt.xticks([])plt.yticks([])# plt.imshow(train_images[i],cmap=plt.cm.gray)plt.imshow(train_images[i],cmap=plt.cm.binary)# cmap参数接受一个值,其中每个值代表一种配色方案plt.xlabel(class_names[train_labels[i]])
plt.show()

3、构建模型:构建神经网络模型需要先配置模型的层,然后再编译模型
3.1 构建DNN模型
  • 该网络的第一层 tf.keras.layers.Flatten 将图像格式从二维数组(28 x 28 像素)转换成一维数组(28 x 28 = 784 像素)。将该层视为图像中未堆叠的像素行并将其排列起来。该层没有要学习的参数,它只会重新格式化数据。

  • 展平像素后,网络会包括两个 tf.keras.layers.Dense 层的序列。它们是密集连接或全连接神经层。第一个 Dense 层有 128 个节点(或神经元)。第二个(也是最后一个)层会返回一个长度为 10 的 logits 数组。每个节点都包含一个得分,用来表示当前图像属于 10 个类中的哪一类。

  • 在深度神经网络中,通常使用一种叫线性整流函数(修正线性单元):Rectified Linear Unit(ReLU),作为神经元的激活函数。 R e L U ( x ) = m a x ( 0 , x ) ReLU(x) = max(0,x) ReLU(x)=max(0,x),即
    R e L U ( x ) = { 0 if  x < 0 x if  x ≥ 0 ReLU(x) = \begin{cases} 0& \text{ if } x<0 \\ x& \text{ if } x≥0 \end{cases} ReLU(x)={0x​ if x<0 if x≥0​

  • 在准备对模型进行训练之前,还需要再对其进行一些设置。以下内容是在模型的编译步骤中添加的:
    ① 损失函数(loss function) :用于测量模型在训练期间的准确率。您会希望最小化此函数,以便将模型“引导”到正确的方向上。
    ② 优化器(optimizer) :决定模型如何根据其看到的数据和自身的损失函数进行更新。
    ③ 指标 (例如:accuracy):用于监控训练和测试步骤。以下示例使用了准确率,即被正确分类的图像的比率。

# 设置层
model = keras.Sequential([keras.layers.Flatten(input_shape = (28,28)),keras.layers.Dense(128, activation = 'relu'),keras.layers.Dense(10)
])# 编译模型
model.compile(optimizer = 'adam',loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics = ['accuracy'])# 训练模型
model.fit(train_images,train_labels,epochs=10)# epochs:迭代次数# 评估准确率
test_loss_dnn, test_acc_dnn = model.evaluate(test_images,test_labels,verbose=2)
print("\nTest accuracy is:",test_acc_dnn)# 进行预测
probability_model = tf.keras.Sequential([model,tf.keras.layers.Softmax()])predictions = probability_model.predict(test_images)# 第一个预测结果
predictions[0]

3.2 构建CNN模型,CNN模型中多了卷积层和池化层,先卷积再池化再卷积再池化,最后全连接层。
# Fashion MINIST数据集是3维的,不符合卷积的输入要求(4维),采用如下命令将3维的输入扩展维4维的输入
train_images = np.expand_dims(train_images, axis=3)
test_images = np.expand_dims(test_images, axis=3)# 设置层
model=keras.models.Sequential([keras.layers.Conv2D(64, (3, 3), activation='relu', input_shape=(28,28,1)),keras.layers.MaxPooling2D(2,2),keras.layers.Conv2D(64,(3,3),activation='relu'),keras.layers.MaxPooling2D(2,2),keras.layers.Flatten(),keras.layers.Dense(128, activation=tf.nn.relu),keras.layers.Dense(10)
])# 编译模型
model.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])# 训练模型
model.fit(train_images,train_labels,epochs=10,validation_data=(test_images,test_labels))# epochs:迭代次数# 评估准确率
test_loss_cnn, test_acc_cnn = model.evaluate(test_images,test_labels,verbose=2)
print("\nTest accuracy is:",test_acc_cnn)

四、总结

本案例同样的迭代10次,使用CNN比使用DNN运行时间长的多,但CNN的准确度更高,可自行设置不同的迭代次数,比较两种模型的运行时间、准确度等(除迭代次数,激活函数的选择和优化器的选择等都对准确度有影响)。

[Tensorflow]服装图像数据集分类:使用DNN、CNN模型相关推荐

  1. 猫狗图像数据集上的深度学习模型性能对比

    LeNet模型简介 1. LeNet LeNet-5由七层组成(不包括输入层),每一层都包含可训练权重.通过卷积.池化等操作进行特征提取,最后利用全连接实现分类识别,下面是他的网络结构示意图: C:卷 ...

  2. 【小白学习Keras教程】四、Keras基于数字数据集建立基础的CNN模型

    @Author:Runsen 文章目录 基本卷积神经网络(CNN) 加载数据集 1.创建模型 2.卷积层 3. 激活层 4. 池化层 5. Dense(全连接层) 6. Model compile & ...

  3. 迁移学习:如何使用TensorFlow对图像进行分类

    导言 在机器学习环境中,迁移学习是一种技术,使我们能够重用已经训练的模型并将其用于另一个任务.图像分类是将图像作为输入并为其分配具有概率的类(通常是标签)的过程.这个过程使用深度学习模型,即深度神经网 ...

  4. 使用增强的图像多分类模型

    欢迎关注 "小白玩转Python",发现更多 "有趣" 什么是图像增强? 图像增强,解决数据有限的问题.图像增强是一种通过在数据集中人工扩大训练数据集大小的技术 ...

  5. 基于卷积神经网络(CNN)模型的垃圾分类设计与实现

    本篇博客主要内容如下: 目录 项目背景 数据集介绍 模型构建与训练 结果分析 结果对比分析 项目背景 如何通过垃圾分类管理,最大限度地实现垃圾资源利用,减少垃圾处置量,改善生存环境质量,是当前世界各国 ...

  6. Tensorflow之CNN实现CIFAR-10图像的分类python

    这个还是18年做的,当时被老师逼着三天速成,也是无奈的很呀,哭唧唧.但是现在想想还是老师逼迫的时候效率高哈哈哈哈哈,感谢努力push我们的老师~ CNN原理 卷积神经网络(Convolutional ...

  7. [Python图像识别] 五十.Keras构建AlexNet和CNN实现自定义数据集分类详解

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  8. [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98%+

    [TensorFlow深度学习入门]实战九·用CNN做科赛网TibetanMNIST藏文手写数字数据集准确率98.8%+ 我们在博文,使用CNN做Kaggle比赛手写数字识别准确率99%+,在此基础之 ...

  9. (转)CNN基于Tensorflow实现cifar10数据集80-准确率

    https://www.jianshu.com/p/a2c1016faa95 数据导入和预处理 本文使用的是CIFAR10的数据集.CIFAR10包含了十个类型的图片,有60000张大小为32x32的 ...

最新文章

  1. Android权限处理分类
  2. jmeter性能测试入门简介
  3. dirty_background_ration 与 /proc/sys/vm/dirty_ratio
  4. UGUI_UGUI事件系统简述及使用方法总结
  5. 直播安装vnc连接树莓
  6. Hadoop集群启动、初体验
  7. 利用python编写祝福_手把手|教你用Python换个姿势,送狗年祝福语
  8. Learning Scrapy笔记(五)- Scrapy登录网站
  9. Cassandra 权威指南
  10. HALCON 21.11:深度学习笔记---有监督训练(6)
  11. SQL Server数据行的物理空间分配
  12. 网页无法打开...因为:net:ERR_CLEARTEXT_NOT_PERMITTED错误解决办法
  13. linux启动项加命令,启动项 命令(linux 添加开机启动项的三种方法)
  14. Linux下编译build的命令,Linux内核编译中build目录设置
  15. 利用HbuilderX制作简单网页: HTML5期末大作业——html5漫画风格个人主页
  16. 管理职工工资属于计算机什么应用领域,计算机练习题
  17. 魔兽巨龙追猎者服务器微信群,魔兽世界怀旧服开门情况一览,60多个区已缴满...
  18. kestrel web服务器性能对比,netcore高性能Web服务器Kestrel分析(示例代码)
  19. 使用html5制作表格
  20. 有什么比较好用的视频录像软件

热门文章

  1. 阳光点歌系统服务器说明书,天行阳光机顶盒点歌系统安装及配置说明-20210409030429.doc-原创力文档...
  2. API (DOM - 事件高级)
  3. linux三剑客-grep详解
  4. 已解决error:stray‘\243‘in program异常的正确解决方法,亲测有效!!!
  5. OJ笔记 18718 航行
  6. 光线追踪是怎么影响渲染速度的,什么显卡可以支持?
  7. android 离线帮助文档
  8. java单例模式之懒汉式与饿汉式
  9. P6专题:如何在 P6 中使用赢得值/挣值管理
  10. 摸鱼之谈----项羽之死