基于卷积神经网络和迁移学习实现场景图片分类任务
文章目录
- 前言
- 一、数据集
- 二、模型结构
- 模型参数设置
- Dataset
- BatchNormailization
- Learning rate
- Relu
- Sparse-Categorical-Crossentropy
- Early stopping method
- 三,模型结果
- 四,改进模型
- 五,代码实现
- 六,总结
前言
该项目解决的主要问题是场景识别。 具体而言,我们使用的数据集是一个包括15种不同的场景的集合。 该项目的目的是使用数据集来训练卷积神经网络模型,再在测试集上使用模型进行预测。
本项目首先自建一个卷积神经网络框架,然后使用迁移学习加载预训练模型,实现对原框架的改进。另外,本项目包含数据集加载和预处理部分,相比于minisit等成熟数据集,更能体现实验的完整性。
一、数据集
训练集:1500张(15个常见,每个类别100张)
测试集:2987张(随机类别)
下载链接:Link~~~~~~~~~~
将训练集和测试集复制进编辑平台即可开始实验。
二、模型结构
如图所示,我们设计的模型包括5个卷积单元,一个过渡单元,三个全连接的神经网络单元和一个输出单元。
具体结果如下,卷积单元包括2个卷积层,每个卷积层跟随一个BatchNormailization层,最后连接MaxPooling层。
过渡单元由一个卷积层,一个BatchNormailization层和全局平均池化层组成。
全连接单元包括一个全连接层和一个BatchNormailization层。
模型参数设置
Dataset
我们将原始训练集根据8:2的比例划分为训练集和验证集。 该目的是通过训练集完成模型训练,然后根据验证集选择最优的模型参数。不划分测试集的原因有两个。 首先,对于深度学习来说,1500张图片的数据集本就不大,在减少数据集容易出现过拟合现象; 其次,我们无法保证要预测的测试集的数据分布是否与训练集相同,因此对测试集的设置没有很大的意义。
BatchNormailization
BatchNormailization层的作用是对数据进行批量的标准化。 设置较大的batch size可以节省训练时间并提高训练稳定性,但是也会降低模型的泛化能力;标准化是规范输入的数据(使数据分布有一个好的方差和偏差,有助于模型训练)并输出它。 测试后,我们发现如果不使用batchnormoilization层,模型训练时的loss长时间得不到更新(梯度弥散)。因此我们在每个卷积器和完全连接的层后跟随一个BatchNormailization层 。 并将batch size大小设置为32。
Learning rate
学习速率是影响模型收敛的一个重要因素之一。 一般来说,由于数据集不同,所需的学习速率不同(一般是0.1~0.00001)。 测试后,我们将学习效率设置为0.0001。另外,如果要优化学习速率,可以选择渐变的lr,即训练刚开始设置一个较大的lr,以便于模型快速收敛,然后,lr随着训练时间逐渐变小,以便于更准确的找到最优解。
Relu
这里选择relu作为卷积层和全连接层的激活函数。 Relu函数是目前神经网络中最受欢迎和广泛使用的激活函数,相比较与其他激活函数,relu函数使深度学习模型参数的更新更加丝滑(但是,可能会造成部分神经元失活)。值得注意的是,在输出层(最后的全连接层)中我们不选择任何激活函数,我们将在损失函数sparse-categorical-crossentropy中完成激活任务。 在许多实验发现,这种激活传递方法比在输出层添加SoftMax更稳定,有效。
Sparse-Categorical-Crossentropy
sparse-categorical-crossentropy被选为loss函数,相比于MSE等损失函数更适合处理图片的多分类问题。
Early stopping method
在验证集使用早停法,目的是为了防止过度模型过拟合并选择模型的最佳参数。
三,模型结果
如图所示,自建的卷积神经网络模型的最终精度达到0.6146
loss降低到1.6089
接下来,我们可以使用更大且更复杂的模型来获得更高的准确性。 预先训练的模型是一个不错的选择,因为预先接受的模型是在大数据集上训练完成的模型,其参数可以更好的提取图片特征。本项目使用tensorflow2.0框架,下载了预先训练的模型MoBilenetV2,不包括顶部分类层,因为我们将设计添加一个新的分类层,以便更好的贴合项目。
值得注意的是,从模型结果上看,当训练集的准确率已经接近为1时,验证集的准确率才只有0.6;而且训练集的准确率在一开始呈飙升趋势。如此大的差距说明模型已经过拟合。
因此,我们决定通过水平或垂直的随机翻转,裁剪等图像增强的方法来对数据集进行实时增广。对某种程度来说,大大缓解了数据集过拟合问题
四,改进模型
该模型的改进主要有两部分:增加预训练模型和实时的数据增广。
使用的预培训模型是由Google开发的Mo-Bilenetv2网络结构。模型已在ImageNet数据集上预先培训,并已学习1000个常见物体的特征。 因此,模型具有强大的特征提取功能。
下载模型时,我们指定参数include-top = false,即不下载模型最后的全连接层,因为我们只想使用使用预训练模型对图片进行特征提取而不是直接进行分类。 预训练模型的分类能力通常受到原始任务的影响。 如果我们想使用预训练模型,自己构建顶层分类能更好的适应新的分类任务。
具体模型结构如图所示
我们将预先训练的模型视为特征提取器,可以得到(6,6,1280)的特征输出。 特征提取器可以理解为特征映射过程。用一个特征矩阵来表示原始图片。 在新的特征空间中,它更有利于图片分类
我们使用GlobalAveragePooling将特征矩阵打平成一维的特征向量。GlobalAveragePooling是工程经验中最常用的打平方法,相比之下GlobalMaxPooling不够稳定; FullyConnect layer容易造成过拟合而且得到的参数量巨大。
最后添加一个15个神经元的全连接层做为输出层,对15个常见进行分类。
结果如下
模型的最高准确率达到了0.9;其loss和acc的曲线趋势也明显优于自创的卷积神经网络。
五,代码实现
import os
import pathlib
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'
import tensorflow as tf
import random
import numpy as np
import sys
import glob
data_root = pathlib.Path(r'C:\Users\86177\.keras\datasets\training')
print(data_root)
for item in data_root.iterdir():print(item)#Get the names of 15 folders
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
print(label_names)#Mark the name of the folder with a serial number (0-14)
label_to_index = dict((name, index) for index, name in enumerate(label_names))#Configure the picture mapping corresponding to the serial number
index_to_label = dict((v,k) for k, v in label_to_index.items())
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
random.shuffle(all_image_paths)
image_count = len(all_image_paths)
print(image_count)#Get the tags of all pictures
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]for path in all_image_paths]
print("First 10 labels indices: ", all_image_labels[:10])#Define loading and preprocessing functions, decode pictures and unify picture sizes
def preprocess_image(image):image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [192, 192])image /= 255.0 # normalize to [0,1] range return imagereturn image
def load_and_preprocess_image(path):image = tf.io.read_file(path)return preprocess_image(image)# Divide the training set into training set and test set according to the ratio of 8:2
i = int(len(all_image_paths)*0.8)train_path = all_image_paths[:i]
train_label = all_image_labels[:i]
test_path = all_image_paths[i:]
test_label = all_image_labels[i:]#"From_tensor_slices" method uses tensor slice elements to build a dataset of image paths
path_trainds = tf.data.Dataset.from_tensor_slices(train_path)
#Similarly, reconstruct the label data set and use tf.cast to convert to int64 data type
label_trainds = tf.data.Dataset.from_tensor_slices(tf.cast(train_label, tf.int64))
#Acquire the picture according to the path, and get the picture data set after loading and preprocessing
image_trainds = path_trainds.map(load_and_preprocess_image )
print(image_trainds)
#Pack the picture and its corresponding label to form a training set
train_ds = tf.data.Dataset.zip((image_trainds, label_trainds))
print(train_ds)
print(len(train_ds))#The test set is processed the same as the training set
path_testds = tf.data.Dataset.from_tensor_slices(test_path)
label_testds = tf.data.Dataset.from_tensor_slices(tf.cast(test_label, tf.int64))
image_testds = path_testds.map(load_and_preprocess_image )
print(image_testds)
test_ds = tf.data.Dataset.zip((image_testds, label_testds))
print(test_ds)
print(len(test_ds))# Define a function to convert the training data into a data type that tensorflow can handle
def preprocess(x, y):# [0~1]x = tf.cast(x, dtype=tf.float32)y = tf.cast(y, dtype=tf.int32)return x,y
#Set batchsize to break up the training set to prevent overfitting in the training process
BATCHSIZE = 32
train_db = train_ds.repeat().shuffle(400).map(preprocess).batch(BATCHSIZE)
test_db = test_ds.shuffle(400).map(preprocess).batch(BATCHSIZE)#Because the training set data is too small, the function is defined in the training phase to perform real-time augmentation of the data to prevent overfitting
def augment_data(image, label):print("扩展数据调用!")# Randomly flip the picture horizontallyimage = tf.image.random_flip_left_right(image)# Randomly set the contrast of the pictureimage = tf.image.random_contrast(image, lower=0.0, upper=1.0)# # Randomly flip the picture verticallyimage = tf.image.random_flip_up_down(image)# # Randomly set the brightness of the pictureimage = tf.image.random_brightness(image, max_delta=0.5)# # Randomly set the chroma of the pictureimage = tf.image.random_hue(image, max_delta=0.3)# # Randomly set the saturation of the pictureimage = tf.image.random_saturation(image, lower=0.3, upper=0.5)return image,labeltrain_db = train_db.map(augment_data)# #Self-created convolutional neural network model structure
# model = tf.keras.Sequential([
#
# tf.keras.layers.Conv2D(64, (3, 3), input_shape=(192, 192, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.MaxPooling2D(),
# tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.MaxPooling2D(),
# tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(256, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.MaxPooling2D(),
# tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.MaxPooling2D(),
# tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Conv2D(512, (3, 3), activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.GlobalAveragePooling2D(),
# # tf.keras.layers.Flatten(),
# tf.keras.layers.Dense(1024, activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Dense(512, activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Dense(128, activation='relu'),
# tf.keras.layers.BatchNormalization(),
# tf.keras.layers.Dense(15)
# ])# Migrate the MobileNetV2 model without loading the top layer
base_model = tf.keras.applications.mobilenet_v2.MobileNetV2(include_top=False, weights='imagenet',input_shape=(192, 192, 3))
inputs = tf.keras.layers.Input(shape=(192, 192, 3))#Define the function to change the value range of the data set, because MobileNetV2 accepts the input data value range is [-1,1], and our previous pre-function maps the quantized value of the picture to [0,1]
def change_range(image,label):return 2*image-1,label
keras_ds = train_db.map(change_range)
keras_vds = test_db.map(change_range)#Define model structure
model = tf.keras.Sequential([base_model,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(15)])
#Assemble model, set optimizer, loss function and performance evaluation index
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.0001),#optimizer=tf.keras.optimizers.SGD(lr=0.005, momentum=0.9),loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
model.summary()#Run model
model.fit(keras_ds,epochs=20,steps_per_epoch=37,validation_data=keras_vds,validation_steps=9)
# Evaluate model
loss, accuracy = model.evaluate(keras_vds)
print('\ntest loss', loss)
print('accuracy', accuracy)#Define loading test set data function
def load_and_preprocess_image2(path):image = tf.io.read_file(path)image = tf.image.decode_jpeg(image, channels=3)image = tf.image.resize(image, [192, 192])image = tf.cast(image, tf.float32)image = image / 255.0return imagef = open('out7.txt', 'a+')
for i in range(0, 2988): # the final image is 2987.jpg.test_img = glob.glob("./data/testing/%s" % i + ".jpg")if (i != 1314 and i != 2938 and i != 2962):BATCHSIZE2 = 1path_trainds = tf.data.Dataset.from_tensor_slices(test_img)image_trainds = path_trainds.map(load_and_preprocess_image2)test_tensor = image_trainds.batch(BATCHSIZE2)# Predict the test set datapred = model.predict(test_tensor)A = index_to_label.get(np.argmax(pred))# Write the prediction results into a txt filename = str(i) + '.jpg'print(name, A.lower(), file=f)f.close()
六,总结
本项目基于卷积神经网络和迁移模型实现了场景分类的任务。代码实现是基于tensorflow2.0框架的,其中有详细的注释解释。有不明白的地方欢迎留言~
基于卷积神经网络和迁移学习实现场景图片分类任务相关推荐
- 基于卷积神经网络与迁移学习的油茶病害图像识别
基于卷积神经网络与迁移学习的油茶病害图像识别 1.研究思路 利用深度卷积神经网络强大的特征学习和特征表达能力来自动学习油茶病害特征,并借助迁移学习方法将AlexNet模型在ImageNet图像数据集上 ...
- 基于卷积神经网络及迁移学习的掌纹识别
- Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类)
Python深度学习实例--基于卷积神经网络的小型数据处理(猫狗分类) 1.卷积神经网络 1.1卷积神经网络简介 1.2卷积运算 1.3 深度学习与小数据问题的相关性 2.下载数据 2.1下载原始数据 ...
- 第八届“泰迪杯”数据挖掘挑战赛C题“泰迪杯”奖论文(基于卷积神经网络及集成学习的网络问政平台留言文本挖掘与分析)
目 录 第一章 引言 1.1挖掘背景 1.2挖掘意义 1.3问题描述 第二章 群众留言分类 2.1数据准备 2.1.1数据描述 2.1.2数据预处理 2.2特征提取 2.3建立模型 2.3.1卷积神经 ...
- 基于卷积神经网络方法的英文短文本情感分类(Python)
摘要:互联网的快速发展,使得每个人表现自己,发表言论更加的自由和便利.Twitter.Facebook等应用软件为大众提供了表达自身情感的一个平台.情感分类,可以简单地表示为喜欢,厌恶和中性,也渐渐受 ...
- 什么是神经网络在object detection的应用?cascade classifier,卷积神经网络,迁移学习
首先 输入,positive image:大黄蜂 negative image:大黄蜂的背景 输出,有多大的概率是大黄蜂 用的是卷积神经网络 卷积神经网络的分类器 在matlab里面是一个xml文件 ...
- 基于MATLAB的Alexnet迁移学习进行猫狗分类(数据集:Kaggle)
基本介绍 软件:Matlab R2018b 数据集:Kaggle猫狗数据集 网络:AlexNet 前期准备 数据集 Kaggle猫狗数据集猫与狗用于训练的图片(train)分别12500张,每张图片的 ...
- 基于ResNET50模型进行迁移学习构建中药饮片分类Web App
本文主要介绍如何利用深度学习迁移方法进行中药分类的设计的过程 1.数据采集 大量有效的中药图片是宝贵的资源,采用自己拍照的方式收集非常耗时,可以从借助搜索引擎从网络抓取中药材图片,方法如下 (1)安装 ...
- 四川大学计算机学院琚生根教授,基于卷积神经网络和自注意力机制的文本分类模型...
Abstract: The wordlevel shallow convolutional neural network (CNN) model has achieved good performa ...
- 记一次 基于 卷积神经网络(CNN)的 验证码图片识别
前几天搭建好了tensorflow2的环境,今天来试验一下神奇的机器学习. 先简单编写一个java程序,收集了10000多个验证码图片,全部进行人工标注(训练素材点击下载),其中600多个用来检验预测 ...
最新文章
- 【原创】关于部门月会(二)
- linux文泉驿字体调用,使用文泉驿点阵字体解决Linux中文化问题
- 二叉搜索树介绍及其接口说明
- nginx做服务器入口_Nginx实现http反向代理
- Https 加密原理分析
- 学习Jsoup(三)
- Java Web 获取客户端真实IP
- pandas—pandas.DataFrame.iterrows的使用
- 跑酷游戏的一些bug总结(滥用FixedUpdate的坑)
- 互联网进入网盘新时代
- PicGo搭建Gitee图库
- 美通企业日报 | 英特尔20亿美元收购AI芯片制造商;嘉吉投资扩建河北嘉好粮油...
- PCI Expansion ROMs
- Windows Workflow Foundation中实现人工活动的demo,按照XPDL规范的实现
- jmeter源码解读
- python 获取一年中所有工作日列表来辅助计算工作时间内的时间差
- 使用jsoup入门java爬虫 案例
- HTML-用户登录界面
- 从业务出发,来谈谈策略模式
- SpringBoot项目中怎么保证提供的接口不会被调崩