文章目录

  • 摘要
  • 数据增强Cutout和Mixup
  • 项目结构
  • 导入项目使用的库
  • 设置全局参数
  • 图像预处理与增强
  • 读取数据
  • 设置模型
  • 定义训练和验证函数
  • 测试

摘要

本例提取了植物幼苗数据集中的部分数据做数据集,数据集共有12种类别,演示如何使用pytorch版本的MobileNetV1图像分类模型实现分类任务。

通过本文你和学到:

1、如何自定义MobileNetV1模型。

2、如何自定义数据集加载方式?

3、如何使用Cutout数据增强?

4、如何使用Mixup数据增强。

5、如何实现训练和验证。

6、预测的两种写法。

MobileNetV1的论文翻译:【第26篇】MobileNets:用于移动视觉应用的高效卷积神经网络_AI浩-CSDN博客

MobileNetV1解析:

mobileNetV1网络解析,以及实现(pytorch)_AI浩-CSDN博客

Keras版本:

MobileNet实战:tensorflow2.X版本,MobileNetV1图像分类任务(大数据集)_AI浩-CSDN博客

数据增强Cutout和Mixup

为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:

pip install torchtoolbox

Cutout实现,在transforms中。

from torchtoolbox.transform import Cutout# 数据预处理transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

Mixup实现,在train方法中。需要导入包:from torchtoolbox.tools import mixup_data, mixup_criterion

    for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)data, labels_a, labels_b, lam = mixup_data(data, target, alpha)optimizer.zero_grad()output = model(data)loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)loss.backward()optimizer.step()print_loss = loss.data.item()

项目结构

MobileNetV1_demo
├─data
│  └─train
│      ├─Black-grass
│      ├─Charlock
│      ├─Cleavers
│      ├─Common Chickweed
│      ├─Common wheat
│      ├─Fat Hen
│      ├─Loose Silky-bent
│      ├─Maize
│      ├─Scentless Mayweed
│      ├─Shepherds Purse
│      ├─Small-flowered Cranesbill
│      └─Sugar beet
├─dataset
│  └─dataset.py
└─models
│    └─mobilenetV1.py
├─train.py
├─test1.py
└─test.py

导入项目使用的库

import torch.optim as optim
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transformsfrom dataset.dataset import SeedlingData
from torch.autograd import Variable
from Model.mobilenetv1 import MobileNetV1
from torchtoolbox.tools import mixup_data, mixup_criterion
from torchtoolbox.transform import Cutout

设置全局参数

设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使用CPU。建议使用GPU,CPU太慢了。

# 设置全局参数
modellr = 1e-4
BATCH_SIZE = 16
EPOCHS = 300
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

图像预处理与增强

数据处理比较简单,加入了Cutout、做了Resize和归一化。

# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

读取数据

将数据集解压后放到data文件夹下面,如图:

然后我们在dataset文件夹下面新建 init.py和dataset.py,在datasets.py文件夹写入下面的代码:

# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
from sklearn.model_selection import train_test_splitLabels = {'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3,'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8,'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}class SeedlingData (data.Dataset):def __init__(self, root, transforms=None, train=True, test=False):"""主要目标: 获取所有图片的地址,并根据训练,验证,测试划分数据"""self.test = testself.transforms = transformsif self.test:imgs = [os.path.join(root, img) for img in os.listdir(root)]self.imgs = imgselse:imgs_labels = [os.path.join(root, img) for img in os.listdir(root)]imgs = []for imglable in imgs_labels:for imgname in os.listdir(imglable):imgpath = os.path.join(imglable, imgname)imgs.append(imgpath)trainval_files, val_files = train_test_split(imgs, test_size=0.3, random_state=42)if train:self.imgs = trainval_fileselse:self.imgs = val_filesdef __getitem__(self, index):"""一次返回一张图片的数据"""img_path = self.imgs[index]img_path=img_path.replace("\\",'/')if self.test:label = -1else:labelname = img_path.split('/')[-2]label = Labels[labelname]data = Image.open(img_path).convert('RGB')data = self.transforms(data)return data, labeldef __len__(self):return len(self.imgs)

说一下代码的核心逻辑:

第一步 建立字典,定义类别对应的ID,用数字代替类别。

第二步 在__init__里面编写获取图片路径的方法。测试集只有一层路径直接读取,训练集在train文件夹下面是类别文件夹,先获取到类别,再获取到具体的图片路径。然后使用sklearn中切分数据集的方法,按照7:3的比例切分训练集和验证集。

第三步 在__getitem__方法中定义读取单个图片和类别的方法,由于图像中有位深度32位的,所以我在读取图像的时候做了转换。

然后我们在train.py调用SeedlingData读取数据 ,记着导入刚才写的dataset.py(from dataset.dataset import SeedlingData)

dataset_train = SeedlingData('data/train', transforms=transform, train=True)
dataset_test = SeedlingData("data/train", transforms=transform_test, train=False)
# 读取数据
print(dataset_train.imgs)# 导入数据
train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

设置模型

  • 设置loss函数为nn.CrossEntropyLoss()。
  • 设置模型为MobileNetV1,num_classes设置为12。
  • 优化器设置为adam。
  • 学习率调整策略选择为余弦退火。
# 实例化模型并且移动到GPU
criterion = nn.CrossEntropyLoss()
model_ft = MobileNetV1(num_classes=12)
model_ft.to(DEVICE)
# 选择简单暴力的Adam优化器,学习率调低
optimizer = optim.Adam(model_ft.parameters(), lr=modellr)
cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer,T_max=20,eta_min=1e-9)

