前言

对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:
飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。

首先来看下cifar10数据集:

这里面一共有五个训练文件,一个测试文件。网上的教程大多都是需要以下五个文件,在这里自己实现了单文件的训练代码。代码需要提前下载好cifar10数据,CIFAR-10 python version版本的哦~

源代码

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 24 09:43:25 2018@author: new
"""
import numpy as np
import os
import sys
import keras.backend as K
from six.moves import cPickle
import cv2
import numpy as np
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import SGD, Adadelta, Adagrad
from keras.utils import np_utils, generic_utils
from keras.utils import plot_model
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sioos.environ["CUDA_VISIBLE_DEVICES"] = "0"
gpu_options = tf.GPUOptions(allow_growth=True)
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))#【0】设置超参
batch_size = 32
num_classes = 10
epochs = 5
data_augmentation = Truedef load_batch(fpath, label_key='labels'):f = open(fpath, 'rb')if sys.version_info < (3,):d = cPickle.load(f)else:d = cPickle.load(f, encoding='bytes')# decode utf8d_decoded = {}for k, v in d.items():d_decoded[k.decode('utf8')] = vd = d_decodedf.close()data = d['data']labels = d[label_key]data = data.reshape(data.shape[0], 3, 32, 32)return data, labels
def load_data():dirname = 'C:/Users/new/Desktop/cifar-10-batches-py'num_train_samples = 50000x_train = np.empty((num_train_samples, 3, 32, 32), dtype='uint8')y_train = np.empty((num_train_samples,), dtype='uint8')for i in range(1, 6):fpath = os.path.join(dirname, 'data_batch_' + str(i))(x_train[(i - 1) * 10000: i * 10000, :, :, :],y_train[(i - 1) * 10000: i * 10000]) = load_batch(fpath)fpath = os.path.join(dirname, 'test_batch')x_test, y_test = load_batch(fpath)y_train = np.reshape(y_train, (len(y_train), 1))y_test = np.reshape(y_test, (len(y_test), 1))if K.image_data_format() == 'channels_last':x_train = x_train.transpose(0, 2, 3, 1)x_test = x_test.transpose(0, 2, 3, 1)return (x_train, y_train), (x_test, y_test)(x_train,  y_train), (x_test, y_test)=load_data()
print('x_train shape:', x_train.shape)
print('y_train shape:', y_train.shape)
print('x_test shape:', x_test.shape)
print('y_test shape:', y_test.shape)plt.figure(1)
plt.imshow(x_train[0]) # 显示第一张训练图片
plt.figure(2)
plt.imshow(x_test[0])  # 显示第一张测试图片# 【3】将标签转化成 one-hot 编码
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 【4】构建深度CNN序贯模型
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))print(model.summary())                              # 打印模型概况# 【5】编译模型
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)#初始化一个 RMSprop 优化器
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])# 【6】数据预处理/增强+模型训练
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255if not data_augmentation:print('Not using data augmentation.')model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test, y_test),)
else:print('Using real-time data augmentation.')# ImageDataGenerator:图片生成器,用以生成一个batch的图像数据,训练时该函数会无限生成数据# 直到达到规定的epoch次数。图片生成(CPU)和训练(GPU)并行执行。datagen = ImageDataGenerator(featurewise_center=False,  samplewise_center=False,  featurewise_std_normalization=False,  samplewise_std_normalization=False,  zca_whitening=False, rotation_range=0,        # 随机旋转的角度范围width_shift_range=0.1,   # 随机水平偏移的幅度范围height_shift_range=0.1,  horizontal_flip=True,    # 随机水平翻转vertical_flip=False)     datagen.fit(x_train)         # 计算样本的统计信息,进行数据预处理(如去中心化,标准化)model.fit_generator(datagen.flow(x_train, y_train,           # datagen.flow()不断生成一个batch的数据用于模型训练batch_size=batch_size),epochs=epochs,validation_data=(x_test, y_test),workers=4)# 【7】保存模型以及权重
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'
if not os.path.isdir(save_dir):os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)# 【8】测试集评估模型
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

实验结果

Using TensorFlow backend.
x_train shape: (50000, 32, 32, 3)
y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3)
y_test shape: (10000, 1)
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_1 (Conv2D)            (None, 32, 32, 32)        896
_________________________________________________________________
activation_1 (Activation)    (None, 32, 32, 32)        0
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 30, 30, 32)        9248
_________________________________________________________________
activation_2 (Activation)    (None, 30, 30, 32)        0
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 15, 15, 32)        0
_________________________________________________________________
dropout_1 (Dropout)          (None, 15, 15, 32)        0
_________________________________________________________________
conv2d_3 (Conv2D)            (None, 15, 15, 64)        18496
_________________________________________________________________
activation_3 (Activation)    (None, 15, 15, 64)        0
_________________________________________________________________
conv2d_4 (Conv2D)            (None, 13, 13, 64)        36928
_________________________________________________________________
activation_4 (Activation)    (None, 13, 13, 64)        0
_________________________________________________________________
max_pooling2d_2 (MaxPooling2 (None, 6, 6, 64)          0
_________________________________________________________________
dropout_2 (Dropout)          (None, 6, 6, 64)          0
_________________________________________________________________
flatten_1 (Flatten)          (None, 2304)              0
_________________________________________________________________
dense_1 (Dense)              (None, 512)               1180160
_________________________________________________________________
activation_5 (Activation)    (None, 512)               0
_________________________________________________________________
dropout_3 (Dropout)          (None, 512)               0
_________________________________________________________________
dense_2 (Dense)              (None, 10)                5130
_________________________________________________________________
activation_6 (Activation)    (None, 10)                0
=================================================================
Total params: 1,250,858
Trainable params: 1,250,858
Non-trainable params: 0
_________________________________________________________________
None
Using real-time data augmentation.
Epoch 1/5
1563/1563 [==============================] - 102s 65ms/step - loss: 1.8342 - acc: 0.3229 - val_loss: 1.5518 - val_acc: 0.4325
Epoch 2/5
1563/1563 [==============================] - 120s 77ms/step - loss: 1.5533 - acc: 0.4310 - val_loss: 1.4069 - val_acc: 0.4883
Epoch 3/5
1563/1563 [==============================] - 108s 69ms/step - loss: 1.4322 - acc: 0.4846 - val_loss: 1.2653 - val_acc: 0.5508
Epoch 4/5
1563/1563 [==============================] - 107s 68ms/step - loss: 1.3429 - acc: 0.5180 - val_loss: 1.1613 - val_acc: 0.5869
Epoch 5/5
1563/1563 [==============================] - 107s 69ms/step - loss: 1.2704 - acc: 0.5454 - val_loss: 1.1002 - val_acc: 0.6138
Saved trained model at C:\Users\new\Desktop\chapter_2\saved_models\keras_cifar10_trained_model.h5
10000/10000 [==============================] - 6s 611us/step
Test loss: 1.1002309656143188
Test accuracy: 0.6138

由于迭代的次数比较少,所以测试集上的准确率不是太高,可以多迭代几次试下哦~~~

加载模型进行预测

model = load_model('C:/Users/new/Desktop/chapter_2/saved_models/keras_cifar10_trained_model.h5')
print('test after load: ', model.predict(x_test[0:2])) 

测试后的结果:

test after load:  [[0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+000.00000000e+00 2.44679976e-34 0.00000000e+00 0.00000000e+001.05497485e-29 0.00000000e+00][0.00000000e+00 1.46545753e-08 0.00000000e+00 0.00000000e+000.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+001.00000000e+00 0.00000000e+00]]

这里是one-hot向量,最大的那个就是预测出的类别~~~

keras训练cifar10数据集源代码相关推荐

  1. 基于Keras搭建cifar10数据集训练预测Pipeline

    基于Keras搭建cifar10数据集训练预测Pipeline 钢笔先生关注 0.5412019.01.17 22:52:05字数 227阅读 500 Pipeline 本次训练模型的数据直接使用Ke ...

  2. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  3. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  4. 使用caffe自带模型训练cifar10数据集

      前面训练了mnist数据集!但caffe自带的数据集还有cifar10数据集.同样cifar10数据集也是分类数据集,共分10类.cifar10数据集中包含60000张32x32的彩色图片.(其中 ...

  5. 深度学习:使用pytorch训练cifar10数据集(基于Lenet网络)

    文档基于b站视频:https://www.bilibili.com/video/BV187411T7Ye 流程 model.py --定义LeNet网络模型 train.py --加载数据集并训练,训 ...

  6. keras笔记(4)-使用Keras训练大规模数据集

    简介 官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求,可以参考我的笔记.但是在有些场合下,需要自己读取大规模数据以及对 ...

  7. LeNet训练Cifar-10数据集代码详解以及输出结果

    首先讲一下交叉熵损失函数,里面包含了Softmax函数和NLL损失函数 接下来讲一下NLL损失函数 Legative Log Likelihood Loss,中文名称是最大似然或者log似然代价函数 ...

  8. MXNet学习:试用卷积-训练CIFAR-10数据集

    第一次用卷积,看的别人的模型跑的CIFAR-10,不过吐槽一下...我觉着我的965m加速之后比我的cpu算起来没快多少..正确率64%的样子,没达到模型里说的75%,不知道问题出在哪里 import ...

  9. 深度学习训练的时候gpu占用0_26秒单GPU训练CIFAR10,Jeff Dean也点赞的深度学习优化技巧...

    选自myrtle.ai 机器之心编译机器之心编辑部 26 秒内用 ResNet 训练 CIFAR10?一块 GPU 也能这么干.近日,myrtle.ai 科学家 David Page 提出了一大堆针对 ...

  10. 【小白学习keras教程】二、基于CIFAR-10数据集训练简单的MLP分类模型

    @Author:Runsen 分类任务的MLP 当目标(y)是离散的(分类的) 对于损失函数,使用交叉熵:对于评估指标,通常使用accuracy 数据集描述 CIFAR-10数据集包含10个类中的60 ...

最新文章

  1. 【Android 安全】DEX 加密 ( DEX 加密使用到的相关工具 | dx 工具 | zipalign 对齐工具 | apksigner 签名工具 )
  2. 二级菜单--竖排---HTML
  3. ajax刷新数据库数据,ajax删除数据刷新数据库
  4. c语言为什么有这么多的编程环境?_为什么98%的程序员学编程都会从C语言开始?...
  5. spring的@ControllerAdvice注解
  6. sudo apt-get常用命令
  7. OpenShift 之 Quarkus(1)创建第一个Quarkus应用
  8. L1-002. 打印沙漏-PAT团体程序设计天梯赛GPLT
  9. 【PDF转换 编辑】 推荐几个好用的pdf相关的网址和软件
  10. 如何用计算机tan角度换算,tan角度换算(tan值求角度计算器)
  11. 【计量经济学】固定效应、随机效应、相关随机效应
  12. 精美的wordpress企业主题模板
  13. win系统服务器白名单,win10系统如何添加白名单 windows10下添加白名单的方法
  14. vue控制台报错Extraneous non-props attributes (class) were passed to component but could not be automatica
  15. Linux怎么同步另一台设备的时间
  16. Python 送你一棵圣诞树
  17. 什么是千行代码缺陷率?
  18. mysql忘记root密码如何重新设置
  19. RH850从0搭建Autosar开发环境【1】- 如何创建Davinci Configurator配置工程
  20. 广告平台的商业模式,行业分析

热门文章

  1. 计算机网络与应用第三次笔记
  2. 问题-[致命错误] Project1.dpr(1): Unit not found: 'System.pas' or binary equivalents (DCU,DPU)
  3. Javascript中相同Function使用多个名称
  4. 小议 - 来自《XX时代XX公司》的笔试编程题目
  5. [转]六步使用ICallbackEventHandler实现无刷新回调
  6. JAVA随机数生成 | Math.random()方法 | 随机生成int、double类型
  7. eclipse maven 打war包的几种方式
  8. jenkins安装与自动部署详细说明
  9. Thingsboard 3.1.0 - REST API
  10. Sql语句查询某列A相同值的另一列B最大值的数据