注意:使用tensorflow2.3以上的版本

参考网址:https://blog.csdn.net/ECHOSON/article/details/117600329

通过爬取百度的图片,准备数据集,会保存在data文件中

import requests
import re
import osheaders = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):name_1 = os.getcwd()name_2 = os.path.join(name_1, 'data/' + name)url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)res = requests.get(url, headers=headers)htlm_1 = res.content.decode()a = re.findall('"objURL":"(.*?)",', htlm_1)if not os.path.exists(name_2):os.makedirs(name_2)for b in a:try:b_1 = re.findall('https:(.*?)&', b)b_2 = ''.join(b_1)if b_2 not in list_1:num = num + 1img = requests.get(b)f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')print('---------正在下载第' + str(num) + '张图片----------')f.write(img.content)f.close()list_1.append(b_2)elif b_2 in list_1:num_1 = num_1 + 1continueexcept Exception as e:print('---------第' + str(num) + '张图片无法下载----------')num_2 = num_2 + 1continueprint('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))

划分数据集

# 作者: 宋老狗
import os
import random
from shutil import copy2def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.0, test_scale=0.2):'''读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行:param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data:param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data:param train_scale: 训练集比例:param val_scale: 验证集比例:param test_scale: 测试集比例:return:'''print("开始数据集划分")class_names = os.listdir(src_data_folder)# 在目标目录下创建文件夹split_names = ['train', 'val', 'test']for split_name in split_names:split_path = os.path.join(target_data_folder, split_name)if os.path.isdir(split_path):passelse:os.mkdir(split_path)# 然后在split_path的目录下创建类别文件夹for class_name in class_names:class_split_path = os.path.join(split_path, class_name)if os.path.isdir(class_split_path):passelse:os.mkdir(class_split_path)# 按照比例划分数据集,并进行数据图片的复制# 首先进行分类遍历for class_name in class_names:current_class_data_path = os.path.join(src_data_folder, class_name)current_all_data = os.listdir(current_class_data_path)current_data_length = len(current_all_data)current_data_index_list = list(range(current_data_length))random.shuffle(current_data_index_list)train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)train_stop_flag = current_data_length * train_scaleval_stop_flag = current_data_length * (train_scale + val_scale)current_idx = 0train_num = 0val_num = 0test_num = 0for i in current_data_index_list:src_img_path = os.path.join(current_class_data_path, current_all_data[i])if current_idx <= train_stop_flag:copy2(src_img_path, train_folder)# print("{}复制到了{}".format(src_img_path, train_folder))train_num = train_num + 1elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):copy2(src_img_path, val_folder)# print("{}复制到了{}".format(src_img_path, val_folder))val_num = val_num + 1else:copy2(src_img_path, test_folder)# print("{}复制到了{}".format(src_img_path, test_folder))test_num = test_num + 1current_idx = current_idx + 1print("*********************************{}*************************************".format(class_name))print("{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))print("训练集{}:{}张".format(train_folder, train_num))print("验证集{}:{}张".format(val_folder, val_num))print("测试集{}:{}张".format(test_folder, test_num))if __name__ == '__main__':src_data_folder = "D:/python-project/untitled/data/"  # todo 原始数据集目录target_data_folder = "D:/python-project/untitled/split_data/"  # todo 数据集分割之后存放的目录data_set_split(src_data_folder, target_data_folder)

训练数据集

