这次的结果是没有想到的,利用官方的Inception_ResNet_V2模型识别效果差到爆,应该是博主自己的问题,但是不知道哪儿出错了。
本次实验分别基于自己搭建的Inception_ResNet_V2和CNN网络实现交通标志识别,准确率很高。

1.导入库

import tensorflow as tf
import matplotlib.pyplot as plt
import os,PIL,pathlib
import pandas as pd
import numpy as np
from tensorflow import keras
from tensorflow.keras import layers,models
from tensorflow.keras import layers, models, Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Dropout,BatchNormalization,Activation
from tensorflow.keras.layers import MaxPooling2D, AveragePooling2D, Concatenate, Lambda,GlobalAveragePooling2D
from tensorflow.keras import backend as K

2.导入数据

数据形式如下所示:

其实images中包含5998张交通标志的图片,其中一共是58种类别。

annotations中是各个图片的名称以及它所代表的种类,一共是58种。

#图片预处理
def preprocess_image(image):##归一化&&调整图片大小image = tf.image.decode_jpeg(image,channels=3)image = tf.image.resize(image,[299,299])return image/255.0
def load_and_preprocess_image(path):#根据路径读入图片image = tf.io.read_file(path)return preprocess_image(image)
#导入数据
data_dir = "E:/tmp/.keras/datasets/trasig_photos/images"
data_dir = pathlib.Path(data_dir)
#导入训练数据的图片路径以及标签
train = pd.read_csv("E:/tmp/.keras/datasets/trasig_photos/annotations.csv")
#图片所在的主路径
img_dir = "E:/tmp/.keras/datasets/trasig_photos/images/"
#训练数据的标签
train_image_label = [i for i in train["category"]]
train_label_ds = tf.data.Dataset.from_tensor_slices(train_image_label)
#训练数据的路径既每一张图片的具体路径
train_image_paths = [img_dir+i for i in train["file_name"]]
#加载图片路径
train_path_ds = tf.data.Dataset.from_tensor_slices(train_image_paths)
#加载图片数据
train_image_ds = train_path_ds.map(load_and_preprocess_image,num_parallel_calls=tf.data.experimental.AUTOTUNE)
#将图片与路径对应进行打包
image_label_ds = tf.data.Dataset.zip((train_image_ds,train_label_ds))

训练集与测试集的划分:

train_ds = image_label_ds.take(5000).shuffle(1000)
test_ds = image_label_ds.skip(5000).shuffle(1000)train_ds = train_ds.batch(batch_size)#设置batch_size
train_ds = train_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test_ds = test_ds.batch(batch_size)
test_ds = test_ds.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

检查图片是否被打乱

被打乱了,达到预期的目标。

3.CNN网络

CNN模型是博主自己搭建的,包含三层卷积池化层+Flatten+二层全连接层。其中池化层前两个用的MaxPooling,最后一个用的AveragePooling。

model = models.Sequential([tf.keras.layers.Conv2D(32,(3,3),activation='relu',input_shape=(299,299,3)),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(64,(3,3),activation='relu'),tf.keras.layers.MaxPooling2D(),tf.keras.layers.Conv2D(128,(3,3),activation='relu'),tf.keras.layers.AveragePooling2D(),tf.keras.layers.Flatten(),tf.keras.layers.Dense(1000,activation='relu'),tf.keras.layers.Dense(58,activation='softmax')
])

优化器的设置,与上篇博客的设置无异。

#优化器的设置
initial_learning_rate = 1e-4
lr_sch = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=initial_learning_rate,decay_steps=100,decay_rate=0.96,staircase=True
)

模型编译&&训练

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr_sch),loss='sparse_categorical_crossentropy',metrics=['accuracy']
)
history = model.fit(train_ds,validation_data=test_ds,epochs=epochs
)

实验结果如下所示:

其中训练集的准确率已经接近100%,测试集的准确率95%左右,这是在batch_size=8,epoch=5的情况下。增加epoch的情况下,准确率应该会提高。

4.Inception_ResNet_V2网络

模型的搭建参考大神**K同学啊**的博客。
网络模型如下所示:

网络模型搭建:

