文章目录

  • 1.数据集准备
  • 2.pytorch Dataset 处理图片数据
  • 3.网络模型设计
  • 4.模型的训练与测试

1.数据集准备

本例采用了pytorch教程提供的蜜蜂、蚂蚁二分类数据集(点击可直接下载)。该数据集的文件夹结构如下图所示。这里面有些黑白的照片,我把它们删掉了,因为黑白照片的通道数是1,会造成Tensor的维度不一致。可以看出数据集分为训练集和测试集,训练集用于训练模型,测试集用于测试模型的泛化能力。在训练集和测试集下又包含了"ants"和"bees"两个文件夹,这两个文件夹的名称即图片的标签,在加载数据的时候需要用到这一点。有了数据,我们就想办法把这些数据处理成pytorch框架下的Dataset需要的格式。

2.pytorch Dataset 处理图片数据

pytorch为我们处理数据提供了一个模板,这个模板就是Dataset,我们在处理数据时继承这个类。在处理数据时要注意以下几点:

  1. 可以用PIL的Image加载图片,但要将图片处理成tensor,而且tensor的维度要一致。这是因为nn模型的输入都是tensor格式,而且要求一个batchsize的tensor维度是一样的。实现上述可能可以使用torchvision的transforms。由于我用的CPU训练模型,所以对图片压缩的比较厉害,全压缩成3*32!请添加图片描述
    *32的图片了。
  2. "ants"和"bees"两个文件夹的名称就是图片的标签,但是__getitem__的返回值应该是一个值。在这里"ants"标签返回0,"bees"标签返回1。
  3. 看数据的预处理对不对,可以用一段代码测试一下,将数据加载到DataLoader,然后循环取出数据,并把这些数据及其标签打印出来,或者记录到tensorboard上去,看每一次迭代返回的数据是否和自己预想的一样。

下面是代码,保存在dataProcess.py文件中。

rom torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
from torchvision import transforms
from torch.utils.tensorboard import SummaryWriterclass MyData(Dataset):# 把图片所在的文件夹路径分成两个部分,一部分是根目录,一部分是标签目录,这是因为标签目录的名称我们需要用到def __init__(self, root_dir, label_dir):self.root_dir = root_dirself.label_dir = label_dir# 图片所在的文件夹路径由根目录和标签目录组成self.path = os.path.join(self.root_dir, self.label_dir)# 获取文件夹下所有图片的名称self.img_names = os.listdir(self.path)def __getitem__(self, idx):img_name = self.img_names[idx]img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)img = Image.open(img_item_path)# 将图片处理成Tensor格式,并将维度设置成32*32的# 图片的维度可能不一致,这里一定要用resize统一一下,否则会出错trans = transforms.Compose([transforms.ToTensor(),transforms.Resize((32, 32))])img_tensor = trans(img)# 根据标签目录的名称来确定图片是哪一类,如果是"ants",标签设置为0,如果是"bees",标签设置为1# 这个地方要注意,我们在计算loss的时候用交叉熵nn.CrossEntropyLoss()# 交叉熵的输入有两个,一个是模型的输出outputs,一个是标签targets,注意targets是一维tensor# 例如batchsize如果是2,ants的targets的应该[0,0],而不是[[0][0]]# 因此label要返回0,而不是[0]label = 0 if self.label_dir == "ants" else 1return img_tensor,  labeldef __len__(self):return len(self.img_names)# 用下面这段代码测试一下加载数据有没有问题
if __name__ == "__main__":# 注意hymenoptera_data和代码在同一级目录root_dir = "hymenoptera_data/train"ants_label = "ants"bees_label = "bees"# 蚂蚁数据集ants_dataset = MyData(root_dir, ants_label)# 蜜蜂数据集bees_dataset = MyData(root_dir, bees_label)# 蚂蚁数据集和蜜蜂数据集合并train_dataset = ants_dataset + bees_dataset# 利用dataLoader加载数据集train_dataloader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# tensorboard的writerwriter = SummaryWriter("logs")for step, train_data in enumerate(train_dataloader):imgs, targets = train_data# 每迭代一次就把一个batch的图片记录到tensorboardwriter.add_images("test", imgs, step)# 每迭代一次就把一个batch的图片标签打印出来print(targets)writer.close()

在测试时tensorboard记录的信息在logs文件夹,在terminal输入tensorboard --logdir=logs启动tensorboard,将tensorboard给出的网址输入到网页,可以看到每一个batch的图片。下图展示了第一个batch的图片。可以看到,取出了64张图片,和batchsize=64是对应的。另外可以看到,把图片压缩成32*32后,确实很模糊了,人眼都很难看出哪个是蚂蚁,哪个是蜜蜂。