# -*- coding: utf-8 -*-
# @Time    : 2021/6/17 20:29
# @Author  : dejahu
# @Email   : 1148392984@qq.com
# @File    : train_cnn.py
# @Software: PyCharm
# @Brief   : cnn模型训练代码,训练的代码会保存在models目录下,折线图会保存在results目录下import tensorflow as tf
import matplotlib.pyplot as plt
from time import *# 数据集加载函数,指明数据集的位置并统一处理为imgheight*imgwidth的大小,同时设置batch
def data_load(data_dir, test_data_dir, img_height, img_width, batch_size):# 加载训练集train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,label_mode='categorical',seed=123,image_size=(img_height, img_width),batch_size=batch_size)# 加载测试集val_ds = tf.keras.preprocessing.image_dataset_from_directory(test_data_dir,label_mode='categorical',seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_names# 返回处理之后的训练集、验证集和类名return train_ds, val_ds, class_names# 构建CNN模型
def model_load(IMG_SHAPE=(224, 224, 3), class_num=12):# 搭建模型model = tf.keras.models.Sequential([# 对模型做归一化的处理,将0-255之间的数字统一处理到0到1之间tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),# 卷积层,该卷积层的输出为32个通道,卷积核的大小是3*3,激活函数为relutf.keras.layers.Conv2D(32, (3, 3), activation='relu'),# 添加池化层,池化的kernel大小是2*2tf.keras.layers.MaxPooling2D(2, 2),# Add another convolution# 卷积层,输出为64个通道,卷积核大小为3*3,激活函数为relutf.keras.layers.Conv2D(64, (3, 3), activation='relu'),# 池化层,最大池化,对2*2的区域进行池化操作tf.keras.layers.MaxPooling2D(2, 2),# 将二维的输出转化为一维tf.keras.layers.Flatten(),# The same 128 dense layers, and 10 output layers as in the pre-convolution example:tf.keras.layers.Dense(128, activation='relu'),# 通过softmax函数将模型输出为类名长度的神经元上,激活函数采用softmax对应概率值tf.keras.layers.Dense(class_num, activation='softmax')])# 输出模型信息model.summary()# 指明模型的训练参数,优化器为sgd优化器,损失函数为交叉熵损失函数model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])# 返回模型return model# 展示训练过程的曲线
def show_loss_acc(history):# 从history中提取模型训练集和验证集准确率信息和误差信息acc = history.history['accuracy']val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']# 按照上下结构将图画输出plt.figure(figsize=(8, 8))plt.subplot(2, 1, 1)plt.plot(acc, label='Training Accuracy')plt.plot(val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.ylabel('Accuracy')plt.ylim([min(plt.ylim()), 1])plt.title('Training and Validation Accuracy')plt.subplot(2, 1, 2)plt.plot(loss, label='Training Loss')plt.plot(val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.ylabel('Cross Entropy')plt.title('Training and Validation Loss')plt.xlabel('epoch')plt.savefig('D:/python-project/untitled/split_data/results_cnn.png', dpi=100)def train(epochs):# 开始训练,记录开始时间begin_time = time()# todo 加载数据集, 修改为你的数据集的路径train_ds, val_ds, class_names = data_load("D:/python-project/untitled/split_data/train","D:/python-project/untitled/split_data/test", 224, 224, 16)print(class_names)# 加载模型model = model_load(class_num=len(class_names))# 指明训练的轮数epoch,开始训练history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)# todo 保存模型, 修改为你要保存的模型的名称model.save("models/cnn_fv.h5")# 记录结束时间end_time = time()run_time = end_time - begin_timeprint('该循环程序运行时间:', run_time, "s")  # 该循环程序运行时间: 1.4201874732# 绘制模型训练过程图show_loss_acc(history)if __name__ == '__main__':train(epochs=15)

通过界面调用模型对新图片进行识别

