基于Keras搭建cifar10数据集训练预测Pipeline

钢笔先生关注

0.5412019.01.17 22:52:05字数 227阅读 500

Pipeline

本次训练模型的数据直接使用Keras.datasets.cifar10.load_data()得到,模型建立是通过Sequential搭建。

重点思考的内容是如何应用训练过的模型进行实际预测,里面牵涉到一些细节,需要注意。同时,Keras提供的ImageDataGenerator为模型训练时提供数据输入,之前有总结过这个类,并给出了从文件系统中加载原始图片数据的方法。

模型搭建

from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
import os# 指定超参数
batch_size = 32
num_classes = 10
epochs = 50
data_augmentation = True # 数据增强
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)# 搭建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',optimizer=opt,metrics=['accuracy'])x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255# 如果不用模型增强
if not data_augmentation:print('Not using data augmentation.')model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test, y_test),shuffle=True)# 使用模型增强
else:print('Using real-time data augmentation.')# This will do preprocessing and realtime data augmentation:datagen = ImageDataGenerator(featurewise_center=False,  # set input mean to 0 over the datasetsamplewise_center=False,  # set each sample mean to 0featurewise_std_normalization=False,  # divide inputs by std of the datasetsamplewise_std_normalization=False,  # divide each input by its stdzca_whitening=False,  # apply ZCA whiteningzca_epsilon=1e-06,  # epsilon for ZCA whiteningrotation_range=0,  # randomly rotate images in the range (degrees, 0 to 180)# randomly shift images horizontally (fraction of total width)width_shift_range=0.1,# randomly shift images vertically (fraction of total height)height_shift_range=0.1,shear_range=0.,  # set range for random shearzoom_range=0.,  # set range for random zoomchannel_shift_range=0.,  # set range for random channel shifts# set mode for filling points outside the input boundariesfill_mode='nearest',cval=0.,  # value used for fill_mode = "constant"horizontal_flip=True,  # randomly flip imagesvertical_flip=False,  # randomly flip images# set rescaling factor (applied before any other transformation)rescale=None,# set function that will be applied on each inputpreprocessing_function=None,# image data format, either "channels_first" or "channels_last"data_format=None,# fraction of images reserved for validation (strictly between 0 and 1)validation_split=0.0)# Compute quantities required for feature-wise normalization# (std, mean, and principal components if ZCA whitening is applied).datagen.fit(x_train)# Fit the model on the batches generated by datagen.flow().history = model.fit_generator(datagen.flow(x_train, y_train,batch_size=batch_size),epochs=epochs,steps_per_epoch = 600,validation_data=(x_test, y_test),validation_steps = 10,workers=4)# Save model and weights
if not os.path.isdir(save_dir):os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

训练完毕后,模型保存为:keras_cifar10_trained_model.h5

使用预训练模型

