一、序言

使用百度飞浆提供的paddle框架实现蝴蝶分类,环境:paddle 2.0.2,opencv 4.5.4.58,pycharm编译器。

目录结构:

  • Butterfly20里有20个文件夹,分别代表20种蝴蝶种类,每个文件夹内有多个同种类的蝴蝶照片
  • Butterfly20_test里有200张蝴蝶照片用于测试训练好的网络
  • visualdl_log里存放训练好的网络,使用log文件格式
  • species.txt里存放20种类别的名称和序号
  • train_set和validation_set在运行时随机分配

二、准备数据

随机查看一个蝴蝶图片及其类别

data_path= '.\Butterfly20\*\*.jpg'
but_files =glob.glob(data_path) #获取Butterfly20中的所有图片地址print('图片数据为',len(but_files))#随机显示一个样品的图片
index=random.choice(but_files)  # 随机获取一个图片
print(index)  # 查看地址name=index.split('\\')[-2]  # 获取标签,得到的是训练集中随机蝴蝶的类别
img = Image.open(index)  # 打开图片
img = cv2.imread(index)  # 图片处理
print(img.shape)  # 输出图片形状(441,600,3)
img = img[:,:,::-1] # 三通道,-1表示从右往左切片,opencv输入为BGR,故从右往左切片为RGB三通道
print(f'该样本标签为:{name}')
cv2.imshow("ran_img",img)
cv2.waitKey(0)

测试输出为:

写一个Reader类,其中定义三个函数,分别为初始化、处理图像、计算长度,使用Reader类加载训练集与数据集

### 查看数据类型
data_list = [] #用个列表保存每个样本的读取路径、标签
# 由于属种名称本身是字符串,而输入模型的是数字。需要构造一个字典,把某个数字代表该属种名称。键是属种名称,值是整数。
label_list=[]
with open("E:/Pycharm/workspace/OpenCV/butterfly/species.txt") as f:for line in f:a,b = line.strip("\n").split(" ") #a为1-20的序号,b为每个种类的namelabel_list.append([b, int(a)-1]) #将20种txt种的类别加入label_list数组种
label_dic = dict(label_list) #dict创建一个字典,字典中有20种蝴蝶类型butterfly_path = './Butterfly20/'
#若项目目录内已经有train_set与validation_set两个数据集,则删除,之后重新创建这两个数据集
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')):  # 判断有误文件os.remove('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt')  # 删除文件
if(os.path.exists('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')):os.remove('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt')for i in os.listdir(butterfly_path): #得到Butterfly20里的所有文件夹if i not in '.DS_Store': #DB_Store里是20种蝴蝶类型的名字for j in os.listdir(os.path.join(butterfly_path, i)): #路径拼接,拼接后为./Butterfly20/20种名字,j从这个路径里提取序号.jpgdata_list.append(f'{os.path.join(butterfly_path, i, j)}\t{label_dic[i]}\n') #前一个大括号是每个图片具体路径,后一个是其种类的序号random.shuffle(data_list)  # 乱序
print(data_list[0]) #打印随机选出的第一个图片以及其属于的种类号
data_len = len(data_list)
count = 0for data in data_list:if count <= data_len*0.8:with open('E:/Pycharm/workspace/OpenCV/butterfly/train_set.txt', 'a')as f: # 80%写入训练集f.write(data)count += 1else:with open('E:/Pycharm/workspace/OpenCV/butterfly/validation_set.txt', 'a')as tf:  # 20%写入验证集tf.write(data)count += 1# 自定义数据读取器
class Reader(Dataset):def __init__(self, mode='train_set'):"""初始化函数"""self.data = []with open(f'{mode}_set.txt') as f: #train_set或validation_setfor line in f.readlines():info = line.strip().split('\t') #strip函数去掉首部等于参数值的字符,无参数表示删掉换行符if len(info) > 0:self.data.append([info[0].strip(), info[1].strip()])def __getitem__(self, index): #将图片转换为(224,224)像素大小"""读取图片,对图片进行归一化处理,返回图片和 标签"""image_file, label = self.data[index]  # 获取数据img = Image.open(image_file)  # 读取图片img = img.convert('RGB')img = img.resize((224, 224), Image.ANTIALIAS)  # 图片大小样式归一化img = np.array(img).astype('float32')  # 转换成数组类型浮点型32位img = img.transpose((2, 0, 1))  # 读出来的图像是rgb,rgb,rbg..., 转置为 rrr...,ggg...,bbb...img = img / 255.0  # 数据缩放到0-1的范围return img, np.array(label, dtype='int64')def __len__(self):"""获取样本总数"""return len(self.data)#调用Reader类,其中三个函数都会走
# 训练的数据提供器
train_dataset = Reader(mode='train')
# 测试的数据提供器
eval_dataset = Reader(mode='validation')# 查看训练和测试数据的大小
print('train大小:', train_dataset.__len__())
print('eval大小:', eval_dataset.__len__())# 随机查看图片数据、大小及标签
for data, label in eval_dataset:print(data)print(np.array(data).shape) #(3,224,224)print(label)break #只循环一次即可

