本代码是pytorch版本的ssd实现,来源amdegroot/ssd.pytorch

一、PhotometricDistort

class PhotometricDistort(object):def __init__(self):#定义6个操作self.pd = [RandomContrast(),ConvertColor(transform='HSV'),RandomSaturation(),RandomHue(),ConvertColor(current='HSV', transform='BGR'),RandomContrast()]self.rand_brightness = RandomBrightness()self.rand_light_noise = RandomLightingNoise()def __call__(self, image, boxes, labels):im = image.copy()im, boxes, labels = self.rand_brightness(im, boxes, labels)if random.randint(2):distort = Compose(self.pd[:-1]) #最先做RandomContrastelse:distort = Compose(self.pd[1:])  #最后做RandomContrastim, boxes, labels = distort(im, boxes, labels)return self.rand_light_noise(im, boxes, labels)

RandomBrightness(随机改变亮度):

在原有图片像素上加一个实数(实数的范围在[-32,32])

其中:random.randint(2):在0和1之间随机产生一个数,random.uniform(x, y) :将随机生成一个实数,它在 [x,y] 范围

class RandomBrightness(object):def __init__(self, delta=32):#默认delta=32,delta的范围要在0-255之间assert delta >= 0.0assert delta <= 255.0self.delta = deltadef __call__(self, image, boxes=None, labels=None):if random.randint(2):delta = random.uniform(-self.delta, self.delta)image += deltareturn image, boxes, labels

RandomContrast(随机改变对比度):

在原图像素上乘一个系数(系数的范围在[0.5,1.5])

class RandomContrast(object):def __init__(self, lower=0.5, upper=1.5):self.lower = lowerself.upper = upperassert self.upper >= self.lower, "contrast upper must be >= lower."assert self.lower >= 0, "contrast lower must be non-negative."# expects float imagedef __call__(self, image, boxes=None, labels=None):if random.randint(2):alpha = random.uniform(self.lower, self.upper)image *= alphareturn image, boxes, labels

ConvertColor(变换颜色空间):

变换颜色空间,若当前为BGR则变换到HSV,若当前为HSV变换到BGR

其中,cv2.cvtColor函数功能是变换空间

class ConvertColor(object):def __init__(self, current='BGR', transform='HSV'):self.transform = transform  #要变换到HSVself.current = current      #当前默认BGRdef __call__(self, image, boxes=None, labels=None):if self.current == 'BGR' and self.transform == 'HSV':image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)elif self.current == 'HSV' and self.transform == 'BGR':image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)else:raise NotImplementedErrorreturn image, boxes, labels

RandomSaturation(随机改变饱和度):

在HSV空间的S维度上乘一个系数(系数在范围[0.5,1.5]中随机得到一个实数)

class RandomSaturation(object):def __init__(self, lower=0.5, upper=1.5):self.lower = lowerself.upper = upperassert self.upper >= self.lower, "contrast upper must be >= lower."assert self.lower >= 0, "contrast lower must be non-negative."def __call__(self, image, boxes=None, labels=None):if random.randint(2):image[:, :, 1] *= random.uniform(self.lower, self.upper)return image, boxes, labels

RandomHue(随机改变色调):

在HSV空间的H维度随机加一个实数(实数的范围[-18.0,18.0])

class RandomHue(object):def __init__(self, delta=18.0):assert delta >= 0.0 and delta <= 360.0self.delta = deltadef __call__(self, image, boxes=None, labels=None):if random.randint(2):image[:, :, 0] += random.uniform(-self.delta, self.delta)image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0   #大于360的值减360image[:, :, 0][image[:, :, 0] < 0.0] += 360.0     #小于0的值加上360return image, boxes, labels

RandomLightingNoise(随机变换通道):

设置了6中变换方式,随机选择一种,将BGR三个通道顺序改变

class RandomLightingNoise(object):def __init__(self):self.perms = ((0, 1, 2), (0, 2, 1),(1, 0, 2), (1, 2, 0),(2, 0, 1), (2, 1, 0))def __call__(self, image, boxes=None, labels=None):if random.randint(2):swap = self.perms[random.randint(len(self.perms))]shuffle = SwapChannels(swap)  # shuffle channelsimage = shuffle(image)return image, boxes, labels

