自定义卷积网络完成分类。图像预处理(直方图均衡化增加对比度)。

使用数据:德国交通信号识别,其中train/test dataset的Images and annotations及test dataset的Extended annotations including class ids

实验结果

数据及代码组织结构:

训练过程与结果:

代码

"""
@file: tranfficSignRec.py
@time: 2018/10/26
"""
import pandas as pd
import numpy as np
from skimage import io, color, exposure, transform
import glob
import h5py
from keras.models import Sequential, model_from_json
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D
from keras.layers.pooling import MaxPooling2D
from keras.optimizers import SGD
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from matplotlib import pyplot as plt
import os
from pathlib import PurePath
import warningswarnings.filterwarnings('ignore')  # 忽略警告NUM_CLASSES = 43  # 43种交通标志
IMG_SIZE = 48  # 图像大小归一化为48batch_size = 32 # 训练的参数
nb_epoch = 10
lr = 0.01# 图像直方图均衡化(调整对比度)、取中心、resize
def preprocess_img(img):hsv = color.rgb2hsv(img)hsv[:, :, 2] = exposure.equalize_hist(hsv[:, :, 2])img = color.hsv2rgb(hsv)ms = min(img.shape[:2])xx = (img.shape[0] - ms) // 2yy = (img.shape[1] - ms) // 2img = img[xx:xx + ms, yy:yy + ms, :]img = transform.resize(img, (IMG_SIZE, IMG_SIZE))return img# 根据路径(图片上层目录)得到标签
def get_class(img_path):return int(PurePath(img_path).parts[- 2])def readfile():# 读取所有图片、标签(onehot),存放至h5py文件try:with h5py.File('X.h5') as hf:X, Y = hf['imgs'][:], hf['labels'][:]print("Loaded images from X.h5")except BaseException:print("Error in reading X.h5. Processing all images...")root_dir = r'../data/GTSRB/Final_Training/Images'imgs = []labels = []all_img_paths = glob.glob(os.path.join(root_dir,'*/*.ppm'))  # 提取所有ppm文件完整路径np.random.shuffle(all_img_paths)  # 打散for img_path in all_img_paths:try:img = preprocess_img(io.imread(img_path))label = get_class(img_path)imgs.append(img)labels.append(label)if len(imgs) % 1000 == 0:print("Processed %d/%d" %(len(imgs), len(all_img_paths)))except BaseException:print('missed', img_path)passX = np.array(imgs, dtype='float32')# labels数组转onehotY = np.eye(len(labels), NUM_CLASSES, dtype=np.uint8)[labels]# 可以加速载入与处理with h5py.File('X.h5', 'w') as hf:hf.create_dataset('imgs', data=X)hf.create_dataset('labels', data=Y)return X, Ydef cnn_model():model = Sequential()model.add(Conv2D(32,(3,3),padding='same',activation='relu',input_shape=(IMG_SIZE,IMG_SIZE,3)))model.add(Conv2D(32, (3, 3), activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.2))model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.2))model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Dropout(0.2))model.add(Flatten())model.add(Dense(512, activation='relu'))model.add(Dropout(0.5))model.add(Dense(NUM_CLASSES, activation='softmax'))sgd = SGD(lr=lr, decay=1e-6, momentum=0.9, nesterov=True)model.compile(loss='categorical_crossentropy',optimizer=sgd,metrics=['accuracy'])return model# 每10个epoch学习率递减0.1倍
def lr_schedule(epoch):return lr * (0.1 ** int(epoch / 10))def train():model = cnn_model()print(model.summary())X, Y = readfile()history = model.fit(X,Y,batch_size=batch_size,epochs=nb_epoch,validation_split=0.2,shuffle=True,verbose=2,callbacks=[LearningRateScheduler(lr_schedule),# ModelCheckpoint('model.h5',save_best_only=True)])# 可视化训练曲线(train,val)plt.figure(figsize=(8, 4))plt.subplot(1, 2, 1)plot_train_history(history, 'loss', 'val_loss')plt.subplot(1, 2, 2)plot_train_history(history, 'acc', 'val_acc')plt.show()return modeldef plot_train_history(history, train_metrics, val_metrics):plt.plot(history.history.get(train_metrics), '-o')plt.plot(history.history.get(val_metrics), '-o')plt.ylabel(train_metrics)plt.xlabel('Epochs')plt.legend(['train', 'validation'])# 在测试集上测试
def test(model):test = pd.read_csv('../data/GTSRB/GT-final_test.csv',sep=';')X_test = []y_test = []for file_name, class_id in zip(list(test['Filename']), list(test['ClassId'])):img_path = os.path.join('../data/GTSRB/Final_Test/Images/', file_name)X_test.append(preprocess_img(io.imread(img_path)))y_test.append(class_id)X_test = np.array(X_test)y_test = np.array(y_test)print("X_test.shape: ", X_test.shape)print("y_test.shape: ", y_test.shape)y_pred = model.predict_classes(X_test) # 返回预测值acc = np.sum(y_pred == y_test) / np.size(y_pred)print("Test accuracy = {} ".format(acc))if __name__ == '__main__':model=train()test(model)

参考:https://github.com/erhwenkuo/deep-learning-with-keras-notebooks

keras交通信号识别(分类)相关推荐

  1. 无人驾驶汽车系统入门(十一)——深度前馈网络,深度学习的正则化,交通信号识别

    无人驾驶汽车系统入门(十一)--深度前馈网络,深度学习的正则化,交通信号识别 在第九篇博客中我们介绍了神经网络,它是一种机器学习方法,基于经验风险最小化策略,凭借这神经网络的拟合任意函数的能力,我们可 ...

  2. 无人驾驶汽车系统入门:深度前馈网络,深度学习的正则化,交通信号识别

    作者 | 申泽邦(Adam Shan) 兰州大学在读硕士研究生,主攻无人驾驶,深度学习:兰大未来计算研究院无人车团队骨干,在改自己的无人车,参加过很多无人车Hackathon,喜欢极限编程. 在前几十 ...

  3. 今天,送你一份交通行业最全数据集(共享单车、自动驾驶、网约出租车、交通信号识别)

    近几年来共享单车.自动驾驶等交通行业发展得如荼如火,小编也一直有意识地收集相关数据集,经过长时间的积累和沉淀,已经拥有将近300G的交通数据,内容涵盖国内外"自动驾驶"." ...

  4. 2022-2028全球城市轨道交通信号系统市场专题研究及投资评估报告

    城市轨道交通信号系统通常由列车运行自动控制系统(ATC)和车辆段信号控制系统两大部分组成,用于列车联锁.进路控制.列车间隔控制.调度指挥.信息管理.设备工况监测及维护管理等方面,由此构成一个高效综合自 ...

  5. 基于MATLAB的SVM的交通标志识别

    基于MATLAB的SVM的交通标志识别 摘要:本文针对三种不同的交通标识(直行.右拐和直行左拐)给出了一种基于SVM识别方法.该方法首先在分析训练集交通标识图片特点的基础上,提取它们的PHOG特征向量 ...

  6. 交通信号标志识别软件(Python+YOLOv5深度学习模型+清新界面)

    摘要:交通信号标志识别软件用于交通信号标志的检测和识别,利用机器视觉和深度学习智能识别交通标志并可视化记录,以辅助无人驾驶等.本文详细介绍交通信号标志识别软件,在介绍算法原理的同时,给出Python的 ...

  7. 基于cnn的短文本分类_基于时频分布和CNN的信号调制识别分类方法

    文章来源:IET Radar, Sonar & Navigation, 2018, Vol. 12, Iss. 2, pp. 244-249. 作者:Juan Zhang1, Yong Li2 ...

  8. keras系列︱人脸表情分类与识别:opencv人脸检测+Keras情绪分类(四)

    人脸识别热门,表情识别更加.但是表情识别很难,因为人脸的微表情很多,本节介绍一种比较粗线条的表情分类与识别的办法. Keras系列: 1.keras系列︱Sequential与Model模型.kera ...

  9. matlab的交通灯信号识别,交通灯识别系统.docx

    摘要:近些年来随着城市化建设的迅速加快,机动车数量增加迅速普及到人们的生活中,机动车辆的行驶安全已经成为全世界关注的热点.然而由于道路状况的复杂性与交通信息的多样性使得驾驶员在行驶的过程中注意力不容易 ...

最新文章

  1. Android布局琐碎(原)
  2. linux udp套接字编程获取报文源地址和源端口(二)
  3. 数据结构C#版笔记--堆栈(Stack)
  4. java登录界面命令_Java命令行界面(第18部分):JCLAP
  5. c语言中判断一个字符串是否包含另一个字符串
  6. word交叉引用插入文献后更新域之后编号未更新
  7. 数据恢复-SQL被注入攻击程序的应对策略(ORA-16703)
  8. bzoj5017 [Snoi2017]炸弹
  9. C++socket编程(五):5.2 tcp编程总结
  10. 驱动精灵恶意投放后门程序 云控劫持流量、诱导推广
  11. python学习笔记(六):if语句之处理数据
  12. 03_美国医疗保健行业的数据介绍
  13. Xib中设置view的BorderColor 及 ShadowColor
  14. 开年工作重点:帮助同事找到工作的价值
  15. Java并发机制的底层实现原理--volatile
  16. 二阶魔方还原 - 4步2公式
  17. 包装exp是什么意思_药瓶说明中EXP是什么意思?
  18. unix/linux 系统 进程资源限制参数
  19. 绘一幅人人出彩的教育画卷
  20. iphone 代码片段2

热门文章

  1. Spring Cloud Alibaba入门实践(五)-远程调用Feign
  2. ie浏览器点击F12没反应
  3. python能用来制作游戏吗_python 做游戏开发怎么样?
  4. c语言读取midi文件举例子,c# – 使用NAudio从MIDI文件中读取音符
  5. java计算机毕业设计科普网站源码+mysql数据库+系统+lw文档+部署
  6. 深入了解JVM之垃圾回收(二)
  7. 阿里10W字JAVA面试手册(面试题+简历攻略)
  8. flash迷宫游戏教程
  9. vue-awesome-swiper 传参控制滑动位置 滚动位置 slideTo 备注防止后期忘记
  10. unity开发抽奖系统