1 前言

ResNet 是一种经典的图像识别领域模型,在 2015 年图像识别领域多个竞赛中排行第一,并且性能上相较第二有大幅提升。在这篇文章里,我们就站在巨人们的肩膀上,搭建一个基于 ResNet 识别花卉图片(Oxford 102 Flowers)的神经网络吧。

2 ResNet 简介

在 ResNet 以前,由于存在梯度消失和梯度爆炸的问题,神经网路层数越深,网络越难以训练,导致深层网络的准确度出现下降。

ResNet 通过引入残差块(Residual block),将 a[l]添加到第二个 ReLU 过程中,直接建立 a[l]与 a[l+2]之间的隔层联系。表达式如下:

论文[1]作者推测模型对残差的拟合优化会比对随机权重的拟合更加容易(因为baseline就是恒等映射),所以在极端状况下,残差块的中间层没有激活,即W≈0,b≈0,则有:

残差块示例

所以这种构造方式保证了深层的网络比浅层包含了更多(至少恒等)的图像信息。多个残差块推挤在一起,便形成了一个残差网络。

残差网络和普通深度神经网络对比

3 用 ResNet 构造分类模型

在下列 demo 中,我们使用 keras 已有的 ResNet50预训练模型,对 Oxford 102 Flowers 数据集中的 10 种花卉图片进行多分类任务模型的构造。在工程上我们只需要修改 ResNet50 顶部的全连接层,对输入的图片数据进行裁剪,旋转,放大等数据增强,训练所有模型参数即可。代码如下:

import os
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model, load_model
from keras.applications import ResNet50
from keras.optimizers import Adam
from keras.layers import Flatten, Dense, Dropout, Input
from keras.callbacks import EarlyStopping, ModelCheckpoint
import numpy as np
import mathdef fc_block(X,units,dropout,stage):fc_name = 'fc' + str(stage)X = Dense(units,activation ='elu',name = fc_name)(X)X = Dropout(dropout)(X)return Xdef ResNet50_transfer():#call base_modelbase_model = ResNet50(include_top=False,weights="imagenet",input_tensor= Input(shape=img_size + (3,)))# freeze resnet layers' paramsfor layer in base_model.layers:layer.trainable = False# top architectureX = base_model.outputX = Flatten()(X)X = Dropout(0.4)(X)X = fc_block(X,fc_layer_units[0],dropout = 0.4,stage = 1)X = fc_block(X,fc_layer_units[1],dropout = 0.4,stage = 2)# output layerX = Dense(len(classes),activation='softmax',name = 'fc3_output')(X)# create modelmodel = Model(inputs = base_model.input,outputs = X, name = 'ResNet50_transfer')return modeldef generate_data(train_path,valid_path):# generate & augment training datatrain_datagen = ImageDataGenerator(rotation_range=30., shear_range=0.2, zoom_range=0.2, horizontal_flip=True)train_datagen.mean = np.array([123.675, 116.28 , 103.53], dtype=np.float32).reshape((3, 1, 1))train_data = train_datagen.flow_from_directory(train_path, target_size=img_size, classes=None)# generate training datavalid_datagen = ImageDataGenerator()valid_datagen.mean = np.array([123.675, 116.28 , 103.53], dtype=np.float32).reshape((3, 1, 1))valid_data = train_datagen.flow_from_directory(valid_path, target_size=img_size, classes=None)return train_data, valid_datadef call_back():early_stopping = EarlyStopping(verbose=1, patience=10, monitor='val_loss')model_checkpoint = ModelCheckpoint(filepath='102flowersmodel.h5', verbose=1, save_best_only=True, monitor='val_loss')callbacks = [early_stopping, model_checkpoint]return callbacks# path_to_img: 'dataset/flower_data_10/train/1//image_06734.jpg'
train_path = 'dataset/flower_data_10/train'
valid_path = 'dataset/flower_data_10/valid'nb_epoch = 20
batch_size = 32
img_size = (224,224)# output classes
classes = list(map(str,[1,2,3,4,5,6,7,8,9,10]))
rgb_mean = [123.68, 116.779, 103.939]
fc_layer_units = [512,64]model = ResNet50_transfer()
model.compile(loss='categorical_crossentropy', optimizer=Adam(lr=1e-5), metrics=['accuracy'])
train_data, valid_data = generate_data(train_path,valid_path)
callbacks = call_back()
model.fit_generator(train_data, steps_per_epoch= math.ceil(train_data.samples / batch_size), epochs=nb_epoch,validation_data=valid_data, validation_steps=math.ceil(valid_data.samples / batch_size),callbacks=callbacks)

经过 20 个 epoch 的训练后,验证集的准确度已经达到了 0.8837。

4 小结

本文章简单地介绍了 ResNet 的特点,以及提供了搭建图片分类模型的代码模板。显卡配置较高的同学可以尝试搭建不同规模的 ResNet 网络观察网络深度对模型性能的影响;对于图像识别模型感兴趣的同学推荐细读 ResNet 论文: Deep Residual Learning for Image Recognition。

参考资料

[1] Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385

