基于 CNN的 fashion MNIST图像分类

  • fashion MNIST图像分类
    • 数据集简介
    • 数据的预处理
    • CNN简介和构建
    • 模型部分代码
    • CNN实验结果
    • 致谢


fashion MNIST图像分类

数据集简介

  在 2017年8月份,德国研究机构Zalando Research在GitHub上推出了一个全新的数据集——fashion MNIST数据集,其中训练集包含60000个样例,测试集包含 10000个样例,分为10类,每一类的样本训练样本数量和测试样本数量相同。该数据集是一个替代 MNIST手写数据集的图像数据,比 MNIST数据复杂一些。该数据集的数据量较小,适用于用来验证某个算法可否正常运行和机器学习的入门。数据集的样本都来自日常穿着的衣裤鞋包,每个都是 28× 28的灰度图像,其中总共有10类标签,每张图像都有各自的标签,分别是T-shirt/top(T恤)、Trouser(裤子)、Pullover(套衫)、Dress(裙子)、Coat(外套)等10个服装类型,值得注意的是该数据集在我对数据集预处理时,发现其中有一些数据标签错误,这导致了数据在训练之后得不到很高的准确率。

数据的预处理

  在进行分类之前,需要对数据进行归一化。归一化可以加快训练网络的收敛性,归纳统一样本的统计分布特性,这里将数据归一化到0-1之间,使之符合概率分布,也能够加速梯度下降求解除最优解。归一化也可以使用其他的归一算法。归一化代码如下:
x_train, x_test = x_train / 255.0, x_test / 255.0

  本实验采用交叉验证的方法,交叉验证可以缓解单独测试结果的片面性和训练数据不足的问题。在数据预处理之后,需要对数据进行切分,切分为训练集和验证集以及测试集。将50000张测试集的数据中切割48000张图片进行训练,将2000张图片进行测试,最后将10000张图进行测试。具体的宿数据集切分代码如下所示:
#数据集切分
x_valid,x_train=x_train[:2000],x_train[2000:]
y_valid,y_train=y_train[:2000],y_train[2000:]

import tensorflow as tf
import os
from tensorflow import keras
import numpy as np
from matplotlib import pyplot as plt
import time
import math
import random
import numpy
from tqdm import tqdm_notebook as tqdm
from tensorflow.keras.utils import plot_model#数据预处理
def pretreatment():'''@introduction:数据预处理,将数据归一化,如果必要,可以进行数据混洗;其次将数据分割               @return : x_train 训练数据y_train 训练数据标签x_valid 验证数据y_valid 验证数据标签x_test 测试数据y_test 测试数据标签'''np.set_printoptions(threshold=np.inf)#控制台输出所有的值,不需要省略号fashion = tf.keras.datasets.fashion_mnist(x_train, y_train), (x_test, y_test) = fashion.load_data()x_train = x_train.reshape([-1,28,28,1])/255.0x_test = x_test.reshape([-1,28,28,1])/255.0#数据集切分x_valid,x_train=x_train[:2000],x_train[2000:]y_valid,y_train=y_train[:2000],y_train[2000:]return x_valid,x_train,y_valid,y_train,x_test,y_test

CNN简介和构建

  卷积神经网络(CNN)是1998年纽约大学的Yann Lecun在基于感受野和神经认知机的基础上提出的,这个神经网络模型在图像处理的领域取得了巨大成功。该网络相对于传统的神经网络的优势在于它能够通过卷积等一系列操作提取出图像的特征,使具有这些特征的网络能够很好的泛化能力和抽象表达能力。最为关键的是,卷积神经网络为了缓解过拟合和网络参数过多得导致的训练困难,这个模型引入了局部连接,权值共享等非常高明的策略。CNN比较关键在卷积和池化以及非线性映射。卷积层能够有效的提取出图像集的特征,这是提供系统准确率的关键。池化能够有效的减少数据量,并且缓解模型的过拟合问题,池化采用的具体方式则是降采样。非线性映射能够从概率统计的角度提升系统的非线性能力,因为在实际的应用中,线性系统的情况适用性不够宽泛。本实验模型图如下

