一、下载数据集并展示

CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。

与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:

• CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。

• CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。

• 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。

import matplotlib.pyplot as plt
import tensorflow as tf
from keras import datasets, layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%config Completer.use_jedi = False(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(10, 10))
for i in range(10):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()#查看图片信息
print('图片尺寸为:',train_images[0].shape)
print('训练集图片个数为:',len(train_images))
print('测试集图片个数为:',len(test_images))

如果使用代码下载失败,那么去到cifar10数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,将下载后的文件存放在 ~./keras/datasets目录下,~表示当前用户路径。

二、构建模型

# 构造网络模型
model = models.Sequential([tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(32, 32, 3)),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(64, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(128, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),#转换为一维tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10, activation='softmax'),
])# 查看网络结构
model.summary()

三、定义损失函数优化器

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

四、数据增强

注意此处只对训练数据集做随机翻转、随机裁剪、平移等,测试集只需归一化。

train_image = ImageDataGenerator(rescale=1/255,#随机翻转rotation_range=40,#平移width_shift_range=0.2,height_shift_range=0.2,#随机裁剪shear_range=0.2,#随机缩放zoom_range=0.2,horizontal_flip=True,fill_mode='nearest'
)test_image = ImageDataGenerator(rescale=1/255,
)

五、模型训练

history = model.fit(train_images, train_labels, epochs=20,validation_data=(test_images, test_labels))

六、绘制acc

# 测试模型并绘制loss图(history的使用)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.0, 1.0])
plt.legend(loc='lower right')
plt.show()

注:此处只做流程演示并未调整参数,可以自行优化。

Tensorflow2+训练CIFAR10相关推荐

  1. [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码

    环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...

  2. 深度学习训练的时候gpu占用0_26秒单GPU训练CIFAR10,Jeff Dean也点赞的深度学习优化技巧...

    选自myrtle.ai 机器之心编译机器之心编辑部 26 秒内用 ResNet 训练 CIFAR10?一块 GPU 也能这么干.近日,myrtle.ai 科学家 David Page 提出了一大堆针对 ...

  3. 【深度学习】训练CIFAR-10数据集实现分类加测试

    网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...

  4. 图解半监督学习FixMatch,只用10张标注图片训练CIFAR10

    2020-05-25 11:20:08 作者:amitness 编译:ronghuaiyang 导读 仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%,来看看是怎 ...

  5. tensorflow2 训练和预测使用不同的输出层、获取权重参数

    目标: youtubeNet通过训练tensorflow2时设置不同的激活函数,训练和预测采用不同的分支,然后可以在训练和测试时,把模型进行分离,得到训练和预测时,某些层的参数不同.可以通过类似迁移学 ...

  6. 使用caffe自带模型训练cifar10数据集

      前面训练了mnist数据集!但caffe自带的数据集还有cifar10数据集.同样cifar10数据集也是分类数据集,共分10类.cifar10数据集中包含60000张32x32的彩色图片.(其中 ...

  7. matlab训练cifar10,认识CIFAR-10数据集

    CIFAR-10是一个广泛使用的标准数据集,里面包含了各种阿猫阿狗阿汽车--为了在后续学习实验中用好它,首先需要认识了解一下. 把tensorflow官方model下的cifar10文件复制到工作区, ...

  8. 深度学习:使用pytorch训练cifar10数据集(基于Lenet网络)

    文档基于b站视频:https://www.bilibili.com/video/BV187411T7Ye 流程 model.py --定义LeNet网络模型 train.py --加载数据集并训练,训 ...

  9. 现代卷积神经网络(NiN),并使用NIN训练CIFAR10的分类

    专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet). ...

最新文章

  1. Gradle 简单使用
  2. 误删/var/lib/dpkg/info,文件解决方案(是否完全解决,不确定)
  3. offer该怎么选:大公司or小公司?高薪or期权?
  4. java development kie_java – 直接从存储库加载Drools/KIE Workbench工件
  5. Centos 6 搭建安装 Gitlab
  6. jdbc:mysql:replication_使用Mysql的Replication功能实现数据库同步
  7. 003.ASP.NET MVC集中管理Session
  8. 这个小众副业,一次200,有人月入3万!
  9. 乡镇政府网络智能办公系统(乡镇OA)应用【乡镇信息化经验】
  10. 计算机配件价格上涨,显卡涨价风声再起 PC配件涨价什么时候是个头
  11. 单目标跟踪——个人笔记
  12. 基于寒武纪CNCodec 做视频编解码遇到的一些问题
  13. [ZZ]AppiumForWindows 菜鸟计划合集
  14. 怎么在PDF上修改文字,PDF修改文字的步骤
  15. 【iot-manager】(1)IOT商业化和未来竞争、淘汰还在继续,需要折腾一个开源IOT系统,参考Rancher开源Octopus:IoT设备管理系统做一个物联网管理系统
  16. No connection could be made because the target machine actively refused it 127.0.0.1:8888
  17. props传递对象_vue组件中使用props传递数据的实例详解
  18. python中true是什么意思_Python解惑之True和False详解
  19. Java for Web学习笔记(三五):自定义tag(3)TLDS和Tag Handler
  20. 双 JK 触发器 74LS112 逻辑功能。真值表_D触发器示例

热门文章

  1. android手写汉字,Android 手写输入的实现(保存涂鸦文字)
  2. WinMount开发者刘涛涛
  3. SecureCRT 5.1.3 及注册码
  4. eclipse中文版本转英文版
  5. c++获取计算机注册码,在c++中,如何能获得计算机的机器码?
  6. 微软飞行模拟器android,微软飞行模拟器2020
  7. COOX基础培训之SCADA(三)
  8. 按字母A——Z排列的中国城市(地级市)json数据
  9. flink onTimer定时器实现定时需求
  10. 【cocos2d-x 手游研发----怪物智能AI】