# -*- coding: utf-8 -*-
# @Time    : 2021/6/17 20:29
# @Author  : dejahu
# @Email   : 1148392984@qq.com
# @File    : window.py
# @Software: PyCharm
# @Brief   : 图形化界面import tensorflow as tf
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
from PIL import Image
import numpy as np
import shutilclass MainWindow(QTabWidget):# 初始化def __init__(self):super().__init__()self.setWindowIcon(QIcon('images/logo.png'))self.setWindowTitle('果蔬识别系统')  # todo 修改系统名称# 模型初始化self.model = tf.keras.models.load_model("models/cnn_fv.h5")  # todo 修改模型名称self.to_predict_name = "images/tim9.jpeg"  # todo 修改初始图片,这个图片要放在images目录下#'土豆', '圣女果', '大白菜', '大葱', '梨', '胡萝卜', '芒果', '苹果', '西红柿', '韭菜', '香蕉', '黄瓜']  # todo 修改类名,这个数组在模型训练的开始会输出self.class_names = ['苹果','香蕉']self.resize(900, 700)self.initUI()# 界面初始化,设置界面布局def initUI(self):main_widget = QWidget()main_layout = QHBoxLayout()font = QFont('楷体', 15)# 主页面,设置组件并在组件放在布局上left_widget = QWidget()left_layout = QVBoxLayout()img_title = QLabel("样本")img_title.setFont(font)img_title.setAlignment(Qt.AlignCenter)self.img_label = QLabel()img_init = cv2.imread(self.to_predict_name)h, w, c = img_init.shapescale = 400 / himg_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)cv2.imwrite("images/show.png", img_show)img_init = cv2.resize(img_init, (224, 224))cv2.imwrite('images/target.png', img_init)self.img_label.setPixmap(QPixmap("images/show.png"))left_layout.addWidget(img_title)left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)left_widget.setLayout(left_layout)right_widget = QWidget()right_layout = QVBoxLayout()btn_change = QPushButton(" 上传图片 ")btn_change.clicked.connect(self.change_img)btn_change.setFont(font)btn_predict = QPushButton(" 开始识别 ")btn_predict.setFont(font)btn_predict.clicked.connect(self.predict_img)label_result = QLabel(' 果蔬名称 ')self.result = QLabel("等待识别")label_result.setFont(QFont('楷体', 16))self.result.setFont(QFont('楷体', 24))right_layout.addStretch()right_layout.addWidget(label_result, 0, Qt.AlignCenter)right_layout.addStretch()right_layout.addWidget(self.result, 0, Qt.AlignCenter)right_layout.addStretch()right_layout.addStretch()right_layout.addWidget(btn_change)right_layout.addWidget(btn_predict)right_layout.addStretch()right_widget.setLayout(right_layout)main_layout.addWidget(left_widget)main_layout.addWidget(right_widget)main_widget.setLayout(main_layout)# 关于页面,设置组件并把组件放在布局上about_widget = QWidget()about_layout = QVBoxLayout()about_title = QLabel('欢迎使用果蔬识别系统')  # todo 修改欢迎词语about_title.setFont(QFont('楷体', 18))about_title.setAlignment(Qt.AlignCenter)about_img = QLabel()about_img.setPixmap(QPixmap('images/bj.jpg'))about_img.setAlignment(Qt.AlignCenter)label_super = QLabel("作者:dejahu")  # todo 更换作者信息label_super.setFont(QFont('楷体', 12))# label_super.setOpenExternalLinks(True)label_super.setAlignment(Qt.AlignRight)about_layout.addWidget(about_title)about_layout.addStretch()about_layout.addWidget(about_img)about_layout.addStretch()about_layout.addWidget(label_super)about_widget.setLayout(about_layout)# 添加注释self.addTab(main_widget, '主页')self.addTab(about_widget, '关于')self.setTabIcon(0, QIcon('images/主页面.png'))self.setTabIcon(1, QIcon('images/关于.png'))# 上传并显示图片def change_img(self):openfile_name = QFileDialog.getOpenFileName(self, 'chose files', '','Image files(*.jpg *.png *jpeg)')  # 打开文件选择框选择文件img_name = openfile_name[0]  # 获取图片名称if img_name == '':passelse:target_image_name = "images/tmp_up." + img_name.split(".")[-1]  # 将图片移动到当前目录shutil.copy(img_name, target_image_name)self.to_predict_name = target_image_nameimg_init = cv2.imread(self.to_predict_name)  # 打开图片h, w, c = img_init.shapescale = 400 / himg_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)  # 将图片的大小统一调整到400的高,方便界面显示cv2.imwrite("images/show.png", img_show)img_init = cv2.resize(img_init, (224, 224))  # 将图片大小调整到224*224用于模型推理cv2.imwrite('images/target.png', img_init)self.img_label.setPixmap(QPixmap("images/show.png"))self.result.setText("等待识别")# 预测图片def predict_img(self):img = Image.open('images/target.png')  # 读取图片img = np.asarray(img)  # 将图片转化为numpy的数组outputs = self.model.predict(img.reshape(1, 224, 224, 3))  # 将图片输入模型得到结果result_index = int(np.argmax(outputs))result = self.class_names[result_index]  # 获得对应的水果名称self.result.setText(result)  # 在界面上做显示# 界面关闭事件,询问用户是否关闭def closeEvent(self, event):reply = QMessageBox.question(self,'退出',"是否要退出程序?",QMessageBox.Yes | QMessageBox.No,QMessageBox.No)if reply == QMessageBox.Yes:self.close()event.accept()else:event.ignore()if __name__ == "__main__":app = QApplication(sys.argv)x = MainWindow()x.show()sys.exit(app.exec_())