def conv2d_bn(x, filters, kernel_size, strides=1, padding='same', activation='relu', use_bias=False, name=None):x = Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias, name=name)(x)if not use_bias:bn_axis = 1 if K.image_data_format() == 'channels_first' else 3bn_name = None if name is None else name + '_bn'x = BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)if activation is not None:ac_name = None if name is None else name + '_ac'x = Activation(activation, name=ac_name)(x)return xdef inception_resnet_block(x, scale, block_type, block_idx, activation='relu'):if block_type == 'block35':branch_0 = conv2d_bn(x, 32, 1)branch_1 = conv2d_bn(x, 32, 1)branch_1 = conv2d_bn(branch_1, 32, 3)branch_2 = conv2d_bn(x, 32, 1)branch_2 = conv2d_bn(branch_2, 48, 3)branch_2 = conv2d_bn(branch_2, 64, 3)branches = [branch_0, branch_1, branch_2]elif block_type == 'block17':branch_0 = conv2d_bn(x, 192, 1)branch_1 = conv2d_bn(x, 128, 1)branch_1 = conv2d_bn(branch_1, 160, [1, 7])branch_1 = conv2d_bn(branch_1, 192, [7, 1])branches = [branch_0, branch_1]elif block_type == 'block8':branch_0 = conv2d_bn(x, 192, 1)branch_1 = conv2d_bn(x, 192, 1)branch_1 = conv2d_bn(branch_1, 224, [1, 3])branch_1 = conv2d_bn(branch_1, 256, [3, 1])branches = [branch_0, branch_1]else:raise ValueError('Unknown Inception-ResNet block type. ''Expects "block35", "block17" or "block8", ''but got: ' + str(block_type))block_name = block_type + '_' + str(block_idx)mixed = Concatenate(name=block_name + '_mixed')(branches)up = conv2d_bn(mixed, K.int_shape(x)[3], 1, activation=None, use_bias=True, name=block_name + '_conv')x = Lambda(lambda inputs, scale: inputs[0] + inputs[1] * scale,output_shape=K.int_shape(x)[1:],arguments={'scale': scale},name=block_name)([x, up])if activation is not None:x = Activation(activation, name=block_name + '_ac')(x)return xdef InceptionResNetV2(input_shape=[299, 299, 3], classes=1000):inputs = Input(shape=input_shape)# Stem blockx = conv2d_bn(inputs, 32, 3, strides=2, padding='valid')x = conv2d_bn(x, 32, 3, padding='valid')x = conv2d_bn(x, 64, 3)x = MaxPooling2D(3, strides=2)(x)x = conv2d_bn(x, 80, 1, padding='valid')x = conv2d_bn(x, 192, 3, padding='valid')x = MaxPooling2D(3, strides=2)(x)# Mixed 5b (Inception-A block)branch_0 = conv2d_bn(x, 96, 1)branch_1 = conv2d_bn(x, 48, 1)branch_1 = conv2d_bn(branch_1, 64, 5)branch_2 = conv2d_bn(x, 64, 1)branch_2 = conv2d_bn(branch_2, 96, 3)branch_2 = conv2d_bn(branch_2, 96, 3)branch_pool = AveragePooling2D(3, strides=1, padding='same')(x)branch_pool = conv2d_bn(branch_pool, 64, 1)branches = [branch_0, branch_1, branch_2, branch_pool]x = Concatenate(name='mixed_5b')(branches)# 10次 Inception-ResNet-A blockfor block_idx in range(1, 11):x = inception_resnet_block(x, scale=0.17, block_type='block35', block_idx=block_idx)# Reduction-A blockbranch_0 = conv2d_bn(x, 384, 3, strides=2, padding='valid')branch_1 = conv2d_bn(x, 256, 1)branch_1 = conv2d_bn(branch_1, 256, 3)branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding='valid')branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)branches = [branch_0, branch_1, branch_pool]x = Concatenate(name='mixed_6a')(branches)# 20次 Inception-ResNet-B blockfor block_idx in range(1, 21):x = inception_resnet_block(x, scale=0.1, block_type='block17', block_idx=block_idx)# Reduction-B blockbranch_0 = conv2d_bn(x, 256, 1)branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding='valid')branch_1 = conv2d_bn(x, 256, 1)branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding='valid')branch_2 = conv2d_bn(x, 256, 1)branch_2 = conv2d_bn(branch_2, 288, 3)branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding='valid')branch_pool = MaxPooling2D(3, strides=2, padding='valid')(x)branches = [branch_0, branch_1, branch_2, branch_pool]x = Concatenate(name='mixed_7a')(branches)# 10次 Inception-ResNet-C blockfor block_idx in range(1, 10):x = inception_resnet_block(x, scale=0.2, block_type='block8', block_idx=block_idx)x = inception_resnet_block(x, scale=1., activation=None, block_type='block8', block_idx=10)x = conv2d_bn(x, 1536, 1, name='conv_7b')x = GlobalAveragePooling2D(name='avg_pool')(x)x = Dense(classes, activation='softmax', name='predictions')(x)# 创建模型model = Model(inputs, x, name='inception_resnet_v2')return model
model = InceptionResNetV2([299, 299, 3], 58)

相比起VGG系列的网络,该网络复杂了太多了。而且运行速度真的是不敢恭维。但是模型的准确率是真的高。

总结:
本次实验真的体会到了硬件问题对于深度学习的局限性。昨晚用InceptionResNetV2的官方模型开始跑,跑了5-6个小时,才跑了6的epoch,而且每个epoch的准确率在30%左右,真的是心态炸裂。参考大神的博客,利用自己搭建的模型开始跑,这次epochs设置的是5,跑了5个小时左右,才算是跑完。

本次实验InceptionResNetV2模型的效果虽然好,但是自己搭建的CNN模型的准确率照样很好,所以InceptionResNetV2的优势在本次实验中并没有完全体现出来,但是它的强大还不是一般的CNN网络可以比拟的。

希望路过的大佬可以用官方的InceptionResNetV2模型跑一下,如果效果很好,可以私聊我分享一下经验,感谢!!