下面这个图展示了第一个batch所有图片的标签,0表示蚂蚁,1表示蜜蜂,仔细看一下图片和标签应该是对应的。

3.网络模型设计

我们把图片处理成3*32*32的tensor了,用如下图所示的卷积神经网络模型。第一层卷积网络采用5*5的卷积核,stride=1,pading=2。第一层卷积的代码是:nn.Conv2d(3, 32, 5, 1, 2),第一个参数3是输入的通道数,第二个参数32是输出的通道数,第三个参数5是卷积核的大小,第四个参数1是stride,第五个参数2是padding。

输出高H,和宽度W计算公式如下所示(注意dilation默认为0)。

Hout=⌊Hin+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1stride[0]+1⌋H_{out} = \left\lfloor\frac{H_{in} + 2 \times \text{padding}[0] - \text{dilation}[0] \times (\text{kernel\_size}[0] - 1) - 1}{\text{stride}[0]} + 1\right\rfloor Hout​=⌊stride[0]Hin​+2×padding[0]−dilation[0]×(kernel_size[0]−1)−1​+1⌋
Wout=⌊Win+2×padding[1]−dilation[1]×(kernel_size[1]−1)−1stride[1]+1⌋W_{out} = \left\lfloor\frac{W_{in} + 2 \times \text{padding}[1] - \text{dilation}[1] \times (\text{kernel\_size}[1] - 1) - 1}{\text{stride}[1]} + 1\right\rfloor Wout​=⌊stride[1]Win​+2×padding[1]−dilation[1]×(kernel_size[1]−1)−1​+1⌋
因此,通过第一层卷积后,高度H为,
Hout=32+2×2−1×(5−1)−11+1=32H_{out}=\frac{32+2 \times 2 -1\times(5-1)-1}{1}+1=32Hout​=132+2×2−1×(5−1)−1​+1=32
同理宽度W也为32。所以输出的大小就32*32*32。接下来,再用一个max-Pooling进行一次池化,池化核的大小是2*2。该池化层的代码是nn.MaxPool2d(2)。池化输出高H,和宽度W计算公式和卷积计算方式一摸一样。在默认的情况下,stride和池化和的大小一样,pading=0,dilation=0。所以第一次池化后,输出的高度H为,
Hout=32+2×0−1×(2−1)−12+1=16H_{out}=\frac{32+2 \times 0 -1\times(2-1)-1}{2}+1=16Hout​=232+2×0−1×(2−1)−1​+1=16
同理,输出的宽度H为16。因此,输出的维度是32*16*16。
后面的输出维度计算方式同上,不再罗嗦了。然后再通过两次卷积和两次池化,后面的输出维度计算方式同上,不再罗嗦了,最终得到一个维度为64*4*4的特征。在做分类之前,首先要把这个三维Tensor拉直成一维Tensor,代码是nn.Flatten()。拉直之后的一维Tensor大小就是64×4×4=102464\times4\times4=102464×4×4=1024。最后通过一个全连接层完成分类任务,全连接层的输入大小是1024,输出的大小是类别的个数,即2,代码是nn.Linear(64 * 4 * 4, 2)。

当完成所有模型的构建后,可以用一段代码来测试一下模型是否有误。例如这里模型的输入在[3,32,32]Tensor的基础上,还需要再增加一维batchsize,所以输入的维度应该是[batchsize,3,32,32]。我们可以生成一个这样维度的数据,例如假设batchsize=3,可以这样生成一个输入:x = torch.ones((3, 3, 32, 32))。然后把x送给模型,看模型是否能正常输出,输出的维度是否是我们预期的。我们还可以借助于Tensorboard来将模型可视化,通过界面把模型展开,看是否正确。
下面是所有的代码,保存在model.py文件中。

from torch import nn
import torch
from torch.utils.tensorboard import SummaryWriterclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 2))def forward(self, x):x = self.model(x)return x# 这段代码测试model是否正确
if __name__ == "__main__":my_model = MyModel()x = torch.ones((3, 3, 32, 32))y = my_model(x)print(y.shape)# 利用tensorboard可视化模型writer = SummaryWriter("graph_logs")writer.add_graph(my_model, x)writer.close()

模型测试代码打印的输出维度是[3,2],3是batchsize,2是全连接层最后的输出维度,和类别的个数是一致的。利用Tensorboard将模型可视化后,如下图所示,还可以进一步展开。