模型部分代码

#模型训练
def model_train(x_train,y_train,x_valid,y_valid,choice=False):'''@introduction:训练模型,根据需求保存模型@parameter:   x_train 训练数据y_train 训练数据标签x_valid 验证数据y_valid 验证标签choice 选择是否读取保存的模型,默认值为不读取@return : model和history'''#模型处理def model_read(model):save_path = "CNN_new_model.h5"if os.path.exists(save_path):#判断文件是否存在print('\n')print('-------------------******模型读入*****----------------------')model=tf.keras.models.load_model(save_path)  else:print('文件不存在')return modeldef model_save(model):save_path = "CNN_new_model.h5"model.save(save_path)#模型定义model = keras.Sequential([   #(-1,28,28,1)->(-1,28,28,32)    keras.layers.Conv2D(input_shape=(28, 28, 1),filters=32,kernel_size=5,strides=1,padding='same'),     # Padding method),    #(-1,28,28,32)->(-1,14,14,32)    keras.layers.MaxPool2D(pool_size=2,strides=2,padding='same'),    #(-1,14,14,32)->(-1,14,14,64)    keras.layers.Conv2D(filters=64,kernel_size=3,strides=1,padding='same'),     # Padding method),    #(-1,14,14,64)->(-1,7,7,64)   keras.layers.MaxPool2D(pool_size=2,strides=2,padding='same'),    #(-1,7,7,64)->(-1,7*7*64)    keras.layers.Flatten(),    #(-1,7*7*64)->(-1,256)    keras.layers.Dense(256, activation=tf.nn.relu),    #(-1,256)->(-1,10)    keras.layers.Dense(10, activation=tf.nn.softmax)])model.summary()#配置参数#引入指数衰减学习率或不引入ad1 = tf.keras.optimizers.Adam(lr=0.0001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)ad2 = 'adam'model.compile(optimizer=ad2,loss='sparse_categorical_crossentropy',metrics=['sparse_categorical_accuracy'])if choice:model=model_read(model)history = model.fit(x_train, y_train , batch_size=128, epochs=2, validation_data=(x_valid, y_valid), validation_freq=1)#history = model.fit(x_train, y_train, epochs=2, validation_data=(x_valid, y_valid))model_save(model)else:history = model.fit(x_train, y_train,epochs=5, validation_data=(x_valid, y_valid))model_save(model)return history,modeldef model_get():'''@introduction: 模型读取@return: model 已经训练好的模型'''save_path = "CNN_model.h5"if os.path.exists(save_path):#判断文件是否存在print('\n>>>>>>>>>>>>模型读入>>>>>>>>>>>>>>>>>>>\n')model=tf.keras.models.load_model(save_path)  plot_model(model, to_file='CNN_model.jpg',show_shapes=True)for i in tqdm(range(100)):if(i==99):print('模型读入成功!')else:print('文件不存在')return model#数据可视化
def figshow(history):'''introductoin:  数据可视化,画出训练损失函数和验证集上的准确率'''# 显示训练集和验证集的acc和loss曲线print('\n\n\n')print('----------------------------------------------图像绘制-------------------------------------')acc = history.history['sparse_categorical_accuracy']val_acc = history.history['val_sparse_categorical_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']plt.subplots(figsize=(8,6))plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.title('Training and Validation Accuracy')plt.legend()plt.show()plt.subplots(figsize=(8,6))plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.title('Training and Validation Loss')plt.legend()plt.show()#测试集测试
def data_test(x_test,y_test,model):'''@introduction :输出模型在测试集上的top1准确率和top2准确率@parameter :x_test  测试集数据y_test  测试数据标签model   训练好的模型@return :top1_acc   top1标准下的准确率top2_acc   top2标准下的准确率'''print('****************************test******************************')loss, acc = model.evaluate(x_test, y_test)top1_acc = acc#print("test_accuracy:{:5.2f}%".format(100 * acc))y_pred = model.predict(x_test)k_b = tf.math.top_k(y_pred,2).indicesidx=0acc=0.0for i in k_b:if y_test[idx] in i.numpy():acc=acc+1idx=idx+1top2_acc=acc/y_test.shape[0] print('top1准确率:{0}\ntop2准确率:{1}'.format(top1_acc,top2_acc))return top1_acc,top2_acc#随机测试20张图片
def random_test(x_test,y_test,model):'''@introduction: 产生20个不同整数作为下标索引,输出预测值与真实值,两者比较@parameter : x_test  测试集数据y_test  测试集标签model  训练好的模型'''def randomNums(a, b, n):#产生n个不同的随机整数all_num = list(range(a, b))res = []while n:numpy.random.seed(0)index = math.floor(numpy.random.uniform() * len(all_num))res.append(all_num[index])del all_num[index]n -= 1return restest_idx = randomNums(0, 10000,20)  #测试集大小为10000,索引范围是[0,10000)pred_img=[]  #预测的图片d_label=[]   #预测图片的标签for i in test_idx:pred_img.append(x_test[i])d_label.append(y_test[i])pred_img = numpy.array(pred_img)#转换为ndarray格式pred_prob = model.predict(pred_img)  #预测概率pred_label = numpy.argmax(pred_prob,1)  #索引labels = ['T恤','裤子','套头衫','连衣裙','外套','凉鞋','衬衫','运动鞋','包','靴子']plt.figure(figsize=(14,14)) #显示前10张图像,并在图像上显示类别for i in range(20):   plt.subplot(4,5,i+1)plt.grid(False)plt.imshow(pred_img[i,:,:,0],cmap=plt.cm.binary)plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签t = labels[pred_label[i]]+'('+labels[d_label[i]]+')'plt.title(t) plt.show()if __name__ == '__main__':x_valid,x_train,y_valid,y_train,x_test,y_test = pretreatment()#model = model_get()history,model = model_train(x_train,y_train,x_valid,y_valid,True)#history,model = model_train(x_train,y_train,x_valid,y_valid)figshow(history)#data_test(x_test,y_test,model)#random_test(x_test,y_test,model)