三、构建网络

使用paddle框架构造神经网络,选用resnet152网络用于图像分类,最后分为20类

import paddle.nn.functional as F
#定义模型
class MyNet(paddle.nn.Layer):def __init__(self):super(MyNet,self).__init__()self.layer=paddle.vision.models.resnet152(pretrained=True) #152层的resnet模型,预训练模型只需要设定模型参数pretained=Trueself.dropout=paddle.nn.Dropout(p=0.5) #Dropout值设为0.5,self.fc1 = paddle.nn.Linear(1000, 512) #fc为全连接层,与模型训练后为1000个输出,要最后分20类self.fc2 = paddle.nn.Linear(512, 20) #两个全连接层实现1000-20#网络的前向计算过程def forward(self,x):x=self.layer(x) #resnet152模型x=self.dropout(x) #值为0.5的Dropoutx=self.fc1(x) #第一个全连接层x=F.relu(x) #使用relu函数激活x=self.fc2(x) #第二个全连接层得到20个分类特征return x

resnet网络结构如下:



















四、训练网络

用构建好的resnet152网络进行训练

model = paddle.Model(MyNet())
model.summary((1, 3, 224, 224)) #输出各层参数input_define = paddle.static.InputSpec(shape=[-1,3,224,224], dtype="float32", name="img")
label_define = paddle.static.InputSpec(shape=[-1,1], dtype="int64", name="label")#实例化网络对象并定义优化器等训练逻辑
model = MyNet()
model = paddle.Model(model,inputs=input_define,labels=label_define) #用Paddle.Model()对模型进行封装
optimizer = paddle.optimizer.Adam(learning_rate=0.0001, parameters=model.parameters())
#上述优化器中的学习率(learning_rate)参数很重要。要是训练过程中得到的准确率呈震荡状态,忽大忽小,可以试试进一步把学习率调低。model.prepare(optimizer=optimizer, #指定优化器loss=paddle.nn.CrossEntropyLoss(), #指定损失函数metrics=paddle.metric.Accuracy()) #指定评估方法callback = paddle.callbacks.VisualDL(log_dir='./visualdl_log')model.fit(train_data=train_dataset,     #训练数据集eval_data=eval_dataset,         #测试数据集batch_size=64,                  #一个批次的样本数量epochs=100,                      #迭代轮次save_dir="./visualdl_log", #把模型参数、优化器参数保存至自定义的文件夹save_freq=20,                    #设定每隔多少个epoch保存模型参数及优化器参数log_freq=100,                     #打印日志的频率verbose=1,                        # 日志展示模式shuffle=True,                     # 是否打乱数据集顺序callbacks=callback                # 回调函数使用)result = model.evaluate(eval_dataset, verbose=1)
print(result)model.save('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model')  # 保存模型

五、预测图片

随机使用一张图片,通过训练好的网络进行预测蝴蝶的种类,该蝴蝶属于第15类

def load_image(file): #加载测试图片并处理图片# 打开图片im = Image.open(file)# 将图片调整为跟训练数据一样的大小im = im.convert('RGB')im = im.resize((224, 224), Image.ANTIALIAS)# 建立图片矩阵 类型为float32im = np.array(im).astype(np.float32)# 矩阵转置im = im.transpose((2, 0, 1))# 将像素值从[0-255]转换为[0-1]im = im / 255.0# print(im)im = np.expand_dims(im, axis=0)# 保持和之前输入image维度一致print('im_shape的维度:', im.shape)return imfrom PIL import Image
# site = 255  # 读取图片位置
model_state_dict = paddle.load('E:/Pycharm/workspace/OpenCV/butterfly/butterfly_model.pdparams')  # 读取模型
model = MyNet()  # 实例化模型
model.set_state_dict(model_state_dict) #浅拷贝,读取模型
model.eval() #不进行BN与dropout,使用所有全职计算img = load_image(index)print(paddle.to_tensor(img).shape)
# print(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))
ceshi = model(paddle.reshape(paddle.to_tensor(img), (1, 3, 224, 224)))  # 测试
print('预测的结果为:', list(label_dic.keys())[np.argmax(ceshi.numpy())])  # 获取值
with open("./work/result.txt", "w") as f:for r in result:f.write("{}\n".format(r))
Image.open(index)  # 显示图片