深度学习应用2——果蔬分类应用相关推荐

  1. 吴恩达深度学习第二周+二分类应用+猫图片识别

    由于最近在看吴恩达老师深度学习的课程,在第二周有一个关于猫图片识别的习题,下面将自己的一些体会和代码分享. 有关数据集的下载可以自行百度. 下载好数据集之后会发现是一个.h5的文件.所以我们首先导入 ...

  2. 基于深度学习的文本分类应用!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:罗美君,算法工程师,Datawhale优秀学习者 在基于机器学习的 ...

  3. 【NLP】基于深度学习的文本分类应用

    作者:罗美君,算法工程师,Datawhale优秀学习者 在基于机器学习的文本分类中,我们介绍了几种常见的文本表示方法:One-hot.Bags of Words.N-gram.TF-IDF.这些方法存 ...

  4. 专治“炼丹侠”各种不服:1分钟就能搞个AI应用 | 最新开源深度学习框架工具套件TinyMS问世...

    贾浩楠 金磊 发自 凹非寺 量子位 报道 | 公众号 QbitAI "炼丹侠"们苦当前深度学习框架久矣. 本来,AI框架的初衷是简化.加速和优化开发流程.但是轮子这么多,有从学界走 ...

  5. 【深度学习】详解集成学习的投票和Stacking机制

    [深度学习]详解集成学习的投票和Stacking机制 文章目录 1 基础原理1.1 硬投票1.2 软投票 2 pytorch综合多个弱分类器,投票机制,进行手写数字分类(boosting) 3 Sta ...

  6. 资深算法工程师万宫玺:Java工程师转型AI的秘密法宝——深度学习框架Deeplearning4j | 分享总结

    资深算法工程师万宫玺:Java工程师转型AI的秘密法宝--深度学习框架Deeplearning4j | 分享总结 本文作者:杨文 2018-01-02 11:03 导语:一文读懂深度学习框架Deepl ...

  7. 超详细!使用OpenCV深度学习模块在图像分类下的应用实践

    专注计算机视觉前沿资讯和技术干货 微信公众号:极市平台 官网:https://www.cvmart.net/ 极市导读:本文来自6月份出版的新书<OpenCV深度学习应用与性能优化实践>, ...

  8. Andrew Ng深度学习课程笔记

    摘要: 本文对Andrew Ng深度学习课程进行了大体的介绍与总结,共包括21个课程. 我最近在Coursera上完成了Andrew Ng导师关于新深度学习的所有课程.Ng在解释术语和概念方面做得非常 ...

  9. 南京的学员看过来 | NVIDIA DLI深度学习入门培训

    NVIDIA 深度学习学院:聚焦于人工智能和深度学习,致力于在全世界范围内提供支持,解决最具挑战性的问题.通过线上自主学习,或者线 下由 NVIDIA 认证的讲师来教授的培训课程,我们可以帮助开发者. ...

最新文章

  1. java课程心得_javaweb课程心得体会(三)
  2. 搭建一个免费的,无限流量的Blog----github Pages和Jekyll入门
  3. 入职3个月的Java程序员面临转正,挑战大厂重燃激情!
  4. 【LeetCode 剑指offer刷题】树题16:Kth Smallest Element in a BST
  5. Python爬取房天下租房信息实战
  6. MySQL(11)数据库实现高可用架构之MMM
  7. java 计算信度,11.5.2 评分者信度实例分析
  8. Mysql笔记——DQL
  9. SCPPO(二十六):测算过程中问题的解决总结
  10. 未在本地计算机上注册“OraOLEDB.Oracle”提供程序
  11. vue学习笔记-10-常用特性之表单操作
  12. Visio画图允许两条线交叉的操作
  13. 网页打不开显示php探针,phpinfo被禁用,可用php探针
  14. pdf关键字搜索盖章,长关键字定位
  15. cpi计算机性能指标,将CPU时间=(CPI指令总数).ppt
  16. 华为路由器AAA配置与管理
  17. 初始化云硬盘切换云主机挂载验证lvm跨主机可读
  18. mybatis(二)xml配置文件详细说明
  19. ansible防火墙firewalld设置
  20. Java虚拟机内存的堆区(heap),栈区(stack)和静态区(static/method)

热门文章

  1. 信息学奥赛一本通超详细题解,动画图文题解
  2. android安卓智能穿戴仿苹果手表界面的源码效果
  3. cd4066典型应用电路图(双向模拟开关/电子开关/音响发声电路图详解) - 全文
  4. php+html配合方式小结
  5. kettle任务在Linux服务器上定时调度
  6. sleuth原理详解
  7. ardupilot 上实现ADRC内环角速度控制
  8. 解决谷歌卸载后安装无反应问题
  9. libevent_Rector模式
  10. 关于户外旋转LED显示屏的核心技术