4.模型的训练与测试

模型的训练与测试就不细讲了,和其他模型训练的套路一样的,基本思路可以看我的第一篇pytorch入门文章。下面直接给出代码。

from model import *
from dataProcess import *
import matplotlib.pyplot as plt
import time# 加载训练数据
train_root_dir = "hymenoptera_data/train"
train_ants_label = "ants"
train_bees_label = "bees"
train_ants_dataset = MyData(train_root_dir, train_ants_label)
train_bees_dataset = MyData(train_root_dir, train_bees_label)
train_dataset = train_ants_dataset + train_bees_dataset
train_data_loader = DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
train_data_len = len(train_dataset)
# 加载测试数据
test_root_dir = "hymenoptera_data/val"
test_ants_label = "ants"
test_bees_label = "bees"
test_ants_dataset = MyData(test_root_dir, test_ants_label)
test_bees_dataset = MyData(test_root_dir, test_bees_label)
test_dataset = test_ants_dataset + test_bees_dataset
test_data_loader = DataLoader(dataset=test_dataset, batch_size=256, shuffle=True)
test_data_len = len(test_dataset)
print(f"训练集长度:{train_data_len}")
print(f"测试集长度:{test_data_len}")
# 创建网络模型
my_model = MyModel()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 5e-3
optimizer = torch.optim.SGD(my_model.parameters(), lr=learning_rate)
# Adam 参数betas=(0.9, 0.99)
# optimizer = torch.optim.Adam(my_model.parameters(), lr=learning_rate, betas=(0.9, 0.99))
# 总共的训练步数
total_train_step = 0
# 总共的测试步数
total_test_step = 0
step = 0
epoch = 500writer = SummaryWriter("logs")
train_loss_his = []
train_totalaccuracy_his = []
test_totalloss_his = []
test_totalaccuracy_his = []
start_time = time.time()
my_model.train()
for i in range(epoch):print(f"-------第{i}轮训练开始-------")train_total_accuracy = 0for data in train_data_loader:imgs, targets = datawriter.add_images("tarin_data", imgs, total_train_step)output = my_model(imgs)loss = loss_fn(output, targets)train_accuracy = (output.argmax(1) == targets).sum()train_total_accuracy = train_total_accuracy + train_accuracyoptimizer.zero_grad()loss.backward()optimizer.step()total_train_step = total_train_step + 1train_loss_his.append(loss)writer.add_scalar("train_loss", loss.item(), total_train_step)train_total_accuracy = train_total_accuracy / train_data_lenprint(f"训练集上的准确率:{train_total_accuracy}")train_totalaccuracy_his.append(train_total_accuracy)# 测试开始total_test_loss = 0my_model.eval()test_total_accuracy = 0with torch.no_grad():for data in test_data_loader:imgs, targets = dataoutput = my_model(imgs)loss = loss_fn(output, targets)total_test_loss = total_test_loss + losstest_accuracy = (output.argmax(1) == targets).sum()test_total_accuracy = test_total_accuracy + test_accuracytest_total_accuracy = test_total_accuracy / test_data_lenprint(f"测试集上的准确率:{test_total_accuracy}")print(f"测试集上的loss:{total_test_loss}")test_totalloss_his.append(total_test_loss)test_totalaccuracy_his.append(test_total_accuracy)writer.add_scalar("test_loss", total_test_loss.item(), i)
end_time = time.time()
total_train_time = end_time-start_time
print(f'训练时间: {total_train_time}秒')
writer.close()
plt.plot(train_loss_his, label='Train Loss')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()
plt.plot(test_totalloss_his, label='Test Loss')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()plt.plot(train_totalaccuracy_his, label='Train accuracy')
plt.plot(test_totalaccuracy_his, label='Test accuracy')
plt.legend(loc='best')
plt.xlabel('Steps')
plt.show()

通过上述代码,训练得到的结果如下图所示,

结果虽然不是很好,但是我觉得已经很不多了,在测试集上的准确率差不多达到0.7了。为了节省计算资源,我把图片压缩成32*32,连我们人眼都很难分辨出哪个是蚂蚁,哪个是蜜蜂。另外,我这个模型是完全从0开始训练的,隔壁在预训练模型的基础上进行训练得到的效果好像没好多少。。。