CNN实验结果

 模型采用分步进行训练能够尽可能的让神经网络学习到更多特征,所以实验分为两步,第一步先使用默认的学习率进行训练,神经网络迭代5次得到初步的模型。模型在训练集和测试集上的准确率在不断的提升,且系统的损失函数都呈现下降趋势。Top1和Top准确率分别为90.83%和97.74%,神经网络的预测准确率有所上升。
1.初步训练和验证准确率

2.初步训练和验证的损失函数

3.初步Top1和Top2准确率

 实验进行第二次训练,能够有效的提高预测准确率。首先,将第一次训练的模型读出,再较小的学习率的基础上进行学习,能够尽可能的求得系统的最优值。具体实验操作是将学习率降低为0.001,再将第一步的模型迭代2次后得到最后的模型。最终再训练完成之后,模型的Top1和Top2准确率分别为92.44%和98.21%,此时的Top2也提高了,说明数据更加集中于准确值,模型的预测准确率进一步提高。在测试集上随机输入20个数据进行预测,实验显示的20张图片中,第一张图片预测错误,其他20张图片预测准确。这说明,模型在一定程度上能够较为准确的对数据集进行fashion_MNIST进行分类处理。
1.最终模型准确率
2.最终模型Top1和Top2准确率

3.随机输出模型预测

致谢

感谢诸君观看,如果感觉有用的话,点个赞吧!

