关于理论部分我看的是b站“霹雳吧啦Wz”的SSD理论讲解,作为入门小白表示能听懂,需要的同学可以自行观看

目录

1.训练环境

2.训练步骤


1.训练环境

我的环境是win11+anaconda+python3.6.13+pytorch1.10.2+cuda11.6

2.训练步骤

(1)下载SSD源码

可到github进行下载

GitHub - amdegroot/ssd.pytorch: A PyTorch Implementation of Single Shot MultiBox Detector

(2)下载模型文件

VGG16_reducedfc.pth预训练模型下载地址:https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth

将下载的模型文件放置于ssd源码目录中  wights/vgg16_reducedfc.pth

(3)数据集准备

与大多数训练模型一样,ssd支持的训练格式为VOC和coco,这里采用voc2007作为演示,制作自己的数据集以及labimg的使用可自行观看yolo数据集标注软件安装+使用流程_道人兄的博客-CSDN博客_yolo数据集标注工具

voc2007的具体下载方式我也不多赘述,网络上百度也有,或者直接看我之前写的也有提到使用Faster—RCNN训练数据集流程(学习记录)_道人兄的博客-CSDN博客

将下载后的voc2007数据集放置于./data/VOCdevkit/中

然后到ssd.pytorch-master/data/中的voc0712.py进行修改其中的VOC_ROOT = osp.join(HOME, "data/VOCdevkit/"),他这里的HOME老是读取我的C盘位置,所以一直报错,我直接把数据集的绝对路径写上去了就没报错

将 voc0712.py文件中VOCDetection类的__init__函数,将image_sets修改为[('2007', 'train'), ('2007', 'val'),('2007','test')],修改后的结果如下。

def __init__(self, root,image_sets=[('2007', 'train'), ('2007', 'val'),('2007','test')],transform=None, target_transform=VOCAnnotationTransform(),dataset_name='VOC0712'):

其中如果是训练自己的数据集,记得修改voc0712.py文件中的VOC_CLASSES 变量。例如,将VOC_CLASSES修改为person类,注意如果只有一类则需要加方括号,修改后的结果如下。

