介绍

如果你决心制作一个CNN模型,使其准确性达到95%以上,那么这可能是适合你的博客。

我们将分三部分解决这个问题

  • 迁移学习

  • 数据扩充

  • 处理过拟合和欠拟合问题

迁移学习

迁移学习是通过从已经学习的相关任务中迁移知识来改进新任务中学习的方法。

用简单的话来说,迁移学习的思想是,我们使用从图像分类任务中预先训练的模型,而不是从头开始训练新模型。

为什么要使用迁移学习?

迁移学习是一种优化,是节省时间或获得更好性能的捷径。

通常,在模型开发和评估之前,在领域中使用迁移学习不会有好处。但是在大多数情况下,迁移学习比起从头训练的模型提供更好的结果

迁移学习的主要好处是:

  • 更高的起点:源模型的初始点(在精炼模型之前)比其他方法要高。

  • **更高的斜率:**在对源模型进行训练的过程中,其提高速度为比其他情况更快。

  • **更高的渐近线:**训练后的模型收敛要优于其他方式。

此图总结了上述的3个点,你可以看到将迁移学习应用于模型时,训练从更高的点开始,从而更快地达到更高的准确度。

Tensorflow中的迁移学习

在本教程中,我们将讨论如何使用Tensorflow Hub在Tensorflow模型中使用迁移学习。

Tensorflow Hub是一个收集各种预训练模型的地方,例如ResNet,MobileNet,VGG-16等。它们还具有用于图像分类,语音识别等的不同模型。在Tensorflow Hub中可用的迁移学习模型中最后的输出层将被删除,以便我们可以使用自定义的类数插入输出层。

  • Tensorflow Hub:https://www.tensorflow.org/hub

URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2"
feature_extractor = hub.KerasLayer(URL,input_shape=(IMG_SHAPE, IMG_SHAPE,3))

在这里,我们使用了MobileNet模型,你可以在TensorFlow Hub网站上找到不同的模型。

每种型号都有特定的输入图像大小,将在网站上提及。

在我们的MobileNet模型中,此处提到的图像尺寸为224×224,因此在使用传输模型时,请确保将所有图像的尺寸调整为该特定尺寸。

feature_extractor.trainable = False

在声明你的迁移学习模型后,请确保包含上面的代码,以确保该模型不会再次从头开始进行训练

现在我们可以定义我们的自定义模型:

no_of_output_classes=4
from tensorflow.keras import layers
model = tf.keras.Sequential([feature_extractor,layers.Dense(No_of_output_classes)   # make sure this number is the same number as output classes
])
model.summary()

现在,我们可以像运行任何普通模型一样运行model.compile和model.fit。

数据扩充

拥有大型数据集对于深度学习模型的性能至关重要。但是,我们可以通过增加现有数据来提高模型的性能。它还可以帮助模型对不同类型的图像进行概括。在数据扩充中,我们添加了不同的过滤器或略微更改了已有的图像,例如添加了随机放大,缩小,以随机角度旋转图像,模糊图像等。

这显示了旋转数据的扩充

Tensorflow中的数据增强

如果你在Tensorflow中使用 ImageDataGenerator,则可以轻松应用数据增强

  • ImageDataGenerator:https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator

image_gen_train = ImageDataGenerator(     # here we use the ImageDataGeneratorrescale=1./255,rotation_range=40,width_shift_range=0.2,                # Applaying these all Data Augmentationsheight_shift_range=0.2,shear_range=0.2,zoom_range=0.2,horizontal_flip=True,fill_mode='nearest')

这些是可用的不同数据扩充的示例,更多信息在TensorFlow文档中查看。

  • TensorFlow文档:https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator

然后我们可以将这些增强应用于我们的图像

train_data_gen = image_gen_train.flow_from_directory(batch_size=BATCH_SIZE,     # Batch siz emeans at a time it takes 100directory=train_dir,    # Here we put shuffle= True so tat model doesnt memorise ordershuffle=True,target_size=(IMG_SHAPE,IMG_SHAPE),class_mode='binary')

这里的 train_dir 是我们的训练图像所在的目录路径。

处理过拟合和欠拟合问题

过度拟合

当模型学习训练数据中的细节和噪声时,就会过度拟合,从而对模型在新数据上的性能产生负面影响。

换句话说,过度拟合的模型在训练集上表现良好,而在测试集上表现不佳,这意味着当涉及到新数据时,该模型似乎无法泛化

正如你在过度拟合中看到的,它过于具体地学习训练数据集,并且在给定新数据集时,会对模型产生负面影响。

欠拟合

欠拟合是相反的情况,在这种情况下,模型无法从训练数据中学到足够多的知识,以至于在训练和测试数据集上都做得不好。当没有足够的数据可进行训练时,通常会发生这种情况。

克服过度拟合的方法:

有两种方法可以解决过度拟合问题:

1)使用更多的训练数据

这是克服过度拟合的最简单方法

2)使用数据扩充

数据增强可以帮助你克服过度拟合的问题。上文已深入讨论了数据扩充。

3)知道何时停止训练

换句话说,知道你想要训练模型的时期数在决定模型是否适合方面具有重要作用

你可以通过绘制训练集和验证集的损失或精度与epoch图来获得想要训练模型的确切数字。

如你所见,在早期停止状态之后,验证集损失会增加,但是训练集值会继续减少。在准确的模型中,无论是训练还是验证,准确性都必须降低

所以这里对应于早期停止值的epoch值就是我们的epoch数

这是未过度拟合或未拟合的模型的示例。

