这里基于PyTorch框架,实现通过Faster RCNN算法检测图像中的小麦麦穗。当然,用YOLO算法也同样能够完成。本文最终实现的效果如下:

麦穗检测示例

一、数据下载

数据集名:Global Wheat Head Dataset

下载地址:www.kaggle.com/c/global-wheat-detection

更多深度学习数据集:https://www.cvmart.net/dataSets

相关论文:Global Wheat Head Detection (GWHD) Dataset: A Large and Diverse Dataset of High-Resolution RGB-Labelled Images to Develop and Benchmark Wheat Head Detection Methods

数据描述:全球麦穗数据集由来自7个国家的9个研究机构领导,东京大学、国家农业、营养和环境研究所、Arvalis、ETHZ、萨斯喀彻温大学、昆士兰大学、南京农业大学和洛桑研究所。包括全球粮食安全研究所、DigitAg、Kubota和Hiphen在内的许多机构都加入了这些机构的行列,致力于精确的小麦麦穗检测。

数据集贡献机构

数据集为室外小麦植物图像,包括来自全球各地不同平台采集的4698张RGB图像,标记了193,634个小麦麦穗,1024×1024像素,每张图像含有20~70个麦穗。2020年通过Kaggle举办了相关比赛,并在2021年更新了数据集。该数据集可以用于麦穗检测,评估穗数和大小。研究成果有助于准确估计不同品种小麦麦穗的密度和大小。

数据集示例

二、代码实战

2.1 导入所需要的包

# 导入所需要的包
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch  import torch.nn as nn
import albumentations as A   # pip install albumentations==1.1.0
from albumentations.pytorch import ToTensorV2
import torchvision
from torchvision import datasets,transforms
from tqdm import tqdm
import cv2
from torch.utils.data import Dataset,DataLoader
import torch.optim as optim
from PIL import Image
import os
import torch.nn.functional as F
import ast

2.2 参数配置

# 定义参数
LR = 1e-4
SPLIT = 0.2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 4
EPOCHS = 2
DATAPATH = '../global-wheat-detection'

2.3 读取数据

# 读取 train.csv文件
df = pd.read_csv(DATAPATH + '/train.csv')
df.bbox = df.bbox.apply(ast.literal_eval)   # # 将string of list 转成list数据  #  # 利用groupby 将同一个image_id的数据进行聚合,方式为list进行,并且用reset_index直接转变成dataframe
df = df.groupby("image_id")["bbox"].apply(list).reset_index(name="bboxes")

2.4 划分数据

# # 划分数据集
def train_test_split(dataFrame,split):  len_tot = len(dataFrame)  val_len = int(split*len_tot)  train_len = len_tot-val_len  train_data,val_data = dataFrame.iloc[:train_len][:],dataFrame.iloc[train_len:][:]  return train_data,val_data  len(df)  train_data_df,val_data_df = train_test_split(df,SPLIT)  # 划分 train val 8:2
len(train_data_df), len(val_data_df)  # 查看数据
train_data_df

2.5 构建Dataset类

# 定义WheatDataset 返回 图片,标签
class WheatDataset(Dataset):  def __init__(self,data,root_dir,transform=None,train=True):  self.data = data  self.root_dir = root_dir  self.image_names = self.data.image_id.values  self.bboxes = self.data.bboxes.values  self.transform = transform  self.isTrain = train  def __len__(self):  return len(self.data)  def __getitem__(self,index):
#         print(self.image_names)
#         print(self.bboxes)  img_path = os.path.join(self.root_dir,self.image_names[index]+".jpg")  # 拼接路径  image = cv2.imread(img_path, cv2.IMREAD_COLOR)   # 读取图片  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.float32)  # BGR2RGB  image /= 255.0    # 归一化  bboxes = torch.tensor(self.bboxes[index],dtype=torch.float64)
#         print(bboxes)  """  As per the docs of torchvision  we need bboxes in format (xmin,ymin,xmax,ymax)  Currently we have them in format (xmin,ymin,width,height)  """  bboxes[:,2] = bboxes[:,0]+bboxes[:,2]   # 格式转换 (xmin,ymin,width,height)-----> (xmin,ymin,xmax,ymax)  bboxes[:,3] = bboxes[:,1]+bboxes[:,3]
#         print(image.size,type(image))  """  we need to return image and a target dictionary  target:  boxes,labels,image_id,area,iscrowd  """  area = (bboxes[:,3]-bboxes[:,1])*(bboxes[:,2]-bboxes[:,0])   # 计算面积  area = torch.as_tensor(area,dtype=torch.float32)  # there is only one class  labels = torch.ones((len(bboxes),),dtype=torch.int64)   # 标签  # suppose all instances are not crowded  iscrowd = torch.zeros((len(bboxes),),dtype=torch.int64)  target = {}   # target是个字典 里面 包括 boxes,labels,image_id,area,iscrowd  target['boxes'] = bboxes  target['labels']= labels  target['image_id'] = torch.tensor([index])  target["area"] = area  target['iscrowd'] = iscrowd  if self.transform is not None:  sample = {  'image': image,  'bboxes': target['boxes'],  'labels': labels  }  sample = self.transform(**sample)  image = sample['image']  # 沿着一个新维度对输入张量序列进行连接。 序列中所有的张量都应该为相同形状,
#             把多个2维的张量凑成一个3维的张量;多个3维的凑成一个4维的张量…以此类推,也就是在增加新的维度进行堆叠  target['boxes'] = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)  return image,target

