点击上方“小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

在本文中,使用Python编程语言和库Keras和OpenCV建立CNN模型,成功地对交通标志分类器进行分类,准确率达96%。开发了一款交通标志识别应用程序,该应用程序具有图片识别和网络摄像头实时识别两种工作方式。

本文的GitHub:https://github.com/Daulettulegenov/TSR_CNN

提供一个开源的交通标志的数据集,希望能够帮助到各位小伙伴:http://www.nlpr.ia.ac.cn/pal/trafficdata/recognition.html

近年来,计算机视觉是现代技术发展的一个方向。这个方向的主要任务是对照片或摄像机中的物体进行分类。在通常的问题中,使用基于案例的机器学习方法来解决。本文介绍了利用机器学习算法进行计算机视觉在交通标志识别中的应用。路标是一种外形固定的扁平人造物体。道路标志识别算法应用于两个实际问题。第一个任务是控制自动驾驶汽车。无人驾驶车辆控制系统的一个关键组成部分是物体识别。识别的对象主要是行人、其他车辆、交通灯和路标。第二个使用交通标志识别的任务是基于安装在汽车上的DVRs的数据自动绘制地图。接下来将详细介绍如果搭建能够识别交通标志的CNN网络。

导入必要的库

# data analysis and wrangling
import numpy as np
import pandas as pd
import os
import random# visualization
import matplotlib.pyplot as plt
from PIL import Image
# machine learning
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from keras.utils.np_utils import to_categorical
from keras.layers import Dropout, Flatten
from keras.layers.convolutional import Conv2D, MaxPooling2D
import cv2
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import ImageDataGenerator

加载数据

Python Pandas包帮助我们处理数据集。我们首先将训练和测试数据集获取到Pandas DataFrames中。我们还将这些数据集组合起来,在两个数据集上一起运行某些操作。

# Importing of the Images
count = 0
images = []
classNo = []
myList = os.listdir(path)
print("Total Classes Detected:",len(myList))
noOfClasses=len(myList)
print("Importing Classes.....")
for x in range (0,len(myList)):myPicList = os.listdir(path+"/"+str(count))for y in myPicList:curImg = cv2.imread(path+"/"+str(count)+"/"+y)curImg = cv2.resize(curImg, (30, 30))images.append(curImg)classNo.append(count)print(count, end =" ")count +=1
print(" ")
images = np.array(images)
classNo = np.array(classNo)

为了对已实现的系统进行适当的训练和评估,我们将数据集分为3组。数据集分割:20%测试集,20%验证数据集,剩余的数据用作训练数据集。

# Split Data
X_train, X_test, y_train, y_test = train_test_split(images, classNo, test_size=testRatio)
X_train, X_validation, y_train, y_validation = train_test_split(X_train, y_train, test_size=validationRatio)

该数据集包含34799张图像,由43种类型的路标组成。这些包括基本的道路标志,如限速、停车标志、让路、优先道路、“禁止进入”、“行人”等。

# DISPLAY SOME SAMPLES IMAGES OF ALL THE CLASSES
num_of_samples = []
cols = 5
num_classes = noOfClasses
fig, axs = plt.subplots(nrows=num_classes, ncols=cols, figsize=(5, 300))
fig.tight_layout()
for i in range(cols):for j,row in data.iterrows():x_selected = X_train[y_train == j]axs[j][i].imshow(x_selected[random.randint(0, len(x_selected)- 1), :, :], cmap=plt.get_cmap("gray"))axs[j][i].axis("off")if i == 2:axs[j][i].set_title(str(j)+ "-"+row["Name"])num_of_samples.append(len(x_selected))

# DISPLAY A BAR CHART SHOWING NO OF SAMPLES FOR EACH CATEGORY
print(num_of_samples)
plt.figure(figsize=(12, 4))
plt.bar(range(0, num_classes), num_of_samples)
plt.title("Distribution of the training dataset")
plt.xlabel("Class number")
plt.ylabel("Number of images")
plt.show()

数据集中的类之间存在显著的不平衡。有些类的图像少于200张,而其他类的图像超过1000张。这意味着我们的模型可能偏向于过度代表的类别,特别是当它对自己的预测不自信时。为了解决这个问题,我们使用了现有的图像转换技术。

为了更好的分类,数据集中的所有图像都被转换为灰度图像