深度学习 101-搭建 ResNet 识别鲜花图像相关推荐

  1. 采用keras深度学习框架搭建卷积神经网络模型实现垃圾分类,基于树莓派上进行实时视频流的垃圾识别源代码

    一.项目概述 简介:该垃圾分类项目主要在于对各种垃圾进行所属归类,本次项目采用keras深度学习框架搭建卷积神经网络模型实现图像分类,最终移植在树莓派上进行实时视频流的垃圾识别. 前期:主要考虑PC端 ...

  2. 《基于深度学习的加密流量识别研究》-2022毕设笔记

    参考文献: 基于深度学习的网络流量分类及异常检测方法研究_王伟 基于深度学习的加密流量分类技术研究与实现_马梦叠 基于深度学习的加密流量识别研究综述及展望_郭宇斌 基于深度学习的加密流量算法识别研究_ ...

  3. 【实战】深度学习构建人脸面部表情识别系统

    实战:深度学习构建人脸面部表情识别系统 一.表情数据集 数据集采用了kaggle面部表情识竞赛的人脸表情识别数据集. https://www.kaggle.com/c/challenges-in-re ...

  4. 【手写汉字识别】基于深度学习的脱机手写汉字识别技术研究

    写在前面 最近一段时间在为本科毕业设计做一些知识储备,方向与手写识别的系统设计相关,在看到一篇2019年题为<基于深度学习的脱机手写汉字识别技术研究>的工学硕士论文后,感觉收获比较大,准备 ...

  5. 基于深度学习的农作物病虫害识别

    摘要:我国有广阔的农作物种植面积,其中病虫害对农作物产量的影响最大,当农作物得了病虫害时,其整体生理机能会大大下降从而导致植株瘦小,无法达到最优生产状态从而产量不高经济效益低.因此农民需要多关注农作物 ...

  6. 一种基于深度学习的目标检测提取视频图像关键帧的方法

    摘要:针对传统的关键帧提取方法误差率高.实时性差等问题,提出了一种基于深度学习的目标检测提取视频图像关键帧的方法,分类提取列车头部.尾部及车身所在关键帧.在关键帧提取过程中,重点研究了基于SIFT特征 ...

  7. 基于深度学习的农作物病虫害识别系统

    1 简介 今天向大家介绍一个帮助往届学生完成的毕业设计项目,基于深度学习的农作物病虫害识别系统. ABSTRACT 及时.准确地诊断植物病害,对于防止农业生产的损失和农产品的损失或减少具有重要作用.为 ...

  8. Python基于深度学习yolov5的扑克牌识别

    Python基于深度学习yolov5的扑克牌识别(附带源码) 源程序来源于本人参与开发的一个网络扑克牌小游戏的图像识别.AI分析,AI出牌的小项目,做完后和大家分享一下扑克牌自动识别模块制作的过程. ...

  9. 基于深度学习的犬种识别软件(YOLOv5清新界面版,Python代码)

    摘要:基于深度学习的犬种识别软件用于识别常见多个犬品种,基于YOLOv5算法检测犬种,并通过界面显示记录和管理,智能辅助人们辨别犬种.本文详细介绍博主自主开发的犬种检测系统,在介绍算法原理的同时,给出 ...

最新文章

  1. 鲁棒,抗遮挡的对柔性手抓取的物体6D姿态估计
  2. Almost sorted interval
  3. asp.net gridview 为什么只显示一行数据_为什么中位数(大多数时候)比平均值好
  4. Java容器---List
  5. hbuilderX里uniapp和php,使用 DCloud 工具 HBuilder X 开发 uni-app 项目踩过的一些坑
  6. sprint敏捷开发
  7. 怎么设置代理服务器IP上网
  8. ML-Agents训练智能AI使用技巧
  9. 爬取北邮人论坛美食帖子
  10. Windows防火墙添加禁用规则——以禁用微信为例
  11. 【Linux问题栏】虚拟机中无法识别电脑摄像头和usb摄像头
  12. VMware ESXI上开虚机玩KVM
  13. JDK17的下载安装与配置(详细教程)
  14. CART分类与回归树
  15. 【Raft】学习九:成员变更ConfChangeV2
  16. c语言:模拟用户密码登录
  17. C++ map用法总结(整理)
  18. 竖流式沉淀池集水槽设计计算_各类沉淀池的设计要点都在这了!(建议收藏)...
  19. Cloud一分钟 |金立董事长赌博输超100亿;韩国全国4万工人大罢工;当当网李国庆力挺俞敏洪...
  20. gt-p7500 Android 4,三星GT-P7500 线刷固件包可救砖 刷机教程

热门文章

  1. Cauchy-Buniakowsky-Schwarz 积分形式证明
  2. html h1标签什么意思,什么是html h1标签?html h1标签使用方法的详细介绍
  3. SecurityError: Error #2123: 安全沙箱冲突,对NetStream使用BitmapData.draw()时出现的
  4. 软件定义汽车时代下,智能汽车软件架构逐步向 SOA 演进
  5. ((n % m) + m) % m; 是做什么
  6. 找工作到底是去国企还是先去互联网
  7. 尽管对领导力的定义众说纷纭
  8. CityEngine之cga语法------------NIL()(结束循环)
  9. C++和NASM联合编译
  10. LitePal+RecyclerView+checkBox实现便签功能(仿小米便签)