预测结果:

基于paddlepaddle构建resnet神经网络的蝴蝶分类相关推荐

  1. 基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题

    基于PaddlePaddle构建ResNet18残差神经网络的食物图片分类问题 Introduction 本项目是在李宏毅机器学习课程的作业3进行的工作,任务是手动搭建一个CNN模型进行食物图片分类( ...

  2. 基于2维卷积神经网络的心电图分类

    在这里给大家分享一篇关于用深度学习进行心电图识别的论文,原文地址https://arxiv.org/abs/1804.06812,我翻译成了中文以便大家快速学习,中间难免有疏忽遗漏的地方,请大家谅解. ...

  3. 基于Pytorch全连接神经网络实现多分类

    (一)计算机视觉工具包的介绍 为了方便开发者应用,PyTorch专门开发了一个视觉工具包torchvision,主要包含以下三个部分: 1.models models提供了深度学习中各种经典的神经网络 ...

  4. 【人工智能 卷积神经网络】基础练习:基于torch构建卷积神经网络,测试集正确率达 百分之99

    声明:仅学习使用~ 这是一个关于卷积神经网络CNN的基础练习,也算是一个回顾.包含分解步骤,内容整合 以及最后的整体输出. 目录 一.步骤分解 1.0 系统环境.主要模块版本 1.1 相关模块的导入 ...

  5. 【项目实战】Python基于librosa和人工神经网络实现语音识别分类模型(ANN算法)项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 语音识别发展到现在作为人机交互的重要接口已经在很多方面改变了我们 ...

  6. python构建bp神经网络_鸢尾花分类(一个隐藏层)__1.数据集

    IDE:jupyter 目前我知道的数据集来源有两个,一个是csv数据集文件另一个是从sklearn.datasets导入 1.1 csv格式的数据集(下载地址已上传到博客园----数据集.rar) ...

  7. 基于pytorch搭建ResNet神经网络用于花类识别

  8. Python基于PyTorch实现BP神经网络ANN分类模型项目实战

    说明:这是一个机器学习实战项目(附带数据+代码+文档+视频讲解),如需数据+代码+文档+视频讲解可以直接到文章最后获取. 1.项目背景 在人工神经网络的发展历史上,感知机(Multilayer Per ...

  9. 【Pytorch(七)】基于 PyTorch 实现残差神经网络 ResNet

    基于 PyTorch 实现残差神经网络 ResNet 文章目录 基于 PyTorch 实现残差神经网络 ResNet 0. 概述 1. 数据集介绍 1.1 数据集准备 1.2 分析分类难度:CIFAR ...

最新文章

  1. Docker的使用(五:Docker中的网络与数据管理)
  2. Nature综述:多年冻土的微生物组
  3. Cisco ××× 完全配置指南-连载-IPSec
  4. css3 -webkit-filter
  5. MediaPlayer使用方法简单介绍
  6. 工作三年的一点感想(展望篇)
  7. 由partition看窗口函数
  8. lambda函数if_Python3中lambda表达式与函数式编程讲解
  9. Swift 5新特性详解:ABI 稳定终于来了!
  10. html5文本与段落简介,美化html段落文本 Ⅰ
  11. SOA架构设计的案例分析
  12. python 调用海康sdk_Qt调用海康SDK实现摄像头视频播放
  13. com.android.dx.cf.iface.ParseException
  14. creat是什么意思中文翻译_CREAT是什么意思中文翻译
  15. 后端开发工程师不懂这些就危险了
  16. flutter APP自动更新
  17. 前端移动端页面与手机尺寸和分辨率的关系
  18. 学习【菜鸟教程】【C++ 类 对象】【内联函数】(例子简单,评论难懂)
  19. 微信开发者工具-真机调试
  20. 当青春走到尽头你会想念你自己吗

热门文章

  1. 2021年中国综艺赞助情况回顾及未来发展趋势:品牌更乐于与成熟的综N代合作,未来合作方式更多元化[图]
  2. C++ using declaration
  3. MySQL8使用with recursive实现递归
  4. 数据结构与算法-Prim算法解析与解决修路最小生成树问题
  5. 外卖小哥莫名成10家公司监事 企业登记存监管漏洞
  6. [内存管理]内存池pool库
  7. c/c++进阶之爱恨交织的临时对象: 二、天使与魔鬼
  8. 菜鸟窝-仿京东淘宝项目学习笔记(二)ToolBar的基本使用
  9. 目前计算机技术已经得到了全面的发展,计算机网络技术对人的全面发展的影响.doc...
  10. jzoj3424. 【NOIP2013模拟】粉刷匠