二、Expand(随机扩张图片)

将原有图片的高和宽乘以一个ratio系数,将原有图片放在扩张后图片的右下角,其他位置像素值使用均值填充,相应的bbox也进行移动

class Expand(object):def __init__(self, mean):self.mean = meandef __call__(self, image, boxes, labels):if random.randint(2):                           #随机是否进行操作return image, boxes, labelsheight, width, depth = image.shaperatio = random.uniform(1, 4)   #在[1,4]随机一个实数left = random.uniform(0, width*ratio - width)   #设置放置原图的min_x坐标top = random.uniform(0, height*ratio - height)  #设置放置原图的min_y坐标expand_image = np.zeros((int(height*ratio), int(width*ratio), depth),dtype=image.dtype)                          #初始化expand图片expand_image[:, :, :] = self.mean               #使用均值填充expand的三个通道expand_image[int(top):int(top + height),int(left):int(left + width)] = image   #将原图放在expand图像中image = expand_imageboxes = boxes.copy()                            #处理变换后的框boxes[:, :2] += (int(left), int(top))boxes[:, 2:] += (int(left), int(top))return image, boxes, labels

三、RandomSampleCrop(随机剪裁)

在图像上随机剪裁矩形区域,裁剪区域一定要包含bbox的中心点,将原始图bbox转换到剪裁区域的bbox