定义训练和验证函数

# 定义训练过程
alpha=0.2
def train(model, device, train_loader, optimizer, epoch):model.train()sum_loss = 0total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)data, labels_a, labels_b, lam = mixup_data(data, target, alpha)optimizer.zero_grad()output = model(data)loss = mixup_criterion(criterion, output, labels_a, labels_b, lam)loss.backward()optimizer.step()lr = optimizer.state_dict()['param_groups'][0]['lr']print_loss = loss.data.item()sum_loss += print_lossif (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item(),lr))ave_loss = sum_loss / len(train_loader)print('epoch:{},loss:{}'.format(epoch, ave_loss))ACC=0
# 验证过程
def val(model, device, test_loader):global ACCmodel.eval()test_loss = 0correct = 0total_num = len(test_loader.dataset)print(total_num, len(test_loader))with torch.no_grad():for data, target in test_loader:data, target = Variable(data).to(device), Variable(target).to(device)output = model(data)loss = criterion(output, target)_, pred = torch.max(output.data, 1)correct += torch.sum(pred == target)print_loss = loss.data.item()test_loss += print_losscorrect = correct.data.item()acc = correct / total_numavgloss = test_loss / len(test_loader)print('\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(avgloss, correct, len(test_loader.dataset), 100 * acc))if acc > ACC:torch.save(model_ft, 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')ACC = acc# 训练for epoch in range(1, EPOCHS + 1):train(model_ft, DEVICE, train_loader, optimizer, epoch)cosine_schedule.step()val(model_ft, DEVICE, test_loader)

运行结果:

测试

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import os
classes = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat','Fat Hen', 'Loose Silky-bent','Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)path='data/test/'
testList=os.listdir(path)
for file in testList:img=Image.open(path+file)img=transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out=model(img)# Predict_, pred = torch.max(out.data, 1)print('Image Name:{},predict:{}'.format(file,classes[pred.data.item()]))

运行结果:

第二种 使用自定义的Dataset读取图片

import torch.utils.data.distributed
import torchvision.transforms as transforms
from dataset.dataset import SeedlingData
from torch.autograd import Variableclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat','Fat Hen', 'Loose Silky-bent','Maize','Scentless Mayweed','Shepherds Purse','Small-flowered Cranesbill','Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = torch.load("model.pth")
model.eval()
model.to(DEVICE)dataset_test =SeedlingData('data/test/', transform_test,test=True)
print(len(dataset_test))
# 对应文件夹的labelfor index in range(len(dataset_test)):item = dataset_test[index]img, label = itemimg.unsqueeze_(0)data = Variable(img).to(DEVICE)output = model(data)_, pred = torch.max(output.data, 1)print('Image Name:{},predict:{}'.format(dataset_test.imgs[index], classes[pred.data.item()]))index += 1

运行结果:


完整代码:
https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/78852856

