看了一下之前做过的有关深度学习的实验,发现InceptionV3这个模型还没有用到,虽然并没有自己实现该网络模型,但是先学习一下它的原理,再利用迁移学习测试一下它的模型准确率,也不失为一种不错的学习方法。
本次实验利用InceptionV3网络模型,实现水果识别。

1.导入库

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os,pathlib,PILgpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号

2.数据加载

原数据集中给出了训练集、验证集、测试集三个子文件夹,就不需要人为的划分了。每个子文件夹中包含21类水果。

data_dir_train = "E:/tmp/.keras/datasets/fruit_rec/fruits-360-original-size/fruits-360-original-size/Training"
data_dir_test = "E:/tmp/.keras/datasets/fruit_rec/fruits-360-original-size/fruits-360-original-size/Test"
data_dir_validation = "E:/tmp/.keras/datasets/fruit_rec/fruits-360-original-size/fruits-360-original-size/Validation"data_dir_train = pathlib.Path(data_dir_train)
data_dir_test = pathlib.Path(data_dir_test)
data_dir_validation = pathlib.Path(data_dir_validation)all_images_paths = list(data_dir_train.glob('*'))
all_images_paths = [str(path) for path in all_images_paths]
all_label_names = [path.split("\\")[8].split(".")[0] for path in all_images_paths]

超参数的设置

height = 256
width = 256
epochs =10
batch_size = 32

分别构建训练集、验证集和测试集的ImageDataGenerator,并进行数据预处理

train_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,rotation_range=45,shear_range=0.2,zoom_range=0.2,horizontal_flip=True
)
test_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255
)
validation_data_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255,validation_split=0.2
)train_ds = train_data_gen.flow_from_directory(directory=data_dir_train,target_size=(height,width),shuffle=True,batch_size=batch_size,class_mode='categorical'
)
test_ds = test_data_gen.flow_from_directory(directory=data_dir_test,target_size=(height,width),shuffle=True,batch_size=batch_size,class_mode='categorical'
)
validation_ds = validation_data_gen.flow_from_directory(directory=data_dir_validation,target_size=(height,width),shuffle=True,batch_size=batch_size,class_mode='categorical'
)

整理后的数据如下所示:

3.InceptionV3网络

InceptionV3模型是谷歌Inception系列里面的第三代模型,相比于其它神经网络模型,Inception网络最大的特点在于将神经网络层与层之间的卷积运算进行了拓展。
就像VGG,AlexNet网络,它就是一直垂直卷积下来的,一层接着一层。
ResNet则是创新性的引入了残差网络的概念,使得靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分,后面的特征层的内容会有一部分由其前面的某一层线性贡献。

Google Inception Net在2014年的 ImageNet Large Scale Visual Recognition Competition (ILSVRC)中取得第一名,该网络以结构上的创新取胜,通过采用全局平均池化层取代全连接层,极大的降低了参数量,是非常实用的模型,一般称该网络模型为Inception V1。随后的Inception V2中,引入了Batch Normalization方法,加快了训练的收敛速度。在Inception V3模型中,通过将二维卷积层拆分成两个一维卷积层,不仅降低了参数数量,同时减轻了过拟合现象。参考链接

Inception网络采用不同大小的卷积核,使得存在不同大小的感受野,最后实现拼接达到不同尺度特征的融合。

整体结构图如下所示:

模型搭建:
在VGG系列的模型搭建时,利用迁移学习得到网络模型后,会将trainable设置为False,也就意味着前面的参数是不能够训练的。但是InceptionV3网络以及ResNet网络,由于引入了BN层,因此不能直接将trainable设置为False,博主在参考别人的博客时,有的博主提出在整个网络模型搭建完成后,再将trainable设置为False,但是博主的实验效果并不好,在训练集上面的准确率非常高,但是测试集上准确率特别低。因此直接去掉了将trainable设置为False这一步骤,模型的准确率得到提高。

base_model = tf.keras.applications.InceptionV3(weights = 'imagenet',include_top = False,pooling = None,input_shape = (height,width,3))
model = tf.keras.Sequential()
model.add(base_model)
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(128,activation='relu',kernel_initializer=tf.keras.initializers.glorot_normal(seed=32))
)
model.add(tf.keras.layers.Dense(64,activation='relu',kernel_initializer=tf.keras.initializers.glorot_normal(seed=33))
)
model.add(tf.keras.layers.Dense(21,activation='softmax',kernel_initializer=tf.keras.initializers.glorot_normal(seed=3))
)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),loss = "categorical_crossentropy",metrics = ['accuracy']
)

实验效果如下所示,经过10个epoch之后,模型准确率在90%左右。

利用验证集绘制混淆矩阵,关于混淆矩阵的代码,参考我之前的博客即可。

努力加油a啊

