第一步:读入数据

# 导入必要的库
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasetsimport numpy as np
import matplotlib.pyplot as plt# 获取数据集
(x, x_lable), (y, y_lable) = datasets.cifar10.load_data()

对数据进行归一化处理可以使得数据运算速度加快,同时减少异常数据带来的影响。本次数据集为图片,分布范围为0~255,仅需要将每一个数值除以255即可将数据集归纳到0~1之间。

x = x/255
y = y/255

第二步:设置回调函数

在模型训练过程中,我们无法对模型进行相关性的操作,此时就需要使用到tensflow中的回调函数了。

我们可以指定一个很大的epoch(训练轮数),当验证集的损失值在一定次数内都没有降低时,代表模型已经运行到了最优值附近,当前的学习率已经无法使得梯度继续下降了,此时就通过回调函数终止模型的训练。

# 提前结束训练
earlyStop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=25)

第一个参数代表监控模型训练过程中的参数

第二个参数代表能够容忍模型监控值没有下降的次数。

在模型的训练过程中,有时会出现一个良好的模型,此时我们可以通过回调函数保存这个模型。

# 设置模型保存节点
checkpoint_save_path = '.\\tmp\\model_4.h5'
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,monitor='val_loss', mode='min',save_best_only=True)

第一个参数代表保存模型的路径

第二个参数代表监模型的值,

第三个参数代表保存值减小时更新模型

第四个参数表示仅保存最好的模型。

回调函数中还有很多有用的函数,比如控制学习率下降的函数,都在 tf.keras.callbacks中,有兴趣的可以自行了解。

同时,也可以自定义类,制作满足特定需求的回调函数。例如,提示当前训练伦次

class PrintEpochs(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs):print('当前轮次:', epoch)

第三步:构建模型

模型中有很多超参数,不同的参数都会对模型有一定影响,此处我训练过三种模型,最终选择最好的一种进行了模型的优化。

model_4 = tf.keras.Sequential([tf.keras.Input(shape=(x.shape[1:])),tf.keras.layers.BatchNormalization(),tf.keras.layers.Conv2D(filters=96, kernel_size=3, padding='same', activation='relu'),tf.keras.layers.MaxPooling2D(pool_size=2),tf.keras.layers.Dropout(0.2),tf.keras.layers.BatchNormalization(),tf.keras.layers.Conv2D(filters=64,kernel_size=3, padding='same', activation='relu'),tf.keras.layers.MaxPooling2D(pool_size=2),tf.keras.layers.Dropout(0.2),tf.keras.layers.BatchNormalization(),tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding='same', activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Flatten(),tf.keras.layers.BatchNormalization(),tf.keras.layers.Dense(units=128, activation='relu'),tf.keras.layers.Dropout(0.2),tf.keras.layers.Dense(units=32, activation='relu'),tf.keras.layers.Dropout(0.1),tf.keras.layers.Dense(units=10, activation='softmax')
])

可以用过    model_4.summary()    查看模型结构。

本次模型共有37w的参数,模型相对较小。

第四步:训练模型

# 训练数据
history_4 = model_4.fit(x,x_lable,epochs=200,verbose=1,callbacks=[earlyStop,cp_callback],validation_split=0.2,batch_size=128)

第一个参数代表传入的训练集

第二个参数代表传入的训练集标签

第三个参数代表模型训练的轮次

第四个参数代表输出训练日志

第五个参数代表调用的回调函数

第六个参数代表训练集验证集的比例

第七个参数代表每批训练模型的大小

部分与运行结果:

查看训练好的两个模型:

# 最低损失值数据
new_model = tf.keras.models.load_model(checkpoint_save_path)
test_loss, test_acc = new_model.evaluate(y,  y_lable)
print('\nTest best accuracy:', test_acc)# 最后一次运行数据
test_loss, test_acc = model_4.evaluate(y,  y_lable)
print('\nTest last accuracy:', test_acc)

运行结果:

此处我们选择model_4用于预测图片信息。

保存模型

model_4.save('.\\tmp\\model_4_794.h5')

查看模型训练过程中的代价函数(损失值)与准确度的变化过程

## 训练过程中的可视化
loss_train_val = history_4.history['loss']
loss_test_val = history_4.history['val_loss']sparse_categorical_accuracy = history_4.history['sparse_categorical_accuracy']
val_sparse_categorical_accuracy = history_4.history['val_sparse_categorical_accuracy']plt.figure(figsize=(12,5))
plt.subplot(121)
plt.plot(loss_train_val,label='train')
plt.plot(loss_test_val,label='test')
plt.title("loss")
plt.xlabel("epoch")
plt.legend()
plt.subplot(122)
plt.plot(sparse_categorical_accuracy,label='train')
plt.plot(val_sparse_categorical_accuracy,label='test')
plt.title("sparse_categorical_accuracy")
plt.xlabel("epoch")
plt.legend()plt.show()

运行结果:

准确率和损失值都区域一定的值,训练集与验证集大致符合,训练效果不错。

第五步:调用模型

new_model = tf.keras.models.load_model('.\\tmp\\model_4_794.h5')

使用模型进行预测:

# 预测照片
pre = new_model.predict(y[:20])

同时生成0~10(不包括10)的数组备用

tik = [i for i in range(0,10)]

防止plt绘制中文时出现乱码,指定画布大小

# 防止中文乱码
plt.rcParams['font.sans-serif'] = ['SimHei']# 设置画布大小
plt.figure(figsize=(12,40))

第六步:分栏可视化

获取预测标签与真实标签,若预测正确则绘制柱状图为绿色,若绘制错误则绘制柱状图为红色。

左侧绘制为预测图片,标题为预测值与真实值

右侧绘制为预测概率柱状图,标题为预测概率与预测值