结论

通过遵循这些方法,你可以使CNN模型的验证集准确性超过95%。

该项目的完整代码可在这个GitHub上找到:https://github.com/aromaljosebaby。

☆ END ☆

如果看到这里,说明你喜欢这篇文章,请转发、点赞。微信搜索「uncle_pn」,欢迎添加小编微信「 mthler」,每日朋友圈更新一篇高质量博文。

扫描二维码添加小编↓

以95%的精度构建CNN模型相关推荐

  1. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(二)

    前情提要 基于上文所说 基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一) 我用tf2.0和Python3.7复现了一个基于CNN做音乐分类器.用余弦相似度评估距离的一个音乐推荐模型. ...

  2. 第16课:项目实战——利用 PyTorch 构建 CNN 模型

    上一篇,我们主要介绍了 CNN 的基本概念和模型结构.本文将带领大家使用 PyTorch 一步步搭建 CNN 模型,进行数字图片识别.本案例中,我们选用的是 MNIST 数据集. 总的来说,我们构建分 ...

  3. 【Python深度学习】基于Tensorflow2.0构建CNN模型尝试分类音乐类型(一)

    音乐分类 前言 复现代码 MP3转mel CNN模型 训练结果 总结 前言 我在逛github的时候,偶然发现了一个项目:基于深度学习的音乐推荐.[VikramShenoy97].作者是基于CNN做的 ...

  4. 基于深度学习的轴承故障识别-构建基础的CNN模型

    上回书说到,处理序列的基本深度学习算法分别是循环神经网络(recurrent neural network)和一维卷积神经网络(1D convnet).上篇构建了基础的LSTM模型,这一篇自然轮到CN ...

  5. 从零开始,手把手教你使用Keras和TensorFlow构建自己的CNN模型

    最近学习CNN,搭建CNN模型时看网上鱼龙混杂的博客走了不少歪路,决定自己来总结一下. 注意本教程未必对所有版本有效,请根据需要的版本适当调整.文章中配置的环境是Python 3.8.12 ,Tens ...

  6. [Tensorflow]服装图像数据集分类:使用DNN、CNN模型

    一.实验介绍 实验环境:jupyter notebook.Tensorflow.keras 数据集Fashion Mnist与样例代码及相关参考: https://www.tensorflow.org ...

  7. 深度学习初学者使用Keras构建和部署CNN模型

    https://www.toutiao.com/a6666072496283845134/ 如果你在黑框中画一个图形.您应该从模型中得到一个预测. 恭喜!您已构建并部署了第一个CNN模型.继续尝试写出 ...

  8. 从爬虫构建数据集到CNN模型的验证码识别,一步一步搭建基于Python的PC个人端12306抢票程序

    写在前面:这个程序不是一个人能在短时间内完成的,感谢达纳,王哥的支持帮助.也感谢小平老师,没有压迫,就没有项目. 简介:这是一篇很硬核的Blog, 有一定Python基础的童鞋方能看懂,本程序的主要内 ...

  9. 一文总结经典卷积神经网络CNN模型

    一般的DNN直接将全部信息拉成一维进行全连接,会丢失图像的位置等信息. CNN(卷积神经网络)更适合计算机视觉领域.下面总结从1998年至今的优秀CNN模型,包括LeNet.AlexNet.ZFNet ...

最新文章

  1. SBIO | 西农韦革宏组-大豆土壤细菌门间负向互作影响群落的动态变化和功能
  2. Leetcode 138. 复制带随机指针的链表 解题思路及C++实现
  3. jquery遍历函数siblings()
  4. Boost:BOOST_ASSERT_IS_VOID的测试程序
  5. SRE(Simple Rule Engine) Document
  6. mysql error 1594_【MySQL】解决mysql的 1594 错误-阿里云开发者社区
  7. 【LOJ】#3030. 「JOISC 2019 Day1」考试
  8. mysqli_fetch_row,mysqli_fetch_array,mysqli_fetch_assoc区别
  9. 【转】ASP.NET 2.0中Page事件的执行顺序
  10. 剑指 offer代码解析——面试题39推断平衡二叉树
  11. 孙鑫VC学习笔记:第十七讲 (四) 用邮槽实现进程间的通信
  12. 拓端tecdat|R语言k-means聚类、层次聚类、主成分(PCA)降维及可视化分析鸢尾花iris数据集
  13. 什么是数字证书?数字证书在哪办理?
  14. 华为vrrp默认优先级_【干货】华为vrrp配置
  15. java什么是继承_JAVA中什么是继承?
  16. 以华为2016年笔试题为例,详解牛客网的在线判题系统(OJ模式)
  17. 2016理数全国卷 T21
  18. 浙江中医药大学第十二届大学生程序设计竞赛 部分题解
  19. 前端图片在线转换Base64 图片编码Base64
  20. 用mysql创建职工表_【典型例题】数据库——用MySQL来建立创建员工表;-Go语言中文社区...

热门文章

  1. C64x+ Megamodule概述
  2. CSS实现矩形两边挖半圆
  3. 985计算机报录比高的学校,985/211高报录比院校
  4. tcp/ip网络里的客户端和服务器端 信息交流 与 安全
  5. 一级计算机word视频教学视频教学设计,计算机基础教学设计(word文档).doc
  6. PXI Express外设板信号汇总(更新中)
  7. 利用bastille配置安全的linux系统
  8. 新钛云服多云管理平台用户手册
  9. 2022暑期杭电第十场
  10. 找工作必备:外企面试常见10个问题。