class RandomSampleCrop(object):def __init__(self):self.sample_options = (# using entire original input imageNone,# sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9(0.1, None),(0.3, None),(0.7, None),(0.9, None),# randomly sample a patch(None, None),)def __call__(self, image, boxes=None, labels=None):height, width, _ = image.shapewhile True:# randomly choose a modemode = random.choice(self.sample_options)if mode is None:return image, boxes, labelsmin_iou, max_iou = modeif min_iou is None:min_iou = float('-inf')if max_iou is None:max_iou = float('inf')# max trails (50)for _ in range(50):current_image = imagew = random.uniform(0.3 * width, width)  #裁剪的w范围[0.3*width, width]h = random.uniform(0.3 * height, height)#裁剪的h范围[0.3*height, height]# aspect ratio constraint b/t .5 & 2,如果长宽比不在[0.5,2]之间就重新尝试if h / w < 0.5 or h / w > 2:continueleft = random.uniform(width - w)        #裁剪图像的min_xtop = random.uniform(height - h)        #裁剪图像的max_x# 得到裁剪图像的[min_x,min_y,max_x,max_y]rect = np.array([int(left), int(top), int(left+w), int(top+h)])# 将裁剪图像与gt的框计算IoUoverlap = jaccard_numpy(boxes, rect)# is min and max overlap constraint satisfied? if not try againif overlap.min() < min_iou and max_iou < overlap.max():continue# 从原图中剪裁新图像current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], :]# 计算gt的bbox框的中心centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0# 检查剪裁图像的min_x, min_y要分别小于bbox的中心x, ym1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])# 检查剪裁图像的max_x, max_y要分别大于bbox的中心x, ym2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])# 上述两条要求都要为Truemask = m1 * m2# 如果由不满足True的情况,就重新尝试if not mask.any():continue# 初始化当前bboxcurrent_boxes = boxes[mask, :].copy()# 获得当前各框标签current_labels = labels[mask]# 取当前各框的min_x和min_ycurrent_boxes[:, :2] = np.maximum(current_boxes[:, :2], rect[:2])# 调整bbox中min_x, min_y位置current_boxes[:, :2] -= rect[:2]# 取当前各框的max_x和max_ycurrent_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], rect[2:])# 调整bbox中max_x, max_y位置current_boxes[:, 2:] -= rect[:2]return current_image, current_boxes, current_labels

pytorch版本SSD代码分析(2)——数据增强相关推荐

  1. Bringing Old Photos Back to Life模型代码分析1(数据载入部分)

    (1)Bringing Old Photos Back to Life原理和测试 (2) Bringing Old Photos Back to Life模型代码分析1(数据载入部分) Bringin ...

  2. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度...

    「@Author:Runsen」 上次基于CIFAR-10 数据集,使用PyTorch构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. imp ...

  3. 【Pytorch】nvidia-dali——一种加速数据增强的方法

    目的 问题: 当我们使用pytorch训练小模型或者使用较大batch size的时候会发现GPU利用率很低,训练周期比较长.其原因之一是在dataloader加载数据之后在cpu上做一些数据增强的操 ...

  4. 【小白学习PyTorch教程】八、使用图像数据增强手段,提升CIFAR-10 数据集精确度

    @Author:Runsen 上次基于CIFAR-10 数据集,使用PyTorch ​​构建图像分类模型的精确度是60%,对于如何提升精确度,方法就是常见的transforms图像数据增强手段. im ...

  5. BBAug: 一个用于PyTorch的物体检测包围框数据增强包

    本文转载自AI公园. 作者:Harpal Sahota 编译:ronghuaiyang 导读 实现了Google Research,Brain Team中的增强策略. 像许多神经网络模型一样,目标检测 ...

  6. Aspose最版本aspose-words:jdk17:23.4 版本,代码分析心得

    aspose 为收费软件,以下仅仅用于学习技术,请勿做任何商业用途,如果需要请到官网购买正版! 官网地址:On Premise, Cloud & App Based Microsoft Wor ...

  7. 从零开始学Pytorch(十五)之数据增强

    图像增广 在深度卷积神经网络里我们提到过,大规模数据集是成功应用深度神经网络的前提.图像增广(image augmentation)技术通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从 ...

  8. 深度学习超分辨率数据处理代码(包含数据增强,随机裁剪,最终保存为h5文件)

    import argparse import glob import h5py import numpy as np import PIL.Image as pil_image from utils ...

  9. python ssd目标检测_目标检测算法之SSD的数据增强策略

    前言 这篇文章是对前面<目标检测算法之SSD代码解析>,推文地址如下:点这里的补充.主要介绍SSD的数据增强策略,把这篇文章和代码解析的文章放在一起学最好不过啦.本节解析的仍然是上篇SSD ...

最新文章

  1. UI培训教程分享:UI设计的分类有哪些?
  2. 如何教计算机认识手写数字(上)
  3. Android自动化测试之monkeyrunner基本要素(七)
  4. 滴滴产品总监:如何合理设计弹窗以保证流畅的用户体验?
  5. Overlapped I/O模型深入分析[转]
  6. NeurIPS 2021 Transformer部署难?北大华为诺亚提出Vision Transformer的后训练量化方法...
  7. 1-2 软件构造的质量目标
  8. Spring Cloud学习笔记---Spring Cloud Sleuth--新建两个互相调用的服务测试zipkin
  9. 大数据hadoop入门 总结图
  10. 5步绘制软件开发流程图
  11. [Luogu P4630] [BZOJ 5463] [APIO2018] Duathlon 铁人两项
  12. 盗心贼歌曲用计算机多少数字,抖音上常见背景音乐歌词盗心的贼是那首歌?
  13. 怎么用计算机进行气象预报,行测言语理解与表达片段阅读:1、中央气象台进行天气预报,先用计算机解出描述天气演变的方程组...
  14. 国内的微软更新服务器地址,windows update 服务器
  15. 单片机温度传感器c语言编码,温度传感器代码解析Ⅱ
  16. 视频加水印怎么加?简单的方法
  17. Python操作excel基础
  18. SMART PLC和V90伺服实现外部脉冲位置控制
  19. PHP的一些常用算法
  20. 个人笔记:数据库——数据库如何进行备份?

热门文章

  1. CentOS 7 安装惠普打印机驱动
  2. navicat添加外键_navicat怎么建立外键
  3. keystone -- An unexpected error prevented the server from fulfilling your request. 错误
  4. VMware Tools 其实很重要
  5. python实现数字签名2
  6. 洛谷P2199-最后的迷宫(BFS)
  7. 如何让IE8默认启动InPrivate浏览模式
  8. 基因组生物信息学实验(三):基因组模拟测序(1)
  9. 小程序租赁服务器,小程序服务器租赁
  10. java hdfs文件_使用Java访问HDFS中的文件