配合视频一起食用这篇教程效果更佳:手把手教你用tensorflow2训练自己的数据集

tensorflow2.x版本对小白非常友好,2.x的api中对keras进行了合并,大家只需要安装tensorflow就可以使用里面封装好的keras,利用keras可以快速地加载数据集和构建模型,下面我们直接来看以下通过tensorflow2.3训练自己的分类数据集吧。

注:本文主要针对图片形式的数据集构建分类模型,文本数据、目标检测等任务暂不涉及。

本文使用到的代码已在码云上开源,请大家自行下载,star不迷路:

vegetables_tf2.3: 基于tensorflow2.3开发的水果蔬菜识别系统 (gitee.com)

另外我这边整理了一些物体分类的数据集,大家根据需要下载:

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

数据集收集

数据集收集主要有3种方式,一种是使用某些机构或者组织开源出来的数据集,另一种是自己通过拍照或者爬虫的方式来自行获取数据集,还有一种是热心网友自己采集整理之后的数据集,下面的csdn链接中我给出了一些我整理的数据集,大家可以根据自己的需要下载使用。

计算机视觉数据集清单-附赠tensorflow模型训练和使用教程_dejavu的博客-CSDN博客

开源数据集

开源的分类数据集一般质量相对较好,数据集的所有者在发布前对数据集做了整理和清洗,直接使用开源的数据集可以帮助我们节省大量的时间,比较有名的有mnist数据集、cifar数据集等,另外大家可以在一些网站中寻找数据集,比如下列的几个网站:

和鲸社区 - Heywhale.com

UCI Machine Learning Repository

CSDN - 专业开发者社区

另外你也可以直接在搜索引擎中输入关键字来寻找数据集,比如你想要寻找垃圾分类的数据集,你可以在搜索栏中输入垃圾 分类 数据集等关键字来直接查找,一般会有热心的网友给出数据集的链接,下载即可。

自行采集数据集

如果找不到相应的开源数据集,你也可以通过自己采集的方式来获取数据集,比如你可以通过拍照的方式来搜集你自己所需的数据集,或者是通过爬虫的方式来搜集数据集,这里有段爬虫爬取百度图片的代码,大家直接执行,输入自己想要爬取的图片名称和图片数量,即可爬取相应的图片,代码如下:

