在飞桨平台做图像分类

文章目录

  • 在飞桨平台做图像分类
  • 前言
  • 制作数据集
    • 下载数据集
    • 飞桨数据集
    • 制作飞桨数据集
  • 数据集的加载
  • 完整代码

前言

计划是在寒假时用在飞桨平台上做动物,水果的分类。

制作数据集

代码在文章最后

下载数据集

飞桨有内置数据集和自定义数据集,这里主要是写如何制作自定义数据集。我这里用到的数据集就是第十六届智能车视觉AI组组委会提供的数据集:这里放上百度网盘链接:
动物水果数据集l
提取码:lasl
只需要下载动物水果即可。

飞桨数据集

飞桨有 map-style 的 paddle.io.Dataset 基类 和 iterable-style 的 paddle.io.IterableDataset 基类 ,来完成数据集定义。此外,针对一些特殊的场景,飞桨框架也提供了 paddle.io.TensorDataset 基类,可以直接处理 Tensor 数据为 dataset,一键完成数据集的定义。这里用的是基于paddle.io.Dataset的,也是官方更为推荐使用的API.使用 paddle.io.Dataset,最后会返回一个 map-style 的 Dataset 类。可以用于后续的数据增强、数据加载等。
使用 paddle.io.Dataset 只需要按格式完成以下四步即可。

步骤一:继承paddle.io.IterableDataset类
步骤二:实现构造函数,定义数据读取方式,划分训练和测试数据集
步骤三:实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据(训练数据,对应的标签)
步骤四:实现__len__方法,返回数据集总数目
在下面将会逐步实现这些步骤

制作飞桨数据集

  1. 继承paddle.io.IterableDataset类
import os
import randomimport numpy as np
import cv2 as cv
import matplotlib.pyplot as plt
import paddle
class myreader(paddle.io.Dataset):if __name__ == '__main__':pass
  1. 实现构造函数,定义数据读取方式,划分训练和测试数据集
    将数据集按照数字索引的方式命名,要处理的文件路径是这样的:

    0~9是分别是5个动物和5个水果大类,里面分别有各类png图片100张左右。
    用os.path.join()方法来拼接路径,遍历得到每一张图片的路径。
    然后用os.path.splitext()来判段文件是否是.png后缀的,这个方法会将文件名和文件后缀分开,os.path.splitext(文件路径)[-1]所读取到的就是后缀。
    然后用cv.imread来读取图片,输入参数有两个,一个是路径,一个是读取的色彩选择,可以选择bgr彩图还是灰度图,默认bgr彩图,注意这个顺序是bgr,而不是常见的rgb,如果读取完后直接用plt来显示的话,图像的色彩就会出错。
    所以我们用cv.resize()来对图像的大小以及色彩进行修改,cv.resize(img, (64, 64))[…, (2, 1, 0)],这是指将图像放缩为64×64大小的rgb图片。(这里修改图像大小的步骤不能省去,否则数组形状不一,后面会报错或警告)
    我们需要将整组图片分为训练集以及测试集,这两个集合不能有重叠部分。这里用train_ratio作为参数来指定训练集的占比,然后用train_ratio*10与图片索引相比来决定是训练集还是测试集,这样的话对于相同的一组图片以及相同的train_ratio来说其训练集和测试集完全没有重叠
    具体代码如下:
 """file_path:大文件夹路径train_ratio:训练集的占比 0~1index_num:指大文件夹下每个小分类文件的总数mode: 训练集还是测试集"""def __init__(self,file_path,train_ratio,index_num,mode):super(myreader, self).__init__()# self.all_data=[]# self.all_lable=[]self.mode = modeself.train_data = []self.train_lable = []self.test_data = []self.test_lable = []self.file_path = file_pathself.train_ratio = train_ratiodataindex = 0for i in range(index_num):file_path = os.path.join(self.file_path,'%d' % i)for j in os.listdir(file_path):dataindex = dataindex % 10if(os.path.splitext(j)[-1] == '.png'):img = cv.imread(os.path.join(file_path, j))img = cv.resize(img, (64, 64))[..., (2, 1, 0)]if mode == 'train':if dataindex < train_ratio*10:self.train_data.append(img/255)self.train_lable.append(i)else:if dataindex >= train_ratio*10:self.test_data.append(img)self.test_lable.append(i)dataindex += 1
  1. 实现__getitem__方法,定义指定index时如何获取数据,并返回单条数据
 def __getitem__(self,index):if self.mode == 'train':npData = np.asarray(self.train_data,dtype='float32')[index]npLable = np.asarray(self.train_lable,dtype='int64')[index]return npData,npLableelse:npData = np.asarray(self.test_data,dtype='float32')[index]npLable = np.asarray(self.test_lable,dtype='int64')[index]return npData, npLable
  1. 实现__len__方法,返回数据集总数目
    注意len方法所对应的值要和getitem的值相等,就是说如果要用getitem取训练集的话,相对的len方法也要返回训练集的数目。
    def __len__(self):if self.mode == 'train':return len(self.train_data)else:return len(self.test_data)

