python model如何获取分类错误的数据_使用CNN和Keras进行95%准确度的交通标志识别的Python项目
Python项目–交通标志识别
您一定已经听说过自动驾驶汽车,乘客可以在其中完全依靠汽车行驶。但是要实现5级自动驾驶,车辆必须了解并遵守所有交通规则。
在人工智能和技术进步的世界中,许多研究人员和大公司,例如特斯拉,优步,谷歌,奔驰,丰田,福特,奥迪等,都在研究自动驾驶汽车和无人驾驶汽车。因此,为了实现该技术的准确性,车辆应该能够分辨交通标志并做出相应的决策。
什么是交通标志识别?
有几种不同类型的交通标志,例如限速,禁止进入,交通信号,左转或右转,行人横穿,重型车辆无法通过等。交通标志分类是识别交通标志属于哪个类别的过程。
交通标志识别–关于Python项目
在这个Python项目示例中,我们将构建一个深度神经网络模型,该模型可以将图像中出现的交通标志分类为不同的类别。通过这种模型,我们能够分辨和分析交通标志,这对于所有自动驾驶汽车都是非常重要的任务。
Python项目的数据集
对于此项目,我们使用Kaggle上可用的公共数据集:
(https://www.kaggle.com/meowmeowmeowmeowmeow/gtsrb-german-traffic-sign)
数据集包含50,000多种不同交通标志的图像。它进一步分为43个不同的类别。数据集变化很大,有些类别的图像很多,而有些类别的图像很少。数据集的大小约为300 MB。数据集有一个train文件夹,其中包含每个类中的图像;还有一个test文件夹,我们将使用它来测试模型。
先决条件
该项目需要具备Keras,Matplotlib,Scikit-learn,Pandas,PIL和图像分类的先验知识。
要安装用于此Python数据科学项目的必要库,请在终端中输入以下命令:
pip install tensorflow keras sklearn matplotlib pandas pillow
生成Python项目的步骤
创建一个Python脚本文件,并在项目文件夹中将其命名为traffic_signs.py。
我们分四个步骤讨论了我们建立交通标志分类模型的方法:
- 探索数据集
- 建立CNN模型
- 训练并验证模型
- 用测试数据集测试模型
步骤1:探索数据集
我们的“train”文件夹包含43个文件夹,每个文件夹代表一个不同的类别。文件夹的范围是0到42。在OS模块的帮助下,我们遍历所有类,并将图像及其各自的标签附加到数据和标签列表中。
PIL库用于将图像内容打开到数组中。
最后,我们将所有图像及其标签存储到列表(数据和标签)中。
我们需要将列表转换为numpy数组,以便提供给模型。
数据的形状为(39209、30、30、3),这意味着有39,209张图像的尺寸为30×30像素,最后的3意味着数据包含彩色图像(RGB值)。
通过sklearn库,我们使用train_test_split()方法来拆分训练和测试数据。
从keras.utils库中,我们使用to_categorical方法将y_train和t_test中存在的标签转换为独热编码。
步骤2:建立CNN模型
为了将图像分类为各自的类别,我们将建立一个CNN模型(卷积神经网络)。CNN最适合用于图像分类。
我们模型的架构是:
- 2个Conv2D层(过滤器= 32,kernel_size =(5,5),激活=“ relu”)
- MaxPool2D层(pool_size =(2,2))
- Dropout层(速率= 0.25)
- 2个Conv2D层(过滤器= 64,kernel_size =(3,3),激活=“ relu”)
- MaxPool2D层(pool_size =(2,2))
- Dropout层(速率= 0.25)
- 将层压平,将层压缩成一维
- 密集的全连接层(256个节点,激活=“ relu”)
- Dropout层(速率= 0.5)
- 密集层(43个节点,激活=“ softmax”)
我们使用性能良好的Adam优化器编译模型,并且损失为“ categorical_crossentropy”,因为我们有多个要分类的类。
步骤3:训练和验证模型
构建模型架构后,我们然后使用model.fit()训练模型。我尝试使用32和64的批处理量。我们的模型在64批量下表现更好。并且在15个周期之后,准确性是稳定的。
我们的模型在训练数据集上的准确率达到95%。使用matplotlib,我们可以绘制图形以确保准确性和损失。
绘制精度图:
精度和损失图
步骤4:使用测试数据集测试我们的模型
我们的数据集包含一个测试文件夹,并且在test.csv文件中,我们具有与图像路径及其各自的类标签有关的详细信息。我们使用pandas提取图像路径和标签。然后要预测模型,我们必须将图像尺寸调整为30×30像素,并创建一个包含所有图像数据的numpy数组。从sklearn.metrics中,我们导入了precision_score并观察了我们的模型如何预测实际标签。在此模型中,我们达到了95%的精度。
最后,我们将使用Keras model.save()函数保存我们训练过的模型。
model.save('traffic_classifier.h5')
完整的源代码:
import numpy as np import pandas as pd import matplotlib.pyplot as pltimport cv2import tensorflow as tffrom PIL import Imageimport osfrom sklearn.model_selection import train_test_splitfrom keras.utils import to_categoricalfrom keras.models import Sequential, load_modelfrom keras.layers import Conv2D, MaxPool2D, Dense, Flatten, Dropoutdata = []labels = []classes = 43cur_path = os.getcwd()#Retrieving the images and their labels for i in range(classes): path = os.path.join(cur_path,'train',str(i)) images = os.listdir(path) for a in images: try: image = Image.open(path + ''+ a) image = image.resize((30,30)) image = np.array(image) #sim = Image.fromarray(image) data.append(image) labels.append(i) except: print("Error loading image")#Converting lists into numpy arraysdata = np.array(data)labels = np.array(labels)print(data.shape, labels.shape)#Splitting training and testing datasetX_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)print(X_train.shape, X_test.shape, y_train.shape, y_test.shape)#Converting the labels into one hot encodingy_train = to_categorical(y_train, 43)y_test = to_categorical(y_test, 43)#Building the modelmodel = Sequential()model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu', input_shape=X_train.shape[1:]))model.add(Conv2D(filters=32, kernel_size=(5,5), activation='relu'))model.add(MaxPool2D(pool_size=(2, 2)))model.add(Dropout(rate=0.25))model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))model.add(Conv2D(filters=64, kernel_size=(3, 3), activation='relu'))model.add(MaxPool2D(pool_size=(2, 2)))model.add(Dropout(rate=0.25))model.add(Flatten())model.add(Dense(256, activation='relu'))model.add(Dropout(rate=0.5))model.add(Dense(43, activation='softmax'))#Compilation of the modelmodel.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])epochs = 15history = model.fit(X_train, y_train, batch_size=32, epochs=epochs, validation_data=(X_test, y_test))model.save("my_model.h5")#plotting graphs for accuracy plt.figure(0)plt.plot(history.history['acc'], label='training accuracy')plt.plot(history.history['val_acc'], label='val accuracy')plt.title('Accuracy')plt.xlabel('epochs')plt.ylabel('accuracy')plt.legend()plt.show()plt.figure(1)plt.plot(history.history['loss'], label='training loss')plt.plot(history.history['val_loss'], label='val loss')plt.title('Loss')plt.xlabel('epochs')plt.ylabel('loss')plt.legend()plt.show()#testing accuracy on test datasetfrom sklearn.metrics import accuracy_scorey_test = pd.read_csv('Test.csv')labels = y_test["ClassId"].valuesimgs = y_test["Path"].valuesdata=[]for img in imgs: image = Image.open(img) image = image.resize((30,30)) data.append(np.array(image))X_test=np.array(data)pred = model.predict_classes(X_test) #Accuracy with the test datafrom sklearn.metrics import accuracy_scoreprint(accuracy_score(labels, pred))model.save('traffic_classifier1.h5')
交通标志分类器GUI
现在,我们将使用Tkinter为我们的交通标志分类器构建图形用户界面。Tkinter是标准python库中的GUI工具包。在项目文件夹中创建一个新文件,然后复制以下代码。将其另存为gui.py,您可以通过在命令行中键入python gui.py来运行代码。
在此文件中,我们首先使用Keras加载了经过训练的模型'traffic_classifier.h5'。然后,我们构建用于上传图像的GUI,并使用一个按钮进行分类,然后调用classify()函数。classify()函数将图像转换为形状的尺寸(1、30、30、3)。这是因为要预测交通标志,我们必须提供构建模型时使用的相同尺寸。然后我们预测类,model.predict_classes(image)向我们返回一个介于(0-42)之间的数字,该数字表示它所属的类。我们使用字典来获取有关该类的信息。这是gui.py文件的代码。
import tkinter as tkfrom tkinter import filedialogfrom tkinter import *from PIL import ImageTk, Imageimport numpy#load the trained model to classify signfrom keras.models import load_modelmodel = load_model('traffic_classifier.h5')#dictionary to label all traffic signs class.classes = { 1:'Speed limit (20km/h)', 2:'Speed limit (30km/h)', 3:'Speed limit (50km/h)', 4:'Speed limit (60km/h)', 5:'Speed limit (70km/h)', 6:'Speed limit (80km/h)', 7:'End of speed limit (80km/h)', 8:'Speed limit (100km/h)', 9:'Speed limit (120km/h)', 10:'No passing', 11:'No passing veh over 3.5 tons', 12:'Right-of-way at intersection', 13:'Priority road', 14:'Yield', 15:'Stop', 16:'No vehicles', 17:'Veh > 3.5 tons prohibited', 18:'No entry', 19:'General caution', 20:'Dangerous curve left', 21:'Dangerous curve right', 22:'Double curve', 23:'Bumpy road', 24:'Slippery road', 25:'Road narrows on the right', 26:'Road work', 27:'Traffic signals', 28:'Pedestrians', 29:'Children crossing', 30:'Bicycles crossing', 31:'Beware of ice/snow', 32:'Wild animals crossing', 33:'End speed + passing limits', 34:'Turn right ahead', 35:'Turn left ahead', 36:'Ahead only', 37:'Go straight or right', 38:'Go straight or left', 39:'Keep right', 40:'Keep left', 41:'Roundabout mandatory', 42:'End of no passing', 43:'End no passing veh > 3.5 tons' }#initialise GUItop=tk.Tk()top.geometry('800x600')top.title('Traffic sign classification')top.configure(background='#CDCDCD')label=Label(top,background='#CDCDCD', font=('arial',15,'bold'))sign_image = Label(top)def classify(file_path): global label_packed image = Image.open(file_path) image = image.resize((30,30)) image = numpy.expand_dims(image, axis=0) image = numpy.array(image) pred = model.predict_classes([image])[0] sign = classes[pred+1] print(sign) label.configure(foreground='#011638', text=sign) def show_classify_button(file_path): classify_b=Button(top,text="Classify Image",command=lambda: classify(file_path),padx=10,pady=5) classify_b.configure(background='#364156', foreground='white',font=('arial',10,'bold')) classify_b.place(relx=0.79,rely=0.46)def upload_image(): try: file_path=filedialog.askopenfilename() uploaded=Image.open(file_path) uploaded.thumbnail(((top.winfo_width()/2.25),(top.winfo_height()/2.25))) im=ImageTk.PhotoImage(uploaded) sign_image.configure(image=im) sign_image.image=im label.configure(text='') show_classify_button(file_path) except: pass upload=Button(top,text="Upload an image",command=upload_image,padx=10,pady=5)upload.configure(background='#364156', foreground='white',font=('arial',10,'bold'))upload.pack(side=BOTTOM,pady=50)sign_image.pack(side=BOTTOM,expand=True)label.pack(side=BOTTOM,expand=True)heading = Label(top, text="Know Your Traffic Sign",pady=20, font=('arial',20,'bold'))heading.configure(background='#CDCDCD',foreground='#364156')heading.pack()top.mainloop()
输出:
运行后界面
选择图片后的分析
摘要
在这个Python项目中,我们已经成功地以95%的准确度对交通标志分类器进行了分类,并且还可视化了我们的准确度和损失随时间的变化,这对于简单的CNN模型来说是相当不错的。
python model如何获取分类错误的数据_使用CNN和Keras进行95%准确度的交通标志识别的Python项目相关推荐
- 【随记】Python:前端表格获取到的填写数据插入到数据库表格中数据类型问题
Python:前端表格获取到的填写数据插入到数据库表格中数据类型问题 背景 问题再现 结论 背景 用户在前端界面的表格中填写数据,通过 text() 获取到的数据插入到数据库表中,该过程涉及到了数据类 ...
- 一篇文章教会你利用Python网络爬虫获取分类图片
点击上方"IT共享之家",进行关注 回复"资料"可获赠Python学习福利 [一.项目背景] 博海拾贝是一支互联网从业者在线教育的团队,扎根于中国教育行业以及互 ...
- python带你获取视频及弹幕数据~知识点满满(含完整源代码)
前言 嗨喽!大家好呀,这里是魔王~** 模块安装问题: 如果安装python第三方模块: win + R 输入 cmd 点击确定, 输入安装命令 pip install 模块名 (pip instal ...
- 【Python】爬虫获取微博热搜数据,response中文显示“\u7814\u7a76\u8bc1\u5b9e\u”
问题描述 在爬虫获取微博热搜数据的时候,response中文出现了不便于理解的字段,截取如下: ......[{"title_sub":"\u7814\u7a76\u8b ...
- 关于python变量使用下列说法中错误的是_关于Python内存管理,下列说法错误的是_学小易找答案...
[单选题]Python 编程中用代码缩进表示逻辑递进关系,通常用几个空格 [判断题]决定系数(英语:coefficient of determination,记为R2或r2)在统计学中用于度量因变量的 ...
- 关于python内存管理下列说法中错误的是_2.关于Python内存管理,下列说法错误的是_学小易找答案...
[单选题]6.下列表达式的值为True的是 [单选题]2.关于Python内存管理,下列说法错误的是 [单选题]890.具有很强的储存功能,存储空间比较大的配送中心所属的类型是( ) [多选题]109 ...
- sql 获取两个月内数据_如何在3个月的时间内自学成为数据分析师?
从一名0基础的用户运营自学成为数据分析师,我花了大半年的时间,但是抛开工作时间,系统性的学习只花了3个月. 这篇文章会从学习资源和学习路径两个方面分享我的自学经验,希望能对大家有所帮助. 先来说说有哪 ...
- 机智云获取树莓派传来的数据_哪些数据对云来说太冒险了?
机智云获取树莓派传来的数据 在这个由四部分组成的系列文章中,我们一直在研究每个组织在将操作迁移到云时(特别是在混合多云环境中)应避免的陷阱. 在第一部分中 ,我们介绍了混合云和多云的基本定义以及我们的 ...
- 利用Python网络爬虫获取分类图片,简单处理反爬教学
本文的文字及图片来源于网络,仅供学习.交流使用,不具有任何商业用途,版权归原作者所有,如有问题请及时联系我们以作处理 本文章来自腾讯云 作者:Python进阶者 想要学习Python?有问题得不到第一 ...
最新文章
- iOS架构-静态库.framework(引用第三方SDK、开源库、资源包)(9)
- 但是尚未从池中获取连接_SQLServer超时时间已到,但是尚未从池中获取连接
- ubuntu pip更新_Cubietruck开发板折腾002:安装Python管理工具pip
- MSU发布2018年视频压缩评比报告
- IOS Swift5.5的通知写法
- React-引领未来的用户界面开发框架-读书笔记(七)
- ubuntu14.04上网问题
- CTF工具-seccomp-tools
- HDU 1476 Sudoku Killer
- Flutter AnimatedBuilder 的基本使用
- SpringBoot 实现Session共享
- 作战手册-2011-12-18
- pytest.mark.parametrize()基本用法
- java.net.SocketException: Software caused connection abort: socket write error
- oppoJava面试!java开发视频聊天
- 每年考证时间表(绝对有用)
- Typora设置图片背景
- 设置iSCSI的发起程序(客户端)(三)
- 54、消防控制室的设置要求
- android 换肤 字体颜色,Android换肤