深度学习之基于InceptionV3实现水果识别相关推荐

  1. 【深度学习】基于caffe的表情识别(二):数据集介绍及处理

    <基于caffe的表情识别>系列文章索引:http://blog.csdn.net/pangyunsheng/article/details/79434263 一.数据集介绍 在本实验中我 ...

  2. 深度学习之基于CNN实现天气识别

    其实和猫狗大战还有上一篇博客的代码差不太多,但是中间出现了新的问题. 1.导入库 import numpy as np import tensorflow as tf import os,PIL im ...

  3. 【深度学习】基于Keras的手写体识别

    from keras import models from keras import layers from keras.datasets import mnist# 搭建网络 network = m ...

  4. 基于深度学习的番茄叶部病害识别模型

    基于深度学习的番茄叶部病害识别模型 1.研究思路 为实现番茄叶病特征的自动 提取,并提高识别准确率,提出一种基于深度学习的番茄叶病识别模型.该模型基于卷积神经网络对番茄叶部病害特征进行自动提取,获得高 ...

  5. 一种基于深度学习的遥感图像分类及农田识别方法

    文章针对现有的神经网络收敛速度慢.识别准确率不高的缺点,提出了一种基于卷积神经网络的遥感图像农田分类及识别方法.该算法使用较大的卷积核,有效地提取梯度信息:设计深度为6层的卷积神经网络,提高了网络的分 ...

  6. 基于深度学习的近红外掌纹识别原型系统设计与实现

    基于深度学习的近红外掌纹识别原型系统设计与实现 一.绪论 二.深度学习知识 三.Tensorflow 四.卷积神经网络 五.掌纹识别理论 掌纹图像采集 掌纹图像预处理 掌纹特征提取 掌纹特征匹配 掌纹 ...

  7. 基于深度学习的高精度家禽猪检测识别系统(PyTorch+Pyside6+YOLOv5模型)

    摘要:基于深度学习的高精度家禽猪检测识别系统可用于日常生活中或野外来检测与定位家禽猪目标,利用深度学习算法可实现图片.视频.摄像头等方式的家禽猪目标检测识别,另外支持结果可视化与图片或视频检测结果的导 ...

  8. 基于深度学习的场景文本检测和识别(Scene Text Detection and Recognition)综述

    1. 引言 文字是人类最重要的创作之一,它使人们在时空上可以有效地.可靠的传播或获取信息. 场景中的文字的检测和识别对我们理解世界很有帮助,它应用在图像搜索.即时翻译.机器人导航.工业自动化等领域. ...

  9. 基于深度学习的高精度牙齿健康检测识别系统(PyTorch+Pyside6+YOLOv5模型)

    摘要:基于深度学习的高精度牙齿健康检测识别系统可用于日常生活中检测牙齿健康状况,利用深度学习算法可实现图片.视频.摄像头等方式的牙齿目标检测识别,另外支持结果可视化与图片或视频检测结果的导出.本系统采 ...

最新文章

  1. mysql 二级什么意思_MySQL二级等级考试归纳——PHP篇
  2. java.lang.OutOfMemoryError: Java heap space解决方法
  3. php中参数传值的三种方法,php cli传递参数的方法
  4. Android反编工具的使用-Android Killer
  5. 汇编语言TEXTEQU伪指令
  6. 财务管理专业应该报计算机二级哪个科目,我是应该报计算机二级还是三级呢
  7. 如何实现listbox选项,然后双击鼠标实现选项的删除
  8. JS对以对象组成的数组去重
  9. 开放式的Video Captioning,中科院自动化所提出基于“检索-复制-生成”的网络
  10. 深度学习自学(九):Alexnet解读
  11. Docker学习之数据管理
  12. 关于空间域到频率域的转换
  13. 利用Python+xarray实现遥感数据——海表温度的经验正交函数(EOF)分解——xarray学习文档02
  14. java计算机毕业设计BS景区票务管理系统设计与实现源码+mysql数据库+系统+lw文档+部署
  15. office转pdf和图片实现在线预览
  16. 快递鸟智选物流API对接流程
  17. 一个理财小白如何挑选靠谱的网络理财产品?
  18. Pandoc中的Markdown语法
  19. 物联网智能硬件与嵌入式系统
  20. 大公司研发部门普遍存在的问题(日常吐槽)

热门文章

  1. Winsock属性 方法介绍
  2. php 实现百度坐标转换,PHP中腾讯与百度进行坐标转换
  3. 驱动人生2008_驱动人生致敬深圳经济特区建立四十周年!
  4. Keras和TensorFlow的关系和区别
  5. Android使用adb命令安装应用-连接usb
  6. python分类算法的应用_07-机器学习_(lineage回归分类算法与应用) ---没用
  7. Android开发之Java和Kotlin混合开发互相跳转报错的问题
  8. Flutter开发之实现沉浸式状态栏的效果
  9. Error:Could not find appcompat-v7.aar (com.android.support:appcompat-v7:26.1.0). Searched in the fol
  10. React性能优化:immutability-helper