数据集的加载

当我们定义了数据集后,就需要加载数据集。我们可以通过 paddle.io.DataLoader 完成数据的加载。

train_dataset = myreader(r'picture',0.7,2,'train')train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)

完整代码

import os
import numpy as np
import cv2 as cv
# import matplotlib.pyplot as plt
import paddleclass myreader(paddle.io.Dataset):"""继承paddle.io.Dataset类""""""file_path:大文件夹路径train_ratio:训练集的占比 0~1index_num:指大文件夹下每个小分类文件的总数mode: 训练集还是测试集"""def __init__(self,file_path,train_ratio,index_num,mode):super(myreader, self).__init__()# self.all_data=[]# self.all_lable=[]self.mode = modeself.train_data = []self.train_lable = []self.test_data = []self.test_lable = []self.file_path = file_pathself.train_ratio = train_ratiodataindex = 0for i in range(index_num):file_path = os.path.join(self.file_path,'%d' % i)for j in os.listdir(file_path):dataindex = dataindex % 10if(os.path.splitext(j)[-1] == '.png'):img = cv.imread(os.path.join(file_path, j))img = cv.resize(img, (64, 64))[..., (2, 1, 0)]if mode == 'train':if dataindex < train_ratio*10:self.train_data.append(img/255)self.train_lable.append(i)else:if dataindex >= train_ratio*10:self.test_data.append(img)self.test_lable.append(i)dataindex += 1def __getitem__(self,index):if self.mode == 'train':npData = np.asarray(self.train_data,dtype='float32')[index]npLable = np.asarray(self.train_lable,dtype='int64')[index]return npData,npLableelse:npData = np.asarray(self.test_data,dtype='float32')[index]npLable = np.asarray(self.test_lable,dtype='int64')[index]return npData, npLabledef __len__(self):if self.mode == 'train':return len(self.train_data)else:return len(self.test_data)if __name__ == '__main__':import matplotlib.pyplot as plttrain_dataset = myreader(r'animal_fruit',0.7,1,'train')train_loader = paddle.io.DataLoader(train_dataset, batch_size=64, shuffle=True)test_dataset = myreader(r'animal_fruit',0.7,1,'test')test_loader = paddle.io.DataLoader(test_dataset, batch_size=64, shuffle=True)for batch_id, data in enumerate(train_loader()):x_data = data[0]y_data = data[1]print(x_data.numpy().dtype)print(y_data.numpy().shape)# for batch_id, data in enumerate(test_loader()):#     x_data = data[0]#     y_data = data[1]#     print(x_data.numpy().shape)#     print(y_data.numpy().shape)print(train_dataset[1][0])l = np.array(train_dataset[0][0])plt.figure(figsize=(2,2))plt.imshow(l, cmap=plt.cm.binary)


测试:

from work import myImageReader
import paddle
import matplotlib.pyplot as plt
import numpy as np
train_dataset = myImageReader.myreader(r'animal',0.7,1,'train')
# print(train_dataset[1][0])
l = np.array(train_dataset[0][0])
plt.figure(figsize=(2,2))
plt.imshow(l, cmap=plt.cm.binary)