VOC_CLASSES = [('person')

如果训练自己的数据集,还需要修改config.py文件中的voc字典变量。将其中的num_classes修改为2(以person为例)(背景类+你训练集的种类个数),第一次调试时可以将max_iter调小至1000,修改后的结果如下。

voc = {'num_classes': 2,'lr_steps': (80000, 100000, 120000),'max_iter': 1000,'feature_maps': [38, 19, 10, 5, 3, 1],'min_dim': 300,'steps': [8, 16, 32, 64, 100, 300],'min_sizes': [30, 60, 111, 162, 213, 264],'max_sizes': [60, 111, 162, 213, 264, 315],'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],'variance': [0.1, 0.2],'clip': True,'name': 'VOC',
}

最后一步,把coco_labels.txt放在ssd.pytorch-master/data/coco/目录下,也可以通过修改coco.py文件中的COCO_ROOT = osp.join(HOME, 'data/coco/')来指定存放路径。

(4)修改源码

①修改ssd.py文件中SSD类的__init__函数和forward函数,修改后的结果如下。

if phase == 'test':self.softmax = nn.Softmax(dim=-1)self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)
修改为:
if phase == 'test':self.softmax = nn.Softmax()self.detect = Detect()
if self.phase == "test":output = self.detect(loc.view(loc.size(0), -1, 4),                   # loc predsself.softmax(conf.view(conf.size(0), -1,self.num_classes)),                # conf predsself.priors.type(type(x.data))                  # default boxes)
修改为:
if self.phase == "test":output = self.detect.apply(21, 0, 200, 0.01, 0.45,loc.view(loc.size(0), -1, 4),                   # loc predsself.softmax(conf.view(-1,21)),                 # conf predsself.priors.type(type(x.data))                  # default boxes)

②修改train.py中187至189行代码,原因是.data[0]写法适用于低版本Pytorch,否则会出现IndexError:invalid index of a 0-dim tensor...错误,修改后的结果如下。

loc_loss += loss_l.item()
conf_loss += loss_c.item()if iteration % 10 == 0:print('timer: %.4f sec.' % (t1 - t0))print('iter ' + repr(iteration) + ' || Loss: %.4f ||' % (loss.item()), end=' ')

③交换layers/modules/multibox_loss.py中97行和98代码位置,否则会出现IndexError: The shape of the mask [14, 8732] at index 0does...错误,修改后的结果如下。

loss_c = loss_c.view(num, -1)
loss_c[pos] = 0  # filter out pos boxes for now

④根据自己的需要对train.py中预训练模型、batch_size、学习率、模型名字和模型保存的次数等参数进行修改。建议学习率修改为1e-4(原因是原版使用1e-3可能会出现loss为nan情况),第一次调试时可以修改为每迭代100次保存,方便调试。

# 加载模型初始参数
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With Pytorch')
train_set = parser.add_mutually_exclusive_group()
# 默认加载VOC数据集
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],type=str, help='VOC or COCO')
# 设置VOC数据集根路径
parser.add_argument('--dataset_root', default=VOC_ROOT,help='Dataset root directory path')
# 设置预训练模型vgg16_reducedfc.pth
parser.add_argument('--basenet', default='vgg16_reducedfc.pth',help='Pretrained base model')
# 设置批大小,根据自己显卡能力设置,默认为32,此处我改为16
parser.add_argument('--batch_size', default=16, type=int,help='Batch size for training')
# 是否恢复中断的训练,默认不恢复
parser.add_argument('--resume', default=None, type=str,help='Checkpoint state_dict file to resume training from')
# 恢复训练iter数,默认从第0次迭代开始
parser.add_argument('--start_iter', default=0, type=int,help='Resume training at this iter')
# 数据加载线程数,根据自己CPU个数设置,默认为4
parser.add_argument('--num_workers', default=4, type=int,help='Number of workers used in dataloading')
# 是否使用CUDA加速训练,默认开启,如果没有GPU,可改成False直接用CPU训练
parser.add_argument('--cuda', default=True, type=str2bool,help='Use CUDA to train model')
# 学习率,默认0.001
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,help='initial learning rate')
# 最佳动量值,默认0.9(动量是梯度下降法中一种常用的加速技术,用于加速梯度下降,减少收敛耗时)
parser.add_argument('--momentum', default=0.9, type=float,help='Momentum value for optim')
# 权重衰减,即正则化项前面的系数,用于防止过拟合;SGD,即mini-batch梯度下降
parser.add_argument('--weight_decay', default=1e-4, type=float,help='Weight decay for SGD')
# gamma更新,默认值0.1
parser.add_argument('--gamma', default=0.1, type=float,help='Gamma update for SGD')
# 使用visdom将训练过程loss图像可视化
parser.add_argument('--visdom', default=False, type=str2bool,help='Use visdom for loss visualization')
# 权重保存位置,默认存在weights/下
parser.add_argument('--save_folder', default='weights/',help='Directory for saving checkpoint models')
args = parser.parse_args()
if iteration != 0 and iteration % 100 == 0:print('Saving state, iter:', iteration)torch.save(ssd_net.state_dict(), 'weights/ssd300_VOC_' + repr(iteration) + '.pth')

⑤因为pytorch1.9以上版本在这份源代码中并不适用,一旦运行cuda方面会报错如下:

RuntimeError: Expected a ‘cuda‘ device type for generator but found ‘cpu‘

参考github上的解决方法,有两种方法可成功运行:

第一种是重装pytorch1.8版本,就可以正常运行,但我觉得太麻烦了

第二种是修改源码:

在位于 anaconda 或任何地方的文件“site-packages/torch/utils/data/sampler.py”中。

[修改第 116 行]:generator = torch.Generator()
改成generator = torch.Generator(device='cuda')
[修改第 126 行]:yield from torch.randperm(n, generator=generator).tolist()
改成yield from torch.randperm(n, generator=generator, device='cuda').tolist()

在train.py文件中,data.DataLoader处进行添加generator

data_loader = data.DataLoader(dataset, args.batch_size,num_workers=args.num_workers,shuffle=True, collate_fn=detection_collate,pin_memory=True, generator=torch.Generator(device='cuda'))

(5)运行train.py,如下图

参考资料:

SSD训练自己的数据集(pytorch版)_Kellenn的博客-CSDN博客_ssd训练自己的数据集pytorch

【目标检测实战】Pytorch—SSD模型训练(VOC数据集) - 知乎 (zhihu.com)