2.6 数据增强

# 训练与验证数据增强,利用albumentations  随机翻转转换,随机图片处理
# 对象检测的增强与正常增强不同,因为在这里需要确保 bbox 在转换后仍然正确与对象对齐
train_transform = A.Compose([  A.Flip(0.5),  ToTensorV2(p=1.0)
],bbox_params = {'format':"pascal_voc",'label_fields': ['labels']})
val_transform = A.Compose([  ToTensorV2(p=1.0)
],bbox_params = {'format':"pascal_voc","label_fields":['labels']})
`### 2.7 数据整理`"""
collate_fn默认是对数据(图片)通过torch.stack()进行简单的拼接。对于分类网络来说,默认方法是可以的(因为传入的就是数据的图片),
但是对于目标检测来说,train_dataset返回的是一个tuple,即(image, target)。
如果我们还是采用默认的合并方法,那么就会出错。
所以我们需要自定义一个方法,即collate_fn=train_dataset.collate_fn
"""
def collate_fn(batch):  return tuple(zip(*batch))

2.8 创建数据加载器

# 创建数据加载器  train_data = WheatDataset(train_data_df,DATAPATH+"/train",transform=train_transform)
valid_data = WheatDataset(val_data_df,DATAPATH+"/train",transform=val_transform)

2.9 查看数据

# 查看一个训练集中的数据
image,target = train_data.__getitem__(0)
plt.imshow(image.numpy().transpose(1,2,0))
print(image.shape)  

训练集示例

2.10 定义模型

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features,num_classes)

2.11 定义Averager类

# 这一个类来保存对应的loss
class Averager:  def __init__(self):  self.current_total = 0.0  self.iterations = 0.0  def send(self, value):  self.current_total += value  self.iterations += 1  @property  def value(self):  if self.iterations == 0:  return 0  else:  return 1.0 * self.current_total / self.iterations  def reset(self):  self.current_total = 0.0  self.iterations = 0.0

2.12 构建训练和测试 dataloader

# 构建训练和测试 dataloader
train_dataloader = DataLoader(train_data,batch_size=BATCH_SIZE,shuffle=True,collate_fn=collate_fn)
val_dataloader = DataLoader(valid_data,batch_size=BATCH_SIZE,shuffle=False,collate_fn=collate_fn)

2.13 定义模型参数

# 定义模型, 优化器,损失, 迭代,以及 学习率
train_loss = []
# val_loss = []
model = model.to(DEVICE)
params =[p for p in model.parameters() if p.requires_grad]
optimizer = optim.Adam(params,lr=LR)
loss_hist = Averager()
itr = 1
lr_scheduler=None  loss_hist = Averager()
itr = 1

2.14 模型训练

if __name__ == '__main__':  for epoch in range(EPOCHS):  loss_hist.reset()  for images, targets in train_dataloader:  # print(images)  # print(targets)  # for image in images:  #     print(image.dtype)  # torch.float32  # for t in targets:  #     for k, v in t.items():  #         print(k ,v.dtype)  images = list(image.to(DEVICE) for image in images)  targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]  loss_dict = model(images, targets)  # for loss in loss_dict.values():  #     print(loss.dtype)  losses = sum(loss for loss in loss_dict.values())  loss_value = losses.item()  loss_hist.send(loss_value)  optimizer.zero_grad()  losses.backward()  optimizer.step()  if itr % 50 == 0:  print(f"Iteration #{itr} loss: {loss_value}")  itr += 1  # update the learning rate  if lr_scheduler is not None:  lr_scheduler.step()  print(f"Epoch #{epoch} loss: {loss_hist.value}")

2.15 模型保存

# 模型保存
torch.save(model.state_dict(), 'fasterrcnn_resnet50_fpn.pth')

训练好的模型                                                whaosoft aiot http://143ai.com

2.16 加载模型进行预测

images, targets = next(iter(val_dataloader))
images = list(img.to(DEVICE) for img in images)
# print(images[0].shape)
targets = [{k: v.to(DEVICE) for k, v in t.items()} for t in targets]
boxes = targets[1]['boxes'].cpu().numpy().astype(np.int32)
sample = images[1].permute(1, 2, 0).cpu().numpy()  model.eval()
cpu_device = torch.device("cpu")
# print(images[0].shape)  outputs = model(images)
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
# print(outputs[1]['boxes'].detach().numpy().astype(np.int32))  pred_boxes = outputs[1]['boxes'].detach().numpy().astype(np.int32)  fig, ax = plt.subplots(1, 1, figsize=(16, 8))  for b, box in zip(boxes, pred_boxes):  # 绘制预测边框 红色表示  cv2.rectangle(sample,  (box[0], box[1]),  (box[2], box[3]),  (220, 0, 0), 3)  # 绘制实际边框  绿色表示  cv2.rectangle(sample,  (b[0], b[1]),  (b[2], b[3]),  (0, 220, 0), 3)  ax.set_axis_off()
ax.imshow(sample)
plt.show()  