我的实践:通过蚂蚁、蜜蜂二分类问题了解如何基于Pytorch构建分类模型相关推荐

  1. Pytorch之模型微调(Finetune)——用Resnet18进行蚂蚁蜜蜂二分类为例

    Pytorch之模型微调(Finetune)--手写数字集为例 文章目录 Pytorch之模型微调(Finetune)--手写数字集为例 前言 一.Transfer Learning and Mode ...

  2. R语言基于glmnet构建分类模型并可视化特征系数(coefficient)以及L1正则化系数(lambda)实战

    R语言基于glmnet构建分类模型并可视化特征系数(coefficient)以及L1正则化系数(lambda)实战 # 导入测试数据集 data(BinomialExample) x <- Bi ...

  3. java按顺序售票方法_java_Java代码实践12306售票算法(二),周五闲来无事,基于上一篇关 - phpStudy...

    Java代码实践12306售票算法(二) 周五闲来无事,基于上一篇关于浅析12306售票算法(java版)理论,进行了java编码实践供各位读者参考(以下为相关代码的简单描述) 1.订票工具类 1.1 ...

  4. 基于PyTorch的LSTM模型的IMBD情感分类遇到的问题

    今天想学LSTM的情感分类,结果碰到了一系列问题,耽误了很多时间.特此记录! 一.项目来源 lesson53-情感分类实战 B站视频 二.碰到的问题 1.报错AttributeError: modul ...

  5. 「深度学习一遍过」必修17:基于Pytorch细粒度分类实战

    本专栏用于记录关于深度学习的笔记,不光方便自己复习与查阅,同时也希望能给您解决一些关于深度学习的相关问题,并提供一些微不足道的人工神经网络模型设计思路. 专栏地址:「深度学习一遍过」必修篇 目录 1 ...

  6. 基于pytorch构建双向LSTM(Bi-LSTM)文本情感分类实例(使用glove词向量)

    学长给的代码,感觉结构清晰,还是蛮不错的,想以后就照着这样的结构走好了,记录一下. 首先配置环境 matplotlib==3.4.2 numpy==1.20.3 pandas==1.3.0 sklea ...

  7. 【项目实战课】基于Pytorch的EfficientNet血红细胞分类竞赛实战

    欢迎大家来到我们的项目实战课,本期内容是<基于Pytorch的EfficientNet血红细胞分类竞赛实战>.所谓项目课,就是以简单的原理回顾+详细的项目实战的模式,针对具体的某一个主题, ...

  8. 用Flair(PyTorch构建的NLP开发包)进行文本分类

    Flair是一个基于PyTorch构建的NLP开发包,它在解决命名实体识别(NER).语句标注(POS).文本分类等NLP问题时达到了当前的顶尖水准.本文将介绍如何使用Flair构建定制的文本分类器. ...

  9. 初学者之蚂蚁蜜蜂分类报错记录

    听了bilibil小土堆的课,拿蚂蚁蜜蜂数据集练手,记录下自己犯的错~ 1.trans = transform.Compose() # 这段报错了-- #图片转为tensor 修改维度 trans = ...

最新文章

  1. python图形设置_python学习笔记——基本图形绘制
  2. boost::gil模块数字扩展中的 convolve_rows() 和 convolve_cols() 示例
  3. ERROR 1205 (HY000): Lock wait timeout exceeded; try restarting transaction
  4. 测试过程中常用的linux命令之【查找指定的文件内容】
  5. [渝粤教育] 广东-国家-开放大学 21秋期末考试中国近现代史纲要(A)10881k1
  6. krsort_PHP krsort()函数与示例
  7. termcap-1.3.1的configure.in文件逐行分析
  8. js作用域与作用域链
  9. (33)Gulp构建脚本文件
  10. 使用Flex4画图形
  11. linux 网络协议栈
  12. IMU噪声参数辨识-艾伦方差
  13. 如何在ppt中打开html,如何在ppt中直接打开网页
  14. SpringBoot之DispatcherServlet详解及源码解析
  15. Verilog语言程序框架
  16. 热点综述 | 纵向微生物组研究的统计方法总结
  17. python+selenium高级教程
  18. 2021最新《python爬虫从0-1》5.正则表达式讲解
  19. python中的break、continue和pass
  20. 公司员工培训管理系统的开发研究(J2EE)

热门文章

  1. CAD更改没有的字体,打开时如何选择字体
  2. Word文档使用方法
  3. 计算机一级考试正文首字下沉2行,各段落首字下沉两行、距正文0.1cm。
  4. 小程序进阶-图表库uchart
  5. win to go 给移动硬盘装双系统
  6. 给视频配音的这两种方法你知道吗?
  7. Java实现邮件找回密码功能
  8. ORACLE-SQL调优
  9. html夸女生的代码,抖音上夸女孩子漂亮的语句
  10. C语言中的malloc与free函数