环境

tensorflow 2.1
最好用GPU

Cifar10数据集

CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题。任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵盖了10个类别:飞机, 汽车, 鸟, 猫, 鹿, 狗, 青蛙, 马, 船以及卡车。

下面代码仅仅只是做显示Cifar10数据集用

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tfdef showPic(X_train, y_train):# 看看数据集中的一些样本:每个类别展示一些classes = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']num_classes = len(classes)samples_per_class = 7for y, cls in enumerate(classes):idxs = np.flatnonzero(y_train == y)# 一个类别中挑出一些idxs = np.random.choice(idxs, samples_per_class, replace=False)for i, idx in enumerate(idxs):plt_idx = i * num_classes + y + 1plt.subplot(samples_per_class, num_classes, plt_idx)plt.imshow(X_train[idx].astype('uint8'))plt.axis('off')if i == 0:plt.title(cls)plt.show()
if __name__ == '__main__':(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()showPic(x_train, y_train)

模型

ResNet
SeNet
用ResNet 和SeNet网络训练Cifar10

训练数据:Cifar10
训练集上准确率:96%左右
验证集上准确率:87%左右
测试集上准确率86%-87%
训练时间在GPU上:一小时多
权重大小:5.08 MB

训练的历程

普通网络(65%左右)-> 数据增强(70%左右)->模型增强(进入ResNet 和SeNet) 80%左右 -> 模型的结构做了调整(86%)
开始的时候我也用tensorlfow 1.4训练过Cifar10. 但是没有跑出理想的准确率,总是在70%左右。后来也没有想过模型上增强直接跳到tensorflow2.1了。

下一步准备加入inception网络试一试结果如何

训练集上和验证集上训练结果

343/351 [============================>.] - ETA: 0s - loss: 0.1013 - sparse_categorical_accuracy: 0.9641
345/351 [============================>.] - ETA: 0s - loss: 0.1012 - sparse_categorical_accuracy: 0.9641
347/351 [============================>.] - ETA: 0s - loss: 0.1011 - sparse_categorical_accuracy: 0.9642
349/351 [============================>.] - ETA: 0s - loss: 0.1010 - sparse_categorical_accuracy: 0.9643
351/351 [==============================] - 15s 44ms/step - loss: 0.1008 - sparse_categorical_accuracy: 0.9643 - val_loss: 0.5324 - val_sparse_categorical_accuracy: 0.8682

下面是测试集上的结果

79/79 - 2s - loss: 0.4225 - sparse_categorical_accuracy: 0.8708
[0.42247119277149814, 0.8708]

下面是完整的代码,运行前建一下这个目录weights3_6,不想写代码自动化建了。
如果要训练Cifar100,直接把cifar10 改成cifar100就可以了。不需要改其它地方

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import time as time
import tensorflow.keras.preprocessing.image as image
import matplotlib.pyplot as plt
import osdef senet_block(inputs, ratio):shape = inputs.shapechannel_out = shape[-1]# print(shape)# (2, 28, 28, 32) , [1,28,28,1], [1,28,28,1]squeeze = layers.GlobalAveragePooling2D()(inputs)# [2, 1, 1, 32]# print(squeeze.shape)# 第二层,全连接层# [2,32]# print(squeeze.shape)shape_result = layers.Flatten()(squeeze)# print(shape_result.shape)# [32,2]shape_result = layers.Dense(int(channel_out / ratio), activation='relu')(shape_result)# shape_result = layers.BatchNormalization()(shape_result)# [2,32]shape_result = layers.Dense(channel_out, activation='sigmoid')(shape_result)# shape_result = layers.BatchNormalization()(shape_result)# 第四层,点乘# print('heres2')excitation_output = tf.reshape(shape_result, [-1, 1, 1, channel_out])# print(excitation_output.shape)h_output = excitation_output * inputsreturn h_outputdef res_block(input, input_filter, output_filter):res_x = layers.Conv2D(filters=output_filter, kernel_size=(3, 3), activation='relu', padding='same')(input)res_x = layers.BatchNormalization()(res_x )res_x = senet_block(res_x, 8)res_x = layers.Conv2D(filters=output_filter, kernel_size=(3, 3), activation=None, padding='same')(res_x )res_x = layers.BatchNormalization()(res_x )res_x = senet_block(res_x, 8)if input_filter == output_filter:identity = inputelse: #需要升维或者降维identity = layers.Conv2D(filters=output_filter, kernel_size=(1,1), padding='same')(input)x = layers.Add()([identity, res_x])output = layers.Activation('relu')(x)return outputdef my_model():inputs = keras.Input(shape=(32,32,3), name='img')h1 = layers.Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1), padding='same', activation='relu')(inputs)h1 = layers.BatchNormalization()(h1)h1 = senet_block(h1, 8)block1_out = res_block(h1, 16, 32)block1_out = layers.MaxPool2D(pool_size=(2, 2))(block1_out)# Resnet blockblock2_out = res_block(block1_out, 32,64)block2_out = layers.MaxPool2D(pool_size=(2, 2))(block2_out)block3_out = res_block(block2_out, 64, 128)block4_out = layers.MaxPool2D(pool_size=(2, 2))(block3_out)block4_out = res_block(block4_out, 128, 256)h3 = layers.GlobalAveragePooling2D()(block4_out)h3 = layers.Flatten()(h3)h3 = layers.BatchNormalization()(h3)h3 = layers.Dense(64, activation='relu')(h3)h3 = layers.BatchNormalization()(h3)outputs = layers.Dense(10, activation='softmax')(h3)deep_model = keras.Model(inputs, outputs, name='resnet')deep_model.compile(optimizer=keras.optimizers.Adam(),loss=keras.losses.SparseCategoricalCrossentropy(),#metrics=['accuracy'])metrics=[keras.metrics.SparseCategoricalAccuracy()])deep_model.summary()#keras.utils.plot_model(deep_model, 'my_resNet.png', show_shapes=True)return deep_modelcurrent_max_loss = 9999def train_my_model(deep_model):(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()train_datagen = image.ImageDataGenerator(rescale=1 / 255,rotation_range=40,  # 角度值,0-180.表示图像随机旋转的角度范围width_shift_range=0.2,  # 平移比例,下同height_shift_range=0.2,shear_range=0.2,  # 随机错切变换角度zoom_range=0.2,  # 随即缩放比例horizontal_flip=True,  # 随机将一半图像水平翻转fill_mode='nearest'  # 填充新创建像素的方法)test_datagen = image.ImageDataGenerator(rescale=1 / 255)validation_datagen = image.ImageDataGenerator(rescale=1 / 255)train_generator = train_datagen.flow(x_train[:45000], y_train[:45000], batch_size=128)# train_generator = train_datagen.flow(x_train, y_train, batch_size=128)validation_generator = validation_datagen.flow(x_train[45000:], y_train[45000:], batch_size=128)test_generator = test_datagen.flow(x_test, y_test, batch_size=128)begin_time = time.time()if os.path.isfile('./weights3_6/model.h5'):print('load weight')deep_model.load_weights('./weights3_6/model.h5')def save_weight(epoch, logs):global current_max_lossif(logs['val_loss'] is not None and logs['val_loss']< current_max_loss):current_max_loss = logs['val_loss']print('save_weight', epoch, current_max_loss)deep_model.save_weights('./weights3_6/model.h5')batch_print_callback = keras.callbacks.LambdaCallback(on_epoch_end=save_weight)callbacks = [tf.keras.callbacks.EarlyStopping(patience=4, monitor='loss'),batch_print_callback,# keras.callbacks.ModelCheckpoint('./weights/model.h5', save_best_only=True),tf.keras.callbacks.TensorBoard(log_dir='logs3_6')]print(train_generator[0][0].shape)history = deep_model.fit_generator(train_generator, steps_per_epoch=351, epochs=200, callbacks=callbacks,validation_data=validation_generator, validation_steps=39, initial_epoch = 0)result = deep_model.evaluate_generator(test_generator, verbose=2)print(result)print('time', time.time() - begin_time)def show_result(history):plt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.plot(history.history['sparse_categorical_accuracy'])plt.plot(history.history['val_sparse_categorical_accuracy'])plt.legend(['loss', 'val_loss', 'sparse_categorical_accuracy', 'val_sparse_categorical_accuracy'],loc='upper left')plt.show()print(history)show_result(history)def predict_module(deep_model):(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()import numpy as npif os.path.isfile('./weights3_6/model.h5'):print('load weight')deep_model.load_weights('./weights3_6/model.h5')print(y_test[0:20])for i in range(20):img = x_test[i][np.newaxis, :]/255y_ = deep_model.predict(img)v  = np.argmax(y_)print(v, y_test[i])def test_module(deep_model):(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()test_datagen = image.ImageDataGenerator(rescale=1 / 255)test_generator = test_datagen.flow(x_test, y_test, batch_size=128)begin_time = time.time()if os.path.isfile('./weights3_6/model.h5'):print('load weight')deep_model.load_weights('./weights3_6/model.h5')result = deep_model.evaluate_generator(test_generator, verbose=2)print(result)print('time', time.time() - begin_time)if __name__ == '__main__':deep_model = my_model()train_my_model(deep_model)#predict_module(deep_model)#test_module(deep_model)

参考训练结果

执行下面命令,访问http://localhost:6006/查看训练过程中,准确率和损失函数变化过程,

tensorboard --logdir=logs3_6


[深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码相关推荐

  1. [深度学习-TF2实践]应用Tensorflow2.x训练ResNet,SeNet和Inception模型在cifar10,测试集上准确率88.6%

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  2. 【PyTorch】深度学习实践之CNN高级篇——实现复杂网络

    本文目录 1. 串行的网络结构 2. GoogLeNet 2.1 结构分析 2.2 代码实现 2.3 结果 3. ResNet 3.1 网络分析 3.2 代码实现 3.3 结果 课后练习1:阅读并实现 ...

  3. 【深度学习】MLP/LeNet/AlexNet/GoogLeNet/ResNet在三个不同数据集上的分类效果实践

    本文是深度学习课程的实验报告 使用了MLP/LeNet/AlexNet/GoogLeNet/ResNet五个深度神经网络模型结构和MNIST.Fashion MNIST.HWDB1三个不同的数据集,所 ...

  4. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(CIFAR10 数据集篇)!

    今天我们将使用 Pytorch 来继续实现 LeNet-5 模型,并用它来解决 CIFAR10 数据集的识别. 正文开始! 二.使用LeNet-5网络结构创建CIFAR-10识别分类器 LeNet-5 ...

  5. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  6. 《解析卷积神经网络—深度学习实践手册》—学习笔记

    书籍链接 百度网盘 谷歌云盘 绪论 机器学习是人工智能的一个分支,它致力于研究如何通过计算的手段,利用经验(experience)来改善计算机系统自身的性能.通过从经验中获取知识(knowledge) ...

  7. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  8. 在浏览器中进行深度学习:TensorFlow.js (十二)异常检测算法

    2019独角兽企业重金招聘Python工程师标准>>> 异常检测是机器学习领域常见的应用场景,例如金融领域里的信用卡欺诈,企业安全领域里的非法入侵,IT运维里预测设备的维护时间点等. ...

  9. 《PyTorch 深度学习实践》第10讲 卷积神经网络(基础篇)

    文章目录 1 卷积层 1.1 torch.nn.Conv2d相关参数 1.2 填充:padding 1.3 步长:stride 2 最大池化层 3 手写数字识别 该专栏内容为对该视频的学习记录:[&l ...

最新文章

  1. C++ Char int string关系
  2. CComboBox 类详细说明
  3. pod install 失败 Сocoapods trunk URL couldn't be downloaded
  4. ERP系统模块完全解析──工作中心
  5. ORACLE定时备份
  6. 3、oracle数据库的语法基础
  7. WinForm中DataGridView的TextBoxColumm换行
  8. java去掉的行_Java实现去掉每行的行号
  9. 字节跳动2019春招笔试——找零(JavaScript)
  10. linux之tar使用技巧
  11. sql关联查询子表的第一条_SQLAlchemy(8)惰性查询
  12. 如何更改java应用程序标题栏默认图标
  13. Java——包装器类
  14. 黑客技术之初学者编程入门
  15. CTC算法论文阅读笔记:Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurren
  16. 81、【backtrader基金策略】如果每周定投一次,在周几定投收益率更高?
  17. win7计算机开机启动项设置,如何设置WIN7开机启动项?
  18. iOS UDID与UUID
  19. 关于动物识别论文的阅读笔记——青鳉鱼的个体识别和“面部反转效应”
  20. IDEA 2019注册码(激活码)

热门文章

  1. dojo 加载自定义module的路径问题
  2. Linux下查看物理CPU、逻辑CPU和CPU核数
  3. 你好,了解一下Java 14带来的一系列新功能
  4. 行为设计模式 - 责任链设计模式
  5. JAVA创建一个私有域_使用java基础反射访问私有域、方法和构造函数
  6. 04737 c++ 自学考试2019版 第五章程序设计题 1
  7. zsh关于.zprofile .zlogin .zshrc .zshenv文件中环境变量的加载
  8. 【HTML】HTML5中的Web Notification桌面通知
  9. python使用md5加密_如何使用Python构建加密机器人并将其连接到Facebook Messenger
  10. 领略ES10的新功能