# 使用已经训练好的参数来加载模型from keras.models import load_modelmodel = load_model('./saved_models/keras_cifar10_trained_model.h5')model.summary()'''
Layer (type)                 Output Shape              Param #
=================================================================
conv2d_9 (Conv2D)            (None, 32, 32, 32)        896
_________________________________________________________________
activation_13 (Activation)   (None, 32, 32, 32)        0
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 30, 30, 32)        9248
_________________________________________________________________
activation_14 (Activation)   (None, 30, 30, 32)        0
_________________________________________________________________
max_pooling2d_5 (MaxPooling2 (None, 15, 15, 32)        0
_________________________________________________________________
dropout_7 (Dropout)          (None, 15, 15, 32)        0
_________________________________________________________________
conv2d_11 (Conv2D)           (None, 15, 15, 64)        18496
_________________________________________________________________
activation_15 (Activation)   (None, 15, 15, 64)        0
_________________________________________________________________
conv2d_12 (Conv2D)           (None, 13, 13, 64)        36928
_________________________________________________________________
activation_16 (Activation)   (None, 13, 13, 64)        0
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 6, 6, 64)          0
_________________________________________________________________
dropout_8 (Dropout)          (None, 6, 6, 64)          0
_________________________________________________________________
flatten_3 (Flatten)          (None, 2304)              0
_________________________________________________________________
dense_5 (Dense)              (None, 512)               1180160
_________________________________________________________________
activation_17 (Activation)   (None, 512)               0
_________________________________________________________________
dropout_9 (Dropout)          (None, 512)               0
_________________________________________________________________
dense_6 (Dense)              (None, 10)                5130
_________________________________________________________________
activation_18 (Activation)   (None, 10)                0
=================================================================
Total params: 1,250,858
Trainable params: 1,250,858
Non-trainable params: 0
'''

识别测试集图片

lst= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
def onehot_to_label(res):label = ''for i in range(len(res[0])):if res[0][i] == 1:label = lst[i]return labeldef softmax_to_label(res):label = ''index = res[0].argmax()label = lst[index]return label# 识别测试集图片
test_image = x_test[100].reshape([1,32,32,3])
test_image.shape
res = model.predict(test_image)
label = softmax_to_label(res)
print(label)

本地加载图片识别

# 自己加载raw image进行识别
from PIL import Image
from keras.preprocessing.image import img_to_array
import numpy as npimage = Image.open('./images/airplane.jpeg') # 加载图片
image = image.resize((32,32))
image = img_to_array(image)# 加载进来之后开始预测
image = image.reshape([1,32,32,3]) # 需要reshape到四维张量才行
res = model.predict(image)
label = softmax_to_label(res)
print("The image is: ", label)# 或者整合为一个函数
def image_to_array(path):image = Image.open(path)image = image.resize((32,32),Image.NEAREST) # 会将图像整体缩放到指定大小,不是裁剪image = img_to_array(image) # 变成数组image = image.reshape([1,32,32,3]) # reshape到4维张量return image

使用时注意到输入到网络的数据是张量,且需要reshape到四维,因为按照批量往里输入的时候,也是四维,单独输入一张图片,使用方式相同。

基于Keras搭建cifar10数据集训练预测Pipeline相关推荐

  1. 基于Keras搭建mnist数据集训练识别的Pipeline

    搭建模型 import tensorflow as tf from tensorflow import keras# get data (train_images, train_labels), (t ...

  2. cifar10数据集测试有多少张图_pytorch VGG11识别cifar10数据集(训练+预测单张输入图片操作)...

    首先这是VGG的结构图,VGG11则是红色框里的结构,共分五个block,如红框中的VGG11第一个block就是一个conv3-64卷积层: 一,写VGG代码时,首先定义一个 vgg_block(n ...

  3. 基于keras 搭建LSTM GRU模型预测 共享单车使用情况 完整代码+数据 数据分析 计算机毕设

    项目运行教程:https://www.bilibili.com/video/BV1nT411k7dT/?spm_id_from=333.999.0.0 附完整代码数据:

  4. 基于Keras的LSTM多变量时间序列预测(北京PM2.5数据集pollution.csv)

                                 基于Keras的LSTM多变量时间序列预测 传统的线性模型难以解决多变量或多输入问题,而神经网络如LSTM则擅长于处理多个变量的问题,该特性使 ...

  5. DL之LSTM:基于《wonderland爱丽丝梦游仙境记》小说数据集利用LSTM算法(层加深,基于keras)对单个character字符预测

    DL之LSTM:基于<wonderland爱丽丝梦游仙境记>小说数据集利用LSTM算法(层加深,基于keras)对单个character字符预测 目录 基于<wonderland爱丽 ...

  6. TF之CNN:基于CIFAR-10数据集训练、检测CNN(2+2)模型(TensorBoard可视化)

    TF之CNN:基于CIFAR-10数据集训练.检测CNN(2+2)模型(TensorBoard可视化) 目录 1.基于CIFAR-10数据集训练CNN(2+2)模型代码 2.检测CNN(2+2)模型 ...

  7. 基于Keras搭建CNN、TextCNN文本分类模型

    基于Keras搭建CNN.TextCNN文本分类模型 一.CNN 1.1 数据读取分词 1.2.数据编码 1.3 数据序列标准化 1.4 构建模型 1.5 模型验证 二.TextCNN文本分类 2.1 ...

  8. 基于Keras搭建LSTM网络实现文本情感分类

    基于Keras搭建LSTM网络实现文本情感分类 一.语料概况 1.1 数据统计 1.1.1 查看样本均衡情况,对label进行统计 1.1.2 计句子长度及长度出现的频数 1.1.3 绘制句子长度累积 ...

  9. cifar10数据集训练

    有关CIFAR-10数据集 (1)CIFAR-10数据集由10个类的60000个32x32彩色图像组成,每个类有6000个图像.有50000个训 练图像和10000个测试图像. (2)数据集分为五个训 ...

最新文章

  1. 马斯克自曝:至暗时刻求苹果收购,库克连瞧都没瞧一眼
  2. 检测跟踪 DeepSOCIAL:基于YOLOv4的人群距离监测 集检测、跟踪以及逆透视映射一体的系统
  3. Android自定义sleep图,android自定义view实现钟表效果
  4. Project Server的页面如何修改Text
  5. 杭电1597_find the nth digit
  6. java中构造器快捷方式_java 构造器 (构造方法)
  7. linux内核 mpls,将MPLS编译进linux内核中
  8. 对于大型公司项目平台选择j2ee的几层认识
  9. 企业级 Spring Boot 教程 (十四)用restTemplate消费服务
  10. 如何下载python模块_Python第三方库(模块)下载和安装(使用pip命令)
  11. tensorflow--模型的保存和提取
  12. leetcode - 226. 翻转二叉树
  13. Python中表达式和语句及for、while循环练习
  14. 「CJOJ2723」Reserve
  15. educoder MongoDB 实验——数据库优化
  16. (转)人工智能步入金融领域的主流玩法
  17. 并发编程学习之线程池工作原理
  18. 手机电子词典_【摘要】陈玉珍:词典使用对搭配产出与记忆保持的效能研究—— 以手机在线词典为例...
  19. otsu阈值分割算法原理_otsu(大津法阈值分割原理)
  20. 利用ISA防火墙实现安全快速上网

热门文章

  1. 文本编辑器创建工具栏
  2. c语言 char operator,C语言取模运算符(modulus operator)“%”的作用是什么
  3. c++ 读取txt文件保存到vect
  4. OpenMP入门教程(二)reduce sum
  5. 使用pytorch动手实现LSTM模块
  6. 链表c的经典实现(一)
  7. batch size 训练时间_深度学习 | Batch Size大小对训练过程的影响
  8. Leetcode 剑指 Offer 04. 二维数组中的查找 (每日一题 20210727)
  9. mas714 笔记:undecidability
  10. 产品经验谈:阿里B2B电商-新零售产地供应链的思考与实践