# PREPROCESSING THE IMAGES
def grayscale(img):img = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)return imgdef equalize(img):img =cv2.equalizeHist(img)return imgdef preprocessing(img):img = grayscale(img)     # CONVERT TO GRAYSCALEimg = equalize(img)      # STANDARDIZE THE LIGHTING IN AN IMAGEimg = img/255            # TO NORMALIZE VALUES BETWEEN 0 AND 1 INSTEAD OF 0 TO 255return imgX_train=np.array(list(map(preprocessing,X_train)))  # TO IRETATE AND PREPROCESS ALL IMAGES
X_validation=np.array(list(map(preprocessing,X_validation)))
X_test=np.array(list(map(preprocessing,X_test)))

数据增强是对原始数据集进行增强的一种方法。数据越多,结果越高,这是机器学习的基本规律。

#AUGMENTATAION OF IMAGES: TO MAKEIT MORE GENERIC
dataGen= ImageDataGenerator(width_shift_range=0.1,   # 0.1 = 10%     IF MORE THAN 1 E.G 10 THEN IT REFFERS TO NO. OF  PIXELS EG 10 PIXELSheight_shift_range=0.1,zoom_range=0.2,  # 0.2 MEANS CAN GO FROM 0.8 TO 1.2shear_range=0.1,  # MAGNITUDE OF SHEAR ANGLErotation_range=10)  # DEGREES
dataGen.fit(X_train)
batches= dataGen.flow(X_train,y_train,batch_size=20)  # REQUESTING DATA GENRATOR TO GENERATE IMAGES  BATCH SIZE = NO. OF IMAGES CREAED EACH TIME ITS CALLED
X_batch,y_batch = next(batches)

热编码用于我们的分类值y_train、y_test、y_validation。

y_train = to_categorical(y_train,noOfClasses)
y_validation = to_categorical(y_validation,noOfClasses)
y_test = to_categorical(y_test,noOfClasses)

使用Keras库创建一个神经网络。下面是创建模型结构的代码:

def myModel():model = Sequential()model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu', input_shape=X_train.shape[1:]))model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(rate=0.25))model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(rate=0.25))model.add(Flatten())model.add(Dense(256, activation='relu'))model.add(Dropout(rate=0.5))model.add(Dense(43, activation='softmax'))model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])return model
# TRAIN
model = myModel()
print(model.summary())
history = model.fit(X_train, y_train, batch_size=batch_size_val, epochs=epochs_val, validation_data=(X_validation,y_validation))

上面的代码使用了6个卷积层和1个全连接层。首先,在模型中添加带有32个滤波器的卷积层。接下来,我们添加一个带有64个过滤器的卷积层。在每一层的后面,增加一个窗口大小为2 × 2的最大拉层。还添加了系数为0.25和0.5的Dropout层,以便网络不会再训练。在最后几行中,我们添加了一个稠密的稠密层,该层使用softmax激活函数在43个类中执行分类。

在最后一个epoch结束时,我们得到以下值:loss = 0.0523;准确度= 0.9832;Val_loss = 0.0200;Val_accuracy = 0.9943,这个结果看起来非常好。之后绘制我们的训练过程

#PLOT
plt.figure(1)
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.legend(['training','validation'])
plt.title('loss')
plt.xlabel('epoch')
plt.figure(2)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training','validation'])
plt.title('Acurracy')
plt.xlabel('epoch')
plt.show()
score =model.evaluate(X_test,y_test,verbose=0)
print('Test Score:',score[0])
print('Test Accuracy:',score[1])

#testing accuracy on test dataset
from sklearn.metrics import accuracy_scorey_test = pd.read_csv('Test.csv')
labels = y_test["ClassId"].values
imgs = y_test["Path"].values
data=[]
for img in imgs:image = Image.open(img)image = image.resize((30,30))data.append(np.array(image))
X_test=np.array(data)
X_test=np.array(list(map(preprocessing,X_test)))
predict_x=model.predict(X_test)
pred=np.argmax(predict_x,axis=1)
print(accuracy_score(labels, pred))

我们在测试数据集中测试了构建的模型,得到了96%的准确性。

使用内置函数model_name.save(),我们可以保存一个模型以供以后使用。该功能将模型保存在本地的.p文件中,这样我们就不必一遍又一遍地重新训练模型而浪费大量的时间。

model.save("CNN_model_3.h5")

接下来给大家看一些识别的结果

好消息!

小白学视觉知识星球

开始面向外开放啦

手把手教你使用CNN进行交通标志识别(已开源)相关推荐

  1. 【深度学习】手把手教你使用CNN进行交通标志识别(已开源)

    在本文中,使用Python编程语言和库Keras和OpenCV建立CNN模型,成功地对交通标志分类器进行分类,准确率达96%.开发了一款交通标志识别应用程序,该应用程序具有图片识别和网络摄像头实时识别 ...

  2. 深度学习之基于Inception_ResNet_V2和CNN实现交通标志识别

    这次的结果是没有想到的,利用官方的Inception_ResNet_V2模型识别效果差到爆,应该是博主自己的问题,但是不知道哪儿出错了. 本次实验分别基于自己搭建的Inception_ResNet_V ...

  3. python交通标志识别_YOLOv3目标检测实战:交通标志识别

    在无人驾驶中,交通标志识别是一项重要的任务.本项目以美国交通标志数据集LISA为训练对象,采用YOLOv3目标检测方法实现实时交通标志识别. 具体项目过程包括包括:安装Darknet.下载LISA交通 ...

  4. 基于引导图像滤波的交通标志识别改进框架

    摘要 在雾霾.下雨.光照弱等光照条件下,由于漏检或定位不正确,交通标志识别的精度不是很高.本文提出了一种基于Faster R-CNN和YOLOv5的交通标志识别(TSR)算法.道路标志是从驾驶员的角度 ...

  5. opencv交通标志识别_教你从零开始做一个基于深度学习的交通标志识别系统

    教你从零开始做一个基于深度学习的交通标志识别系统 基于Yolo v3的交通标志识别系统及源码 自动驾驶之--交通标志识别 在本文章你可以学习到如何训练自己采集的数据集,生成模型,并用yolo v3算法 ...

  6. 基于深度学习的大规模交通标志识别(附6GB交通标志数据集)

    01 1.文章信息 <Deep Learning for Large-Scale Traffic-Sign Detection and Recognition>. 国外学者2020年发在I ...

  7. 深度学习交通标志识别项目

    主要内容 在本文中,使用Python编程语言和库Keras和OpenCV建立CNN模型,成功地对交通标志分类器进行分类,准确率达96%.开发了一款交通标志识别应用程序,该应用程序具有图片识别和网络摄像 ...

  8. Python交通标志识别基于卷积神经网络的保姆级教程(Tensorflow)

    项目介绍 TensorFlow2.X 搭建卷积神经网络(CNN),实现交通标志识别.搭建的卷积神经网络是类似VGG的结构(卷积层与池化层反复堆叠,然后经过全连接层,最后用softmax映射为每个类别的 ...

  9. 交通标志识别论文综述

    交通标志识别是计算机视觉领域的一个研究热点.主要研究方向是使用机器学习和图像处理技术来识别交通标志. 近年来,随着深度学习技术的发展,交通标志识别的研究取得了显著进展.许多研究人员提出了基于卷积神经网 ...

最新文章

  1. 定时器:SetTimer
  2. matlab中怎样画出散点图,将这些散点连接成线
  3. 企业架构(三)——联邦企业架构框架(FEAF)
  4. mysql 定义XML字段_MyBatis之基于XML的属性与列名映射
  5. Http协议--Get和Post区别
  6. Intel超线程技术 Hyper-Threading Technology (1) - 引言与历史
  7. 健康评测 php,8款超好用的健康APP测评推荐!
  8. access查询设计sol视图_Access删除索引
  9. 用Bat脚本写一个无限弹窗代码
  10. html5个人简历代码模板,个人简历HTML模板
  11. STM32单片机全自动锂电池容量电量检测放电电流电池电压ACS712
  12. fudanNLP-使用
  13. 【Python3】简易爬虫实现船舶的MMSI的获取
  14. Elasticsearch安全认证
  15. 祝萍:后疫情时代,医美运营既要走心也要反套路
  16. CentOS7.9安装Nextcloud+ocDownloader+aria2使用Nextcloud网盘做离线下载服务器
  17. 2-SAT问题,一个神奇的东西
  18. CVE-2020-1472: NetLogon特权提升漏洞通告
  19. 计算机命令vty是什么意思,华为交换机基础命令中user interface 0和user-interface vty 0的区别...
  20. 下载电影、软件、工具的利器--讯雷

热门文章

  1. python正则匹配括号内任意字符,python 正则匹配 获取括号内字符
  2. 并发编程(七)好用的线程池ThreadPoolExecutor
  3. 超级计算机国产cpu,为何国产超级计算机已经领先全世界了,而国产cpu却依然落后?...
  4. Blockathon2018(上海)顺利结束,9个项目打开区块链落地新思路
  5. Linux进程与计划任务
  6. 面试突击:什么是粘包和半包?怎么解决?
  7. 安卓培训机构排名!这篇文章可以满足你80%日常工作!跳槽薪资翻倍
  8. SpringBoot时区配置
  9. 学习java开发培训
  10. c# picturebox控件的使用方法介绍