import requests
import re
import osheaders = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):name_1 = os.getcwd()name_2 = os.path.join(name_1, 'data/' + name)url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)res = requests.get(url, headers=headers)htlm_1 = res.content.decode()a = re.findall('"objURL":"(.*?)",', htlm_1)if not os.path.exists(name_2):os.makedirs(name_2)for b in a:try:b_1 = re.findall('https:(.*?)&', b)b_2 = ''.join(b_1)if b_2 not in list_1:num = num + 1img = requests.get(b)f = open(os.path.join(name_1, 'data/' + name, name + str(num) + '.jpg'), 'ab')print('---------正在下载第' + str(num) + '张图片----------')f.write(img.content)f.close()list_1.append(b_2)elif b_2 in list_1:num_1 = num_1 + 1continueexcept Exception as e:print('---------第' + str(num) + '张图片无法下载----------')num_2 = num_2 + 1continueprint('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))

数据集整理

放置到相应的子文件夹

数据集收集完成之后,我们还需要对数据集进行整理,如果是爬虫爬取的图片可能会有一些质量比较差的图片,那么整理之前还需要进行数据的清洗,删除质量不好的图片,数据集整理其实很简单,我们只需要将数据集进行归类即可,即相同类别的图片放在一个文件夹下,比如下面的这个数据集,白菜的文件夹下放的全是白菜的图片,土豆的文件夹下则放的全是土豆的图片。

划分训练集和测试集

注:如果是使用的开源数据集,开源数据集可能已经进行了数据集的划分,直接使用即可,不需要再次进行划分,比如这里是我下载到的农作物病虫害的数据集,已经分别提供了训练集、测试集和验证集,就不需要再次进行数据集的划分。

为了方便我们进行数据集的加载,我们还需要将图片划分为训练集和测试集,如果需要的话你还需要划分出验证集,验证集在一般的任务中是可选的,因为是自己收集的数据集的话,数据量比较少,如果再划分验证集的话可能会导致训练量不够,这里我写了一段数据集划分的代码逻辑,大家输入原始的数据集位置和划分之后的数据集位置,指定数据集划分的比例,即可完成数据集的划分。

# 作者: 宋老狗
import os
import random
from shutil import copy2def data_set_split(src_data_folder, target_data_folder, train_scale=0.8, val_scale=0.0, test_scale=0.2):'''读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行:param src_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/src_data:param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data:param train_scale: 训练集比例:param val_scale: 验证集比例:param test_scale: 测试集比例:return:'''print("开始数据集划分")class_names = os.listdir(src_data_folder)# 在目标目录下创建文件夹split_names = ['train', 'val', 'test']for split_name in split_names:split_path = os.path.join(target_data_folder, split_name)if os.path.isdir(split_path):passelse:os.mkdir(split_path)# 然后在split_path的目录下创建类别文件夹for class_name in class_names:class_split_path = os.path.join(split_path, class_name)if os.path.isdir(class_split_path):passelse:os.mkdir(class_split_path)# 按照比例划分数据集,并进行数据图片的复制# 首先进行分类遍历for class_name in class_names:current_class_data_path = os.path.join(src_data_folder, class_name)current_all_data = os.listdir(current_class_data_path)current_data_length = len(current_all_data)current_data_index_list = list(range(current_data_length))random.shuffle(current_data_index_list)train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)train_stop_flag = current_data_length * train_scaleval_stop_flag = current_data_length * (train_scale + val_scale)current_idx = 0train_num = 0val_num = 0test_num = 0for i in current_data_index_list:src_img_path = os.path.join(current_class_data_path, current_all_data[i])if current_idx <= train_stop_flag:copy2(src_img_path, train_folder)# print("{}复制到了{}".format(src_img_path, train_folder))train_num = train_num + 1elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):copy2(src_img_path, val_folder)# print("{}复制到了{}".format(src_img_path, val_folder))val_num = val_num + 1else:copy2(src_img_path, test_folder)# print("{}复制到了{}".format(src_img_path, test_folder))test_num = test_num + 1current_idx = current_idx + 1print("*********************************{}*************************************".format(class_name))print("{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length))print("训练集{}:{}张".format(train_folder, train_num))print("验证集{}:{}张".format(val_folder, val_num))print("测试集{}:{}张".format(test_folder, test_num))if __name__ == '__main__':src_data_folder = "C:/Users/Scm97/Desktop/dejahu/data"  # todo 原始数据集目录target_data_folder = "C:/Users/Scm97/Desktop/dejahu/split_data"  # todo 数据集分割之后存放的目录data_set_split(src_data_folder, target_data_folder)

注:路径中最好不要出现中文

数据集划分之后,记住训练集和测试集的位置,接下来,我们就可以开始训练我们的模型了。

下面以花卉识别,我给大家演示一下,data是演示目录,目录下存放的是5个子文文件夹,对应5种花卉,每个子文件夹下存放了相应的花卉图片,split_data是新建的空文件夹,用于存放分割之后的数据集,这时候只需要修改代码种的两处即可。

代码默认训练集占80%,测试集占20%,修改完成之后右键直接执行即可。

执行之后你就可以得到划分好的数据集

这个时候记住训练集和测试集的目录,开始大干一场吧。

测试集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/train

训练集目录为:C:/Users/Scm97/Desktop/dejahu/split_data/test

环境搭建

本次教程需要大家实现配置好python的环境,我们需要使用到anaconda和pycharm,不熟悉环境配置的同学可以看我得这篇博客,我在这里就不再进行赘述了。

如何在pycharm中配置anaconda的虚拟环境_dejavu的博客-CSDN博客

训练模型

模型训练的代码种,以cnn模型的训练为例,train_cnn.py是训练cnn模型的代码,只需要修改三处即可,如下所示

train_mobilnet.py是训练mobilenet模型的代码,训练的模型将会保存在models目录下,这里也是只需修改三处即可。

注:代码最后一行的epochs指的是跑的训练的轮数,这里默认是30,大家可以根据自己的需要增加或减少训练的轮数

修改之后直接运行即可,等代码跑完后模型就会保存在models目录下

另外,在results目录下你可以找到模型训练的过程图

模型训练的过程中会输出数据集的类名,这里记录一下,在后面的模型使用中会用到。

测试模型

模型的测试的代码为test_model.py,也是只需要改动几处代码即可完成测试

改动如下:

测试的基本流程是:加载数据、加载模型、测试、保存结果

测试之后在命令行中会输出每个模型的准确率,并且会在results目录下生成相应的热力图


热力图中对应了每个类别的准确率,如下所示,是mobilenet测试的热力图。

使用模型

模型的时候中,我们通过Pyqt5来构建图形化界面,用户可以上传图片,并在系统中调用我们训练好的模型进行图片类别的预测。

window.py代码中修改四处即可完成基本功能

启动看看吧!

快去试试你自己的数据集吧!

手把手教你用tensorflow2.3训练自己的分类数据集相关推荐

  1. 手把手教你用YOLOv5算法训练数据和检测目标(不会你捶我)

    前言 本人从一个小白,一路走来,已能够熟练使用YOLOv5算法来帮助自己解决一些问题,早就想分析一下自己的学习心得,一直没有时间,最近工作暂时告一段落,今天抽空写点东西,一是为自己积累一些学习笔记,二 ...

  2. bert 是单标签还是多标签 的分类_搞定NLP领域的“变形金刚”!手把手教你用BERT进行多标签文本分类...

    大数据文摘出品 来源:medium 编译:李雷.睡不着的iris.Aileen 过去的一年,深度神经网络的应用开启了自然语言处理的新时代.预训练模型在研究领域的应用已经令许多NLP项目的最新成果产生了 ...

  3. 手把手教你进行安全帽的佩戴检测(附数据集+代码演示+实验结果)

    目录: 一.数据集和代码的准备 二.训练过程 三.结果演示 那让我们开始吧! 一.数据集和代码的准备 数据集:链接:https://pan.baidu.com/s/1tN7g26s8DRgAKrn6F ...

  4. 手把手教你洞悉 PyTorch 模型训练过程,彻底掌握 PyTorch 项目实战!(文末重金招聘导师)...

    (文末重金招募导师) 在CVPR 2020会议接收中,PyTorch 使用了405次,TensorFlow 使用了102次,PyTorch使用数是TensorFlow的近4倍. 自2019年开始,越来 ...

  5. 手把手教你在百度aistuio训练人工智能模型

    我在上文中介绍了百度的在线AI模型创建平台aistudio(博客地址:https://blog.csdn.net/BEYONDMA/article/details/101762116),我们知道深度学 ...

  6. 超详细!手把手教你使用YOLOX进行物体检测(附数据集)

    点击下方卡片,关注3D视觉工坊公众号 3D视觉干货第一时间送达 作者:王浩,3D视觉开发者社区签约作者,毕业于北京航空航天大学,人工智能领域优质创作者,CSDN博客认证专家. 编辑:3D视觉开发者社区 ...

  7. 【超详细】手把手教你使用YOLOX进行物体检测(附数据集)

    作者:王浩 毕业于北京航空航天大学,人工智能领域优质创作者 编辑:3D视觉开发者社区 ✨如果觉得文章内容不错,别忘了三连支持下哦

  8. TensorFlow2 手把手教你实现自定义层

    TensorFlow2 手把手教你实现自定义层 概述 Sequential Model & Layer 案例 数据集介绍 完整代码 概述 通过自定义网络, 我们可以自己创建网络并和现有的网络串 ...

  9. TensorFlow2 手把手教你避开梯度消失和梯度爆炸

    TensorFlow2 手把手教你避开梯度消失和梯度爆炸 梯度消失 & 梯度爆炸 梯度消失 梯度爆炸 张量限幅 tf.clip_by_value tf.clip_by_norm mnist 展 ...

  10. 手把手教你训练一个秒杀科比的投篮AI,不服来练 | 附开源代码

    原作:Abe Haskins 安妮 编译整理 量子位 出品 | 公众号 QbitAI 在这篇教程中,谷歌工程师Abe Haskins用简洁易懂的语言,教你用Unity3D和TensorFlow生产一只 ...

最新文章

  1. 转载 http://blog.csdn.net/dengta_snowwhite/article/details/6418384
  2. github push时候报错解决方法
  3. python 制作gif-利用Python如何制作好玩的GIF动图详解
  4. JSTL标签之核心标签
  5. 2020年高考西工大附中成绩分析
  6. flutter 局部状态和全局状态区别_Flutter状态管理
  7. jmeter 非gui 模式跑jmx
  8. 浅析文件传输协议 (ftp) 的工作原理
  9. VMware中ubuntu虚拟机与windows的端口映射,共享一个IP地址
  10. 华硕v4000fj笔记本怎么样_所有已开箱笔记本的目录汇总 20200812
  11. 使用计算机的硬件及参数,硬件参数怎么看?如何选配电脑硬件?
  12. hdu 6108 小C的倍数问题
  13. php 工资 2018,2018年PHP程序员的进阶之路
  14. 计算机组成原理10——建立数据通路
  15. 群晖6.1安装php3.6_教程分享 --- jun大神 VMWare虚拟机安装黑群晖 (DSM6.1)
  16. 技术主管和技术总监的区别_技术主管–责任圈
  17. PyAlgoTrade框架研究
  18. 如何用阿里云云盘快照恢复部分数据
  19. Tinymce组件cdn失效解决办法
  20. 如何从Mixamo下载人物模型的动画

热门文章

  1. html表单变灰,excel菜单灰色 excel工具栏突然变灰了 怎么办
  2. 微信公众平台后台接入简明指南
  3. 一周信创舆情观察(5.6~5.9)
  4. editplus编辑器使用-快速开始(editplus通过sftp协议远程编辑文件)
  5. 第K顺序统计量的求解
  6. linux b类地址设24位掩码,子网掩码的设置方法和作用
  7. 位掩码(BitMask)——介绍与使用
  8. 梯度消失和爆炸原因以及解决方法
  9. 密码学基础(数学理论)
  10. Java的8 大基本类型的包装类和美女选妃案例的两种写法