2.1SSD算法理论_哔哩哔哩_bilibili

SSD训练数据集流程(学习记录)相关推荐

  1. Cadence Allegro 设计流程学习记录

    Cadence Allegro 设计流程学习记录 前提摘要 软件设计版本: 电路仿真软件:NI Multisim 14.0,TINA-TI. 原理图设计:Design Entry CIS 16.6. ...

  2. DeeplabV3+训练数据集流程(学习记录)

    我所学习的内容来自于b站up主Bubbliiiing的课程,感兴趣也可以去看看 目录 一.源码准备 二.训练步骤 我训练的配置环境 win11+cuda11.3+pytorch1.10.2+pytho ...

  3. Android Camera 流程学习记录(五)—— Camera.takePicture() 流程解析

    简介 在前面的几篇笔记中,我已经把 Camera 控制流的部分梳理得比较清楚了.在 Camera 流程中,还有一个重要的部分,即数据流. Camera API 1 中,数据流主要是通过函数回调的方式, ...

  4. BIM概述及应用流程学习记录

    文章目录 前言 一.BIM是什么? BIM Revit : buiding information modeling 二.BIM概述及应用流程 1.什么是BIM 2.BIM如何来实施 实施的步骤: 3 ...

  5. 【深度学习】深度学习模型训练全流程!

    Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集.模型训练.模型加载和模型调参四个部分对深度学习中模型训练的全流程进行讲解. 一个成熟合格的深度学习训练流 ...

  6. 加载tf模型 正确率很低_深度学习模型训练全流程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...

  7. 深度学习模型训练全流程!

    ↑↑↑关注后"星标"Datawhale 每日干货 & 每月组队学习,不错过 Datawhale干货 作者:黄星源.奉现,Datawhale优秀学习者 本文从构建数据验证集. ...

  8. 基于深度学习的瓶子检测软件(UI界面+YOLOv5+训练数据集)

    摘要:基于深度学习的瓶子检测软件用于自动化瓶子检测与识别,对于各种场景下的塑料瓶.玻璃瓶等进行检测并计数,辅助计算机瓶子生产回收等工序.本文详细介绍深度学习的瓶子检测软件,在介绍算法原理的同时,给出P ...

  9. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

最新文章

  1. JS基础 -- 枚举对象中的属性
  2. python中的time库安装步骤-Python中time模块的使用
  3. 多语言制作工具(2013-01-24更新,支持VS2005、2008、2010、2012)(已开源)
  4. APUE读书笔记-15进程内部通信-05FIFOs
  5. Linux虚拟化KVM-Qemu分析(五)之内存虚拟化
  6. amazon php 空间,如何将PHP图像资源放入Amazon Web Services?
  7. CISCO安全 ×××技术
  8. luogu 1337
  9. mysql中表示金钱的类型
  10. 小程序方法-小程序获取上一页的数据修改上一个页面的数据
  11. Java实习日记(5)
  12. 苹果开发者账号变更公司名称
  13. 基于WebService实现设备状态监控Demo(含源码)
  14. 把数字翻译成中文的计算机,数字翻译成中文,把数字翻译成中文
  15. mongoose视频教程
  16. input框不允许输入负数
  17. 有人负责,才有质量:写给在集市中迷失的一代
  18. LinuxC学习保姆级教程(李慧芹课程笔记)
  19. 淘宝的商品管理是怎样的?
  20. match、search、findall用法区别

热门文章

  1. HTML+CSS大作业网站设计——英雄联盟LOL(4页) HTML+CSS+JavaScript web期末网站设计大作业
  2. 前端 php 需要都学吗,php需要学哪些
  3. Web报表开发:ireport
  4. 无线点菜服务器,ipad电子菜谱点菜
  5. 实用小工具-----python3 pdf2docx轻松搞定pdf转word
  6. 如何设置无线打印-惠普打印机
  7. 跑分cpu_【新机】A14芯片最新跑分成绩曝光:3GHz主频,CPU/GPU提升20%丨特斯拉又双叒降价了...
  8. 十道经典面试算法真题详解
  9. 14_从零构建微信小程序项目_数据交互_json-server详解
  10. 凯利公式自动计算表_工程成本测算难?全自动计算汇总表格,套入公式一键即可出结果...