在飞桨平台做图像分类-1 制作基于飞桨的数据集|CSDN创作打卡相关推荐

  1. 在飞桨平台做图像分类-2 完成模型组网并训练|CSDN创作打卡

    在飞桨平台做图像分类 文章目录 在飞桨平台做图像分类 前言 导入数据集 Sequential组网 训练模型 模型评估及预测 后记 前言 计划是在寒假时用在飞桨平台上做动物,水果的分类. 前面已经完成了 ...

  2. 赛桨PaddleScience v1.0 Beta:基于飞桨核心框架的科学计算通用求解器

    近年来,关于AI for Science的主题被广泛讨论,重点领域包含使用AI方法加速设计并发现新材料,助力高能物理及天文领域的新问题探索,以及加速智慧工业实时设备数据与模型的"数字孪生&q ...

  3. 《用Python编程来做数学题|CSDN创作打卡》

    目录 一.Python编程:将8个苹果分成四组,每组至少一个苹果,有多少种方案? 二.求最大公约数和最小公倍数的函数 一.Python编程:将8个苹果分成四组,每组至少一个苹果,有多少种方案? 在数学 ...

  4. html5 制作 蝴蝶飞动的动态图片,fireworks制作蝴蝶飞gif动画

    最近看到好多朋友在画:随着一只上下飞舞的小蝴蝶,漂亮的书法字就一笔一划地写出来了.那么我们也来看看吧.下面学习啦小编给大家整理了关于fireworks制作蝴蝶飞gif动画的方法,希望你们喜欢. fir ...

  5. AI Studio 飞桨 零基础入门深度学习笔记4-飞桨开源深度学习平台介绍

    AI Studio 飞桨 零基础入门深度学习笔记4-飞桨开源深度学习平台介绍 深度学习框架 深度学习框架优势 深度学习框架设计思路 飞桨开源深度学习平台 飞桨开源深度学习平台全景 框架和全流程工具 模 ...

  6. 揭晓飞桨平台提速秘诀:INT8量化加速实现“事半功倍”

    为帮助广大企业和开发者更加便捷和快速地创建深度学习应用,百度飞桨正不断地提升平台的训练和推理能力,并与英特尔紧密合作,在至强® 平台集成的AI加速能力的支持下,以 INT8 量化方案,在不影响预测准确 ...

  7. 基于飞桨图像分类套件PaddleClas的柠檬分类竞赛实战

    前情提要   通过之前教程中的学习,相信大家对于如何搭建一个分类网络已经清晰了.那么我们不禁会想,有没有更快速的尝试模型及技巧的方法呢?因为我们在上一次课程中使用的代码都需要自己进行开发,自己写需要很 ...

  8. 基于飞桨复现图像分类模型TNT,实现肺炎CT分类

    本项目介绍了TNT图像分类模型,讲述了如何使用飞桨一步步构建TNT模型网络结构,并尝试在新冠肺炎CT数据集上进行分类.由于作者水平有限,若有不当之处欢迎批评指正. TNT模型介绍 TNT模型全称是Tr ...

  9. 深度学习入门实践学习——手写数字识别(百度飞桨平台)——上篇

    一.项目平台 百度飞桨 二.项目框架 1.数据处理: 2.模型设计:网络结构,损失函数: 3.训练配置:优化器,资源配置: 4.训练过程: 5.保存加载. 三.手写数字识别任务 1.构建神经网络流程: ...

  10. 英特尔计算引擎、阿里大规模图形神经网络平台、百度飞桨平台、索尼音乐生成AI套件......重量级深度学习工业产品亮相NeurIPS 2019行业展览会!

    NeurIPS 2019的正式会议将于加拿大/温哥华时间的12月9日早上8点开始.会议前一天将会举办为期一整天的行业展览会(可能是赞助商太多了--) 当别人为明天的正式会议捉急准备时,小助手已经在展览 ...

最新文章

  1. ios 中的关联对象
  2. CVE-2014-6271 漏洞告警
  3. 无敌简单快速的文件服务器sgfs
  4. mysql union null_mysql – 删除SQL中的SQL JOIN和UNION操作符中的NULL值
  5. C语言实现动态顺序表
  6. linux vim tag,Vim基础知识之ctags 及 Taglist 插件
  7. php-fpm的安装与测试
  8. java arraydeque poll,Java ArrayDeque pollLast()方法
  9. Ubuntu系统下载工具的推荐
  10. SVN 小乌龟(TortoiseSVN)本地文件更新报错Another process is blocking the working copy database 解决方法
  11. 影响科学圈的那些计算机代码
  12. 前端英文首字母转大写
  13. 金融学习之四——插值法求远期国债收益率
  14. 【微服务】(十)—— 统一网关Gateway
  15. <视觉SLAM十四讲> 李群与李代数
  16. Resco MobileForms Toolkit 2010的破解
  17. 【Spire.Doc】合并 Word 文档,将多个文档合并为一个
  18. 剑指 Offer 51-60
  19. 有机化学php,有机化学原理
  20. 雷军:《我十年的程序员生涯》系列之二(我赚的第一桶金)

热门文章

  1. RestSharp解决Encoding乱码问题
  2. 苹果录制屏幕在哪设置_屏幕录像专家如何录全屏 屏幕录像专家全屏录制设置方法...
  3. Latex入门——使用vscode实时编辑latex文档
  4. 软件测试——白盒测试
  5. 软件测试试题,软件评测师考试
  6. Unity UGUI源码解析
  7. 自学结构体(小甲鱼c语言)
  8. 一款简单好用的动画/游戏制作软件|源码编辑器|编程猫南宁体验中心
  9. FFmpeg入门详解之71:获取ffmpeg转码的实时进度
  10. Excel从省份证中提取信息