努力加油a啊

深度学习之基于Inception_ResNet_V2和CNN实现交通标志识别相关推荐

  1. 深度学习100例 | 第3天:交通标志识别 - PyTorch实现

    文章目录 一.导入数据 1. 获取类别名 2. 数据可视化 3. 加载数据文件 4. 划分数据 二.自建模型 三.模型训练 1. 优化器与损失函数 2. 模型的训练 四.结果分析 大家好,我是K同学啊 ...

  2. 基于引导图像滤波的交通标志识别改进框架

    摘要 在雾霾.下雨.光照弱等光照条件下,由于漏检或定位不正确,交通标志识别的精度不是很高.本文提出了一种基于Faster R-CNN和YOLOv5的交通标志识别(TSR)算法.道路标志是从驾驶员的角度 ...

  3. 基于BP 网络分类器的交通标志识别

    基于BP 网络分类器的交通标志识别 摘要:针对中国全部 3 大类 116 个交通标志,即禁令标志.指示标志.警告标志,用 BP 网络实现分类功能. 实验中使用了 3 种测试集,即加高斯噪声.水平扭曲和 ...

  4. 基于MATLAB的SVM的交通标志识别

    基于MATLAB的SVM的交通标志识别 摘要:本文针对三种不同的交通标识(直行.右拐和直行左拐)给出了一种基于SVM识别方法.该方法首先在分析训练集交通标识图片特点的基础上,提取它们的PHOG特征向量 ...

  5. 【深度学习】手把手教你使用CNN进行交通标志识别(已开源)

    在本文中,使用Python编程语言和库Keras和OpenCV建立CNN模型,成功地对交通标志分类器进行分类,准确率达96%.开发了一款交通标志识别应用程序,该应用程序具有图片识别和网络摄像头实时识别 ...

  6. 手把手教你使用CNN进行交通标志识别(已开源)

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 在本文中,使用Python编程语言和库Keras和OpenCV建立 ...

  7. 【交通标志识别】基于matlab GUI BP神经网络交通标志识别系统(含语音报警)【含Matlab源码 2240期】

    ⛄一.BP神经网络交通标志识别简介 道路交通标志用以禁止.警告.指示和限制道路使用者有秩序地使用道路, 保障出行安全.若能自动识别道路交通标志, 则将极大减少道路交通事故的发生.但是由于道路交通错综复 ...

  8. python2.7交通标志识别图_(四)深度学习初探:基于LeNet-5卷积神经网络的交通标志识别...

    1.项目任务 在常见深度学习模型的基础上(本文为LeNet-5),调整模型结构和参数,使用Tensorflow进行部署.利用公开的德国交通标志数据集进行训练,得到模型,并利用该模型对新的图片进行预测. ...

  9. 深度学习之基于opencv和CNN实现人脸识别

    这个项目在之前人工智能课设上做过,但是当时是划水用的别人的.最近自己实现了一下,基本功能可以实现,但是效果并不是很好.容易出现错误识别,或者更改了背景之后识别效果变差的现象.个人以为是数据选取的问题, ...

最新文章

  1. JavaScript面向对象编程笔记
  2. 3_1 StrategyMode.cpp 策略模式
  3. mysql 导入dbm文件_DBM数据导入到mysql数据库方法
  4. okhttp post json 数据_使用python抓取App数据
  5. python3.5安装scrapy_win7+Python3.5下scrapy的安装方法
  6. 一文弄懂“分布式锁”
  7. 如何在macOS Big Sur的Voice Memos中使用增强录音和智能文件夹?
  8. 编程中常见的安全算法
  9. android系统源代码单独编译应用程序
  10. Java并发编程实践-总结
  11. Dubbo入门详细教程
  12. 基于python对doi号通过sci-hub进行pubmed中的文献下载
  13. 我的世界血量显示的服务器,Minecraft|世纪之都|服务器mod:拔刀剑 工业2 高级太阳能 血量显示 Nei 聊天泡泡等...
  14. python抓取网易云音乐评论_Python 爬取网易云音乐评论
  15. STM32:SPI驱动ADXL345
  16. 支付宝周期扣款Java逻辑代码
  17. 给生命一段独处的时光
  18. 批处理 检测U盘插入并自动备份文件
  19. 宾得常用镜头群[转自东河寒梅]
  20. SQL注入的原理、过程及如何防范

热门文章

  1. swift5主线程延迟操作的几种写法
  2. delphi edit里面的文字如何添加下划线_标题设计如何处理更吸引人?来看设计高手的实用技巧...
  3. java冒泡排序算法代码降序_冒泡排序(起泡排序)算法及其C语言实现
  4. java8 例外网站_Java8兰巴达斯和例外
  5. 获取某个输入框的字符长度_收藏,最全的字符串函数方法,总有你用到的~
  6. 一些Setup Factory 教程的链接
  7. Invocation failed Unexpected end of file from server java.lang.RuntimeException: Invocation failed U
  8. IDEA git修改远程仓库地址
  9. Apache Spark 的设计与实现(job逻辑执行图)
  10. TortoiseSVN与VisualSVN Server搭建SVN版本控制系统【转】