tensorflow2.0 CNN fashion MNIST图像分类相关推荐

  1. tensorflow卷积神经网络实战:Fashion Mnist 图像分类与人马分类

    卷积神经网络实战:Fashion Mnist 图像分类与人马分类 一.FashionMnist的卷积神经网络模型 1.卷积VS全连接 2.卷积网络结构 3.卷积模型结构 1)Output Shape ...

  2. 基于tensorflow2.0+CNN实现手势识别(全)

    基于tensorflow2.0+CNN实现手势识别 环境:windows10.pycharm2017.python3.64.tensorflow2.0.opencv3 我在github上分享了代码以及 ...

  3. Pytorch初学实战(一):基于的CNN的Fashion MNIST图像分类

    1.引言 1.1.什么是Pytorch PyTorch是一个开源的Python机器学习库. 1.2.什么是CNN 卷积神经网络(Convolutional Neural Networks)是一种深度学 ...

  4. TensorFlow2.0学习笔记2-tf2.0两种方式搭建神经网络

    目录 一,TensorFlow2.0搭建神经网络八股 1)import  [引入相关模块] 2)train,test  [告知喂入网络的训练集测试集以及相应的标签] 3)model=tf.keras. ...

  5. TensorFlow2.0 教程-图像分类

    TensorFlow2.0 教程-图像分类 Tensorflow 2.0 教程持续更新: https://blog.csdn.net/qq_31456593/article/details/88606 ...

  6. TensorFlow2.0 学习笔记(三):卷积神经网络(CNN)

    欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 文章目录 欢迎关注WX公众号:[程序员管小亮] 专栏--TensorFlow学习笔记 一.神经网络的基本单位:神经元 二.卷 ...

  7. python cnn程序_python cnn训练(针对Fashion MNIST数据集)

    本文将和大家一起一步步尝试对Fashion MNIST数据集进行调参,看看每一步对模型精度的影响.(调参过程中,基础模型架构大致保持不变) 废话不多说,先上任务: 模型的主体框架如下(此为拿到的原始代 ...

  8. 〖TensorFlow2.0笔记21〗自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where!

    自定义数据集(宝可精灵数据集)实现图像分类+补充:tf.where! 文章目录 一. 数据集介绍以及加载 1.1. 数据集简单描述 1.2. 程序实现步骤 1.3. 加载数据的格式 1.4. map函 ...

  9. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一)

    音乐分类 前言 复现代码 MP3转mel CNN模型 训练结果 总结 前言 我在逛github的时候,偶然发现了一个项目:基于深度学习的音乐推荐.[VikramShenoy97].作者是基于CNN做的 ...

最新文章

  1. 项目/程序的流程走向
  2. 为什么明星公司会选择Go作为编程语言?
  3. 「预警」尽快升级FastJson版本,避免恶意请求导致OOM!
  4. C++默认参数注意事项
  5. 计划策略-10-净需求计划
  6. 诈尸了。不瞒您说,老坑从不填,天天开新坑
  7. jenkins api_接触Jenkins(Hudson)API,第2部分
  8. MySQL 数据类型和 Java 数据类型对照表
  9. Mac使用Homebrew安装Kafka
  10. PCL综述—三维图像处理
  11. input file 上传图片时,文件格式限制
  12. Java Code Examples for org.apache.ibatis.annotations.Insert
  13. 用于jqGrid获取SQL Server中数据的简单分页存储过程及sp_executesql的一点使用方法...
  14. axure 8 表格合并_多人编辑,自动汇总,领导可见所有?用 SeaTable 表格更简单
  15. Dxg——C# 开发笔记整理分类合集【所有的相关记录,都整理在此】
  16. 计算机无法启动print,本地计算机无法启动print spooler服务,错误1069怎么处理
  17. 华为路由器配置备忘录
  18. zencart 模板设计
  19. spring-cloud-starter-bus-kafka利用kafka消息总线实现动态刷新配置
  20. Day.js 一个轻量级的 JavaScript 时间日期处理库

热门文章

  1. 【安信可NB-IoT模组EC系列应用笔记⑨】使用CoAP协议接入OneNET Studio实现数据收发
  2. wordpress支持Markdown
  3. 将路径中的“\\”换成“/”的方法
  4. 使用Glide加载、缓存图片、Gif、解决背景出现浅绿色、GlideModules冲突
  5. Linux中 Nginx+uwsgi部署flask项目 Nginx负载均衡 反向代理
  6. LeetCode 2353. 设计食物评分系统(sortedcontainers)
  7. 物联网平台Node-red初涉——访问搭建的简易服务器
  8. python中trunc函数_Oracle trunc()函数的用法及四舍五入 round函数
  9. php判断范围,范围判断标签
  10. 基于java所写的学生选课管理系统