Tensorflow2+训练CIFAR10
一、下载数据集并展示
CIFAR-10 是由 Hinton 的学生 Alex Krizhevsky 和 Ilya Sutskever 整理的一个用于识别普适物体的小型数据集。一共包含 10 个类别的 RGB 彩色图 片:飞机( a叩lane )、汽车( automobile )、鸟类( bird )、猫( cat )、鹿( deer )、狗( dog )、蛙类( frog )、马( horse )、船( ship )和卡车( truck )。图片的尺寸为 32×32 ,数据集中一共有 50000 张训练圄片和 10000 张测试图片。 CIFAR-10 的图片样例如图所示。
与 MNIST 数据集中目比, CIFAR-10 具有以下不同点:
• CIFAR-10 是 3 通道的彩色 RGB 图像,而 MNIST 是灰度图像。
• CIFAR-10 的图片尺寸为 32×32, 而 MNIST 的图片尺寸为 28×28,比 MNIST 稍大。
• 相比于手写字符, CIFAR-10 含有的是现实世界中真实的物体,不仅噪声很大,而且物体的比例、 特征都不尽相同,这为识别带来很大困难。
import matplotlib.pyplot as plt
import tensorflow as tf
from keras import datasets, layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
%config Completer.use_jedi = False(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer','dog', 'frog', 'horse', 'ship', 'truck']plt.figure(figsize=(10, 10))
for i in range(10):plt.subplot(5, 5, i + 1)plt.xticks([])plt.yticks([])plt.grid(False)plt.imshow(train_images[i], cmap=plt.cm.binary)plt.xlabel(class_names[train_labels[i][0]])
plt.show()#查看图片信息
print('图片尺寸为:',train_images[0].shape)
print('训练集图片个数为:',len(train_images))
print('测试集图片个数为:',len(test_images))
如果使用代码下载失败,那么去到cifar10数据集下载地址:https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz,将下载后的文件存放在 ~./keras/datasets目录下,~表示当前用户路径。
二、构建模型
# 构造网络模型
model = models.Sequential([tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(32, 32, 3)),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(64, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),tf.keras.layers.Conv2D(128, (3,3), activation='relu'),tf.keras.layers.MaxPooling2D(2,2),#转换为一维tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dropout(0.5),tf.keras.layers.Dense(10, activation='softmax'),
])# 查看网络结构
model.summary()
三、定义损失函数优化器
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])
四、数据增强
注意此处只对训练数据集做随机翻转、随机裁剪、平移等,测试集只需归一化。
train_image = ImageDataGenerator(rescale=1/255,#随机翻转rotation_range=40,#平移width_shift_range=0.2,height_shift_range=0.2,#随机裁剪shear_range=0.2,#随机缩放zoom_range=0.2,horizontal_flip=True,fill_mode='nearest'
)test_image = ImageDataGenerator(rescale=1/255,
)
五、模型训练
history = model.fit(train_images, train_labels, epochs=20,validation_data=(test_images, test_labels))
六、绘制acc
# 测试模型并绘制loss图(history的使用)
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.ylim([0.0, 1.0])
plt.legend(loc='lower right')
plt.show()
注:此处只做流程演示并未调整参数,可以自行优化。
Tensorflow2+训练CIFAR10相关推荐
- [深度学习-实践]Tensorflow 2.x应用ResNet SeNet网络训练cifar10数据集的模型在测试集上准确率 86%-87%,含完整代码
环境 tensorflow 2.1 最好用GPU Cifar10数据集 CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题.任务的目标对一组32x32 RGB的图像进行分类,这个数据集涵 ...
- 深度学习训练的时候gpu占用0_26秒单GPU训练CIFAR10,Jeff Dean也点赞的深度学习优化技巧...
选自myrtle.ai 机器之心编译机器之心编辑部 26 秒内用 ResNet 训练 CIFAR10?一块 GPU 也能这么干.近日,myrtle.ai 科学家 David Page 提出了一大堆针对 ...
- 【深度学习】训练CIFAR-10数据集实现分类加测试
网上有很多博主写的训练CIFAR-10的代码,本次只是单纯记录一下自己调试的一个程序,对于初学深度学习的小白可以参考,如有不对,请多多见谅!!! 一.CIFAR-10数据集由10个类的60000个32 ...
- 图解半监督学习FixMatch,只用10张标注图片训练CIFAR10
2020-05-25 11:20:08 作者:amitness 编译:ronghuaiyang 导读 仅使用10张带有标签的图像,它在CIFAR-10上的中位精度为78%,最大精度为84%,来看看是怎 ...
- tensorflow2 训练和预测使用不同的输出层、获取权重参数
目标: youtubeNet通过训练tensorflow2时设置不同的激活函数,训练和预测采用不同的分支,然后可以在训练和测试时,把模型进行分离,得到训练和预测时,某些层的参数不同.可以通过类似迁移学 ...
- 使用caffe自带模型训练cifar10数据集
前面训练了mnist数据集!但caffe自带的数据集还有cifar10数据集.同样cifar10数据集也是分类数据集,共分10类.cifar10数据集中包含60000张32x32的彩色图片.(其中 ...
- matlab训练cifar10,认识CIFAR-10数据集
CIFAR-10是一个广泛使用的标准数据集,里面包含了各种阿猫阿狗阿汽车--为了在后续学习实验中用好它,首先需要认识了解一下. 把tensorflow官方model下的cifar10文件复制到工作区, ...
- 深度学习:使用pytorch训练cifar10数据集(基于Lenet网络)
文档基于b站视频:https://www.bilibili.com/video/BV187411T7Ye 流程 model.py --定义LeNet网络模型 train.py --加载数据集并训练,训 ...
- 现代卷积神经网络(NiN),并使用NIN训练CIFAR10的分类
专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet). ...
最新文章
- Gradle 简单使用
- 误删/var/lib/dpkg/info,文件解决方案(是否完全解决,不确定)
- offer该怎么选:大公司or小公司?高薪or期权?
- java development kie_java – 直接从存储库加载Drools/KIE Workbench工件
- Centos 6 搭建安装 Gitlab
- jdbc:mysql:replication_使用Mysql的Replication功能实现数据库同步
- 003.ASP.NET MVC集中管理Session
- 这个小众副业,一次200,有人月入3万!
- 乡镇政府网络智能办公系统(乡镇OA)应用【乡镇信息化经验】
- 计算机配件价格上涨,显卡涨价风声再起 PC配件涨价什么时候是个头
- 单目标跟踪——个人笔记
- 基于寒武纪CNCodec 做视频编解码遇到的一些问题
- [ZZ]AppiumForWindows 菜鸟计划合集
- 怎么在PDF上修改文字,PDF修改文字的步骤
- 【iot-manager】(1)IOT商业化和未来竞争、淘汰还在继续,需要折腾一个开源IOT系统,参考Rancher开源Octopus:IoT设备管理系统做一个物联网管理系统
- No connection could be made because the target machine actively refused it 127.0.0.1:8888
- props传递对象_vue组件中使用props传递数据的实例详解
- python中true是什么意思_Python解惑之True和False详解
- Java for Web学习笔记(三五):自定义tag(3)TLDS和Tag Handler
- 双 JK 触发器 74LS112 逻辑功能。真值表_D触发器示例
热门文章
- android手写汉字,Android 手写输入的实现(保存涂鸦文字)
- WinMount开发者刘涛涛
- SecureCRT 5.1.3 及注册码
- eclipse中文版本转英文版
- c++获取计算机注册码,在c++中,如何能获得计算机的机器码?
- 微软飞行模拟器android,微软飞行模拟器2020
- COOX基础培训之SCADA(三)
- 按字母A——Z排列的中国城市(地级市)json数据
- flink onTimer定时器实现定时需求
- 【cocos2d-x 手游研发----怪物智能AI】