检测结果

对比预测框与实际框,可以看出模型能够很好的预测出麦穗。可以尝试测试不同的麦穗图片,来进行测试查看效果。

PyTorch~Faster RCNNの小麦麦穗检测相关推荐

  1. 基于Faster RCNN的医学图像检测(肺结节检测)

    Faster-R-CNN算法由两大模块组成:1.PRN候选框提取模块 2.Fast R-CNN检测模块.其中,RPN是全卷积神经网络,用于提取候选框:Fast R-CNN基于RPN提取的proposa ...

  2. 计算机视觉与深度学习 | 基于Faster R-CNN的目标检测(深度学习Matlab代码)

    ===================================================== github:https://github.com/MichaelBeechan CSDN: ...

  3. 【论文解读】Faster R-CNN 实时目标检测

    前言 Faster R-CNN 的亮点是使用RPN来提取候选框:RPN全称是Region Proposal Network,也可理解为区域生成网络,或区域候选网络:它是用来提取候选框的.RPN特点是耗 ...

  4. 面试真题总结:Faster Rcnn,目标检测,卷积,梯度消失,Adam算法

    目标检测可以分为两大类,分别是什么,他们的优缺点是什么呢? 答案:目标检测算法分为单阶段和双阶段两大类.单阶段目标验测算法(one-stage),代表算法有 yolo 系列,SSD 系列:直接对图像进 ...

  5. iCAN使用faster r-cnn得到目标检测结果文件为空

    问题在于图片文件夹后少了/,添加上/后解决 -/tf-faster-rcnn/tools/Object_Detector.py --img_dir /home/featurize/Data/exima ...

  6. 目标检测算法Faster R-CNN简介

    在博文https://blog.csdn.net/fengbingchun/article/details/87091740 中对Fast R-CNN进行了简单介绍,这里在Fast R-CNN的基础上 ...

  7. 【目标检测】Faster RCNN算法详解

    转载自:http://blog.csdn.net/shenxiaolu1984/article/details/51152614 Ren, Shaoqing, et al. "Faster ...

  8. 人工智能:物体检测之Faster RCNN模型

    人工智能:物体检测之Faster RCNN模型 物体检测 Faster RCNN模型 简介 卷积层 RPN Roi Pooling Classifier 物体检测 什么是物体检测 物体检测应用场景 物 ...

  9. 深度学习之目标检测:R-CNN、Fast R-CNN、Faster R-CNN

    object detection 就是在给定的图片中精确找到物体所在位置,并标注出物体的类别.object detection 要解决的问题就是物体在哪里,是什么这整个流程的问题.然而,这个问题不是容 ...

最新文章

  1. hadoop面试记录(一)
  2. 笔记-高项案例题-2017年下-整体管理-变更管理
  3. hibernate映射之多对多双向
  4. 桌面虚拟化之运维支持
  5. invalid currency could not be saved in AG3
  6. Logtail 混合模式:使用插件处理文件日志
  7. Python 全国考级二级
  8. SAP License:糟糕的用户比任何系统问题都要危险
  9. 解析Java的JNI编程中的对象引用与内存泄漏问题
  10. Redis主从自动failover
  11. AI算法 - 粒子滤波
  12. 学计算机用苹果本,新手小白用苹果电脑搞科研,学会这些才不至于尴尬!
  13. linux oa系统搭建,企业Linux系统部署OA系统上线实例
  14. 单片机C语言关键字之extern
  15. HDS USP系列存储
  16. 将Planet卫星影像数据添加到QGIS, ArcGIS Pro 或 ArcGIS 10.X方法,以ArcGIS Pro为例。
  17. 【示波器】基于FPGA的数字示波器设计实现
  18. php立方体相册源码,纯CSS实现3D的代码(正方体、动态立体图片册、平面的星空)...
  19. 如何修复无法打开的Excel文件,三大原因三大方法为你解决
  20. 轻松禁用WinRAR设置

热门文章

  1. 让看代码成为一种享受! 使用Carbon生成漂亮的代码图片
  2. C陷阱和缺陷(C Traps and Pitfalls)-读书笔记
  3. lisp正负调换_OpenSees五问(1)
  4. OpenSees开发(二)源码分析——平面桁架静力有限元分析实例
  5. 【论文代码复现2】Clustered sampling based on sample size
  6. Ackerman函数的实现算法
  7. Ubuntu使用Foxit Reader + GoldenDict实现PDF划译
  8. 2023东北师范大学计算机考研信息汇总
  9. Portal(博图)软件的应用及程序简介
  10. 网络爬虫反反爬小技巧(六)奇淫技巧