MobileNetV1实战:使用MobileNetV1实现植物幼苗分类相关推荐

  1. Kaggle图像识别竞赛 Plant Seedlings Classification(植物幼苗分类)具体实现

    目录 0. 前言 1. 总体设计 2. import部分 3. 具体实现步骤 一.数据预处理 (一)均衡化 (二)提取图片中叶子(绿色)的部分 二.提取特征 (一)SIFT提取关键点 (二)BOW(B ...

  2. MobileNetV3 实战:植物幼苗分类(pytorch)

    文章目录 摘要 mobilenetv3简介 数据增强Cutout和Mixup 项目结构 导入项目使用的库 设置全局参数 图像预处理与增强 读取数据 设置模型 定义训练和验证函数 测试 摘要 本例提取了 ...

  3. 深度学习图像分类:Kaggle植物幼苗分类(Plant Seedlings Classification)完整代码

    学习并参考了这个链接. 提前准备:anaconda3(Jupyter):tensorflow:cv2:数据集(官网可以下,点开我给的超链接) 注意:tensorflow要求python是3.5/3.6 ...

  4. ConvMAE实战:使用ConvMAE实现对植物幼苗的分类(非官方)(一)

    ConvMAE实战 摘要 安装包 1.安装timm 数据增强Cutout和Mixup 项目结构 计算mean和std 生成数据集 摘要 本文通过对植物幼苗分类的实际例子来感受一下ConvMAE模型的效 ...

  5. SENet实战详解:使用SE-ReSNet50实现对植物幼苗的分类

    摘要 1.SENet概述 ​ Squeeze-and-Excitation Networks(简称 SENet)是 Momenta 胡杰团队(WMW)提出的新的网络结构,利用SENet,一举取得最后一 ...

  6. 深度学习图像分类:植物幼苗图像分类入门(Plant Seedlings Classification)

    前言:深度学习考试期末的题目,植物幼苗分类,可以帮助农业领域的进步. 题目介绍:kaggle原题:可以下载数据集,查看一些参与者的思路等. 易用的深度学习框架Keras简介及使用 部分图片如下: 思路 ...

  7. matlab幼苗识别,基于MATLAB的植物幼苗识别

    基于MATLAB的植物幼苗识别(论文11000字,外文翻译) 摘要:杂草种类繁多,严重影响了农作物的生产与产量,使用图像处理技术识别区分杂草和作物幼苗已成为一种最科学最有效的方法.通过提取植物图像的有 ...

  8. R语言创建自定义颜色(分类变量与颜色形成稳定映射)实战:设置因子变量(分类变量)到可视化颜色的稳定映射

    R语言创建自定义颜色(分类变量与颜色形成稳定映射)实战:设置因子变量(分类变量)到可视化颜色的稳定映射 目录

  9. Kaggle竞赛方案分享:如何分辨杂草和植物幼苗(转)

    任务概览 你能分清楚杂草和庄稼苗吗? 如果能高效识别杂草,就能有效地提高粮食产量,更好地管理环境.Aarhus University Signal Processing和University of S ...

最新文章

  1. 生成pojo mysql_通过数据库表反向生成pojo类
  2. Ribbon-3使用配置文件自定义Ribbon Client
  3. Mac下安装LNMP(Nginx+PHP5.6)环境
  4. [2019CSP多校联赛普及组第五周] 调度CPU (贪心)
  5. 调用阿里云接口实现短信消息的发送源码——CSDN博客
  6. 选择排序 自带时间复杂度分析
  7. C++vector基础容器3.0
  8. 为什么要自定义ClassLoader进行类加载
  9. 【Python】Django CSRF问题
  10. 使用python爬取行政区划
  11. 什么转换器能将excel转换成pdf
  12. VMware - 虚拟机系统中无法使用键盘
  13. 计算机考研复试之操作系统
  14. 蜂巢(已更名为网易云计算基础服务)计费系统架构升级之路
  15. win10应用商店无法打开重新 加载
  16. MIPI -- mipi_CSI-2_specification_v2-1-er01.pdf
  17. 聊聊iClient for Leaflet坐标转换问题
  18. 电商数据仓库项目总结
  19. 小学计算机课程目录五年级,小学五年级信息技术课件
  20. html5标签不区分大小写对错,html5 不区分大小写、标记结束符及属性是否加引号?...

热门文章

  1. 曲率滤波的理论基础和应用
  2. 风控标签体系的使用与介绍
  3. 基于ssm快递取件及上门服务微信小程序
  4. Wdf框架之WdfObject状态机(3)-前篇
  5. 数据结构:删除顺序表中小于min和大于max的数(不需要从大到小排列依然可以)
  6. 转 新入职程序员心得
  7. 为什么《王者荣耀》的音乐让人过耳不忘? 天美讲述游戏音频设计背后的故事
  8. 笔试题-搜狐手机网Python开发工程师
  9. 机器学习实战-python3勘误
  10. 网站常用的五种布局方案