for i in range(1,20,2):# 真实值标签rel_lable = int(y_lable[i])pre_lable = np.argmax(pre[i])if rel_lable==pre_lable:color_s = 'g'else:color_s = 'r'# 绘制预测图片plt.subplot(10,2,i)plt.imshow(y[i])plt.xticks([])plt.yticks([])plt.xlabel(f'预测值为{pre_lable}(真实值为{rel_lable})',size=14)# 绘制标签plt.subplot(10,2,i+1) # 预测值标签plt.bar(tik,pre[i],color=color_s)plt.xticks(tik)plt.title(f"预测 %.2f  为 :{np.argmax(pre[i])}     真实值为:{rel_lable}"%(np.max(pre[i])))plt.show()

部分运行结果:

PS:

制作不易,一键三连。

tensorflow cifar10 分类预测实战相关推荐

  1. 【数据分析与挖掘】基于LightGBM,XGBoost,逻辑回归的分类预测实战:英雄联盟数据(有数据集和代码)

    机器学习-LightGBM 一.LightGBM的介绍与应用 1.1 LightGBM的介绍 1.2 LightGBM的应用 二.数据集来源 三.基于英雄联盟数据集的LightGBM分类实战 Step ...

  2. 基于决策树的分类预测

    1.决策树的介绍 ​ 决策树(decision tree)是一种基本的分类与回归的方法,作为最基础.最常见的有监督学习模型,常被用于解决分类回归问题.本文主要讨论用于分类的决策树.决策树的核心思想是基 ...

  3. cifar-10 分类 tensorflow 代码

    看了黄文坚.唐源的<TensorFlow实战>对mnist分类的cnn教程,然后上网搜了下,发现挺多博主贴了对mnist分类的tensorflow代码,我想在同样框架下测试cifar-10 ...

  4. TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

    TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%) 目录 输出结果 实现代码 输出结果 Successfully downloaded t ...

  5. 【阿旭机器学习实战】【13】决策树分类模型实战:泰坦尼克号生存预测

    [阿旭机器学习实战]系列文章主要介绍机器学习的各种算法模型及其实战案例,欢迎点赞,关注共同学习交流. 本文用机器学习中的决策树分类模型对泰坦尼克号生存项目进行预测. 关于决策树的详细介绍及原理参见前一 ...

  6. PyTorch-09 循环神经网络RNNLSTM (时间序列表示、RNN循环神经网络、RNN Layer使用、时间序列预测案例、RNN训练难题、解决梯度离散LSTM、LSTM使用、情感分类问题实战)

    PyTorch-09 循环神经网络RNN&LSTM (时间序列表示.RNN循环神经网络.RNN Layer使用.时间序列预测案例(一层的预测点的案例).RNN训练难题(梯度爆炸和梯度离散)和解 ...

  7. 机器学习sklearn实战-----随机森林调参乳腺癌分类预测

    机器学习sklearn随机森林乳腺癌分类预测 机器学习中调参的基本思想: 1)非常正确的调参思路和方法 2)对模型评估指标有深入理解 3)对数据的感觉和经验 文章目录 机器学习sklearn随机森林乳 ...

  8. 现代神经网络(VGG),并用VGG16进行实战CIFAR10分类

    专栏:神经网络复现目录 本章介绍的是现代神经网络的结构和复现,包括深度卷积神经网络(AlexNet),VGG,NiN,GoogleNet,残差网络(ResNet),稠密连接网络(DenseNet). ...

  9. 《TensorFlow技术解析与实战》——第3章 可视化TensorFlow 3.1PlayGround

    本节书摘来自异步社区<TensorFlow技术解析与实战>一书中的第3章,第3.1节,作者李嘉璇,更多章节内容可以访问云栖社区"异步社区"公众号查看 第3章 可视化Te ...

最新文章

  1. Python中的类、模块和包究竟是什么?
  2. Javascript在IE中的有趣错误
  3. ProGit-读书简记
  4. 数字电路可控门电路原理(三态/同相/反相、缓冲/驱动电路)
  5. 第四期 | 带学斯坦福CS224n自然语言处理课+带打全球Kaggle比赛(文末重金招募老师!)...
  6. mysql从库故障恢复步骤(删除数据重新同步)
  7. leetcode108 将有序数组转换为二叉搜索树
  8. python pdf转txt保留全部信息_Python 将pdf转换成txt(不处理图片)
  9. windows稀疏文件
  10. 金融级分布式数据库架构设计要点
  11. python怎么读取csv文件-python怎么读取csv文件
  12. 【数据科学】什么是数据分析
  13. Kattis - missinggnomesD Missing Gnomes (思路题)
  14. acr122 java,ACR122U中文开发文档
  15. 神箭手云爬虫-爬取携程【国际】航班/机票信息-利用python解析返回的json文件将信息存储进Mysql数据库
  16. 虚拟账户 FTP 服务器不能上传可下载
  17. ES 矩阵查询(Adjacency matrix aggregation)
  18. python毕业设计 深度学习抽烟行为检测系统 - yolo opencv
  19. 7-5 修理牧场 (25 分)
  20. 国外生活必备的英文词汇

热门文章

  1. java jframe教程_Java JFrame
  2. 如何增加3d渲染的逼真感?提高3d渲染真实感的技巧
  3. 关于【赤峰公交出行】暂停服务的通知
  4. torch.linspace
  5. 「津津乐道播客」#348 厂长来了:AI是怎样帮你挡掉电诈的?
  6. Stream.noneMatch()
  7. java实现拜占庭将军_什么是拜占庭将军问题(一)
  8. JavaScript相册单图放大预览
  9. win10家庭版转专业版后,专业版功能仍然不能用怎么办
  10. 4020mAh电池+4GB大内存 360手机vizza仅售899元