论文:Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation

非官方实现代码:https://github.com/qq995431104/Copy-Paste-for-Semantic-Segmentation

目录

一、前言

二、思路及代码


一、前言

前些天分享了一篇谷歌的数据增强论文,解读在这:https://blog.csdn.net/oYeZhou/article/details/111307717。

可能由于方法比较简单,官方没有开源代码,于是,我自己尝试在语义分割数据集上进行了实现,代码见GitHub。

先看下实现的效果:

原图:

使用复制-粘贴方法增强后:

将合成后的annotation和image叠加可视化出来:

二、思路及代码

从上面的可视化结果,可以看出,我们需要两组样本:一组image+annotation为源图,一组image+annotation为主图,我们的目的是将源图及其标注信息叠加到主图及其标注信息上;同时,需要对源图的信息做随机水平翻转、大尺度抖动/随机缩放的操作。

思路如下:

  1. 随机选取源图像(用于提取目标)、主图像(用于将所提取的目前粘贴在其之上);
  2. 分别进行随机水平翻转;
  3. 根据参数设置,对进行大尺度抖动(Large Scale Jittering,LSJ),或者仅对进行随机尺度缩放;
  4. 及其对应的分别使用公式进行合成,生成合成的图像及其对应mask;
  5. 保存图像及mask,其中,mask转为8位调色板模式保存;

具体实现的代码如下(需要你的数据集为VOC格式,如果是coco格式,需要先将coco数据集的mask提取出来,可以参考这篇博客):

"""
Unofficial implementation of Copy-Paste for semantic segmentation
"""from PIL import Image
import imgviz
import cv2
import argparse
import os
import numpy as np
import tqdmdef save_colored_mask(mask, save_path):lbl_pil = Image.fromarray(mask.astype(np.uint8), mode="P")colormap = imgviz.label_colormap()lbl_pil.putpalette(colormap.flatten())lbl_pil.save(save_path)def random_flip_horizontal(mask, img, p=0.5):if np.random.random() < p:img = img[:, ::-1, :]mask = mask[:, ::-1]return mask, imgdef img_add(img_src, img_main, mask_src):if len(img_main.shape) == 3:h, w, c = img_main.shapeelif len(img_main.shape) == 2:h, w = img_main.shapemask = np.asarray(mask_src, dtype=np.uint8)sub_img01 = cv2.add(img_src, np.zeros(np.shape(img_src), dtype=np.uint8), mask=mask)mask_02 = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)mask_02 = np.asarray(mask_02, dtype=np.uint8)sub_img02 = cv2.add(img_main, np.zeros(np.shape(img_main), dtype=np.uint8),mask=mask_02)img_main = img_main - sub_img02 + cv2.resize(sub_img01, (img_main.shape[1], img_main.shape[0]),interpolation=cv2.INTER_NEAREST)return img_maindef rescale_src(mask_src, img_src, h, w):if len(mask_src.shape) == 3:h_src, w_src, c = mask_src.shapeelif len(mask_src.shape) == 2:h_src, w_src = mask_src.shapemax_reshape_ratio = min(h / h_src, w / w_src)rescale_ratio = np.random.uniform(0.2, max_reshape_ratio)# reshape src img and maskrescale_h, rescale_w = int(h_src * rescale_ratio), int(w_src * rescale_ratio)mask_src = cv2.resize(mask_src, (rescale_w, rescale_h),interpolation=cv2.INTER_NEAREST)# mask_src = mask_src.resize((rescale_w, rescale_h), Image.NEAREST)img_src = cv2.resize(img_src, (rescale_w, rescale_h),interpolation=cv2.INTER_LINEAR)# set paste coordpy = int(np.random.random() * (h - rescale_h))px = int(np.random.random() * (w - rescale_w))# paste src img and mask to a zeros backgroundimg_pad = np.zeros((h, w, 3), dtype=np.uint8)mask_pad = np.zeros((h, w), dtype=np.uint8)img_pad[py:int(py + h_src * rescale_ratio), px:int(px + w_src * rescale_ratio), :] = img_srcmask_pad[py:int(py + h_src * rescale_ratio), px:int(px + w_src * rescale_ratio)] = mask_srcreturn mask_pad, img_paddef Large_Scale_Jittering(mask, img, min_scale=0.1, max_scale=2.0):rescale_ratio = np.random.uniform(min_scale, max_scale)h, w, _ = img.shape# rescaleh_new, w_new = int(h * rescale_ratio), int(w * rescale_ratio)img = cv2.resize(img, (w_new, h_new), interpolation=cv2.INTER_LINEAR)mask = cv2.resize(mask, (w_new, h_new), interpolation=cv2.INTER_NEAREST)# mask = mask.resize((w_new, h_new), Image.NEAREST)# crop or paddingx, y = int(np.random.uniform(0, abs(w_new - w))), int(np.random.uniform(0, abs(h_new - h)))if rescale_ratio <= 1.0:  # paddingimg_pad = np.ones((h, w, 3), dtype=np.uint8) * 168mask_pad = np.zeros((h, w), dtype=np.uint8)img_pad[y:y+h_new, x:x+w_new, :] = imgmask_pad[y:y+h_new, x:x+w_new] = maskreturn mask_pad, img_padelse:  # cropimg_crop = img[y:y+h, x:x+w, :]mask_crop = mask[y:y+h, x:x+w]return mask_crop, img_cropdef copy_paste(mask_src, img_src, mask_main, img_main):mask_src, img_src = random_flip_horizontal(mask_src, img_src)mask_main, img_main = random_flip_horizontal(mask_main, img_main)# LSJ, Large_Scale_Jitteringif args.lsj:mask_src, img_src = Large_Scale_Jittering(mask_src, img_src)mask_main, img_main = Large_Scale_Jittering(mask_main, img_main)else:# rescale mask_src/img_src to less than mask_main/img_main's sizeh, w, _ = img_main.shapemask_src, img_src = rescale_src(mask_src, img_src, h, w)img = img_add(img_src, img_main, mask_src)mask = img_add(mask_src, mask_main, mask_src)return mask, imgdef main(args):# input pathsegclass = os.path.join(args.input_dir, 'SegmentationClass')JPEGs = os.path.join(args.input_dir, 'JPEGImages')# create output pathos.makedirs(args.output_dir, exist_ok=True)os.makedirs(os.path.join(args.output_dir, 'SegmentationClass'), exist_ok=True)os.makedirs(os.path.join(args.output_dir, 'JPEGImages'), exist_ok=True)masks_path = os.listdir(segclass)tbar = tqdm.tqdm(masks_path, ncols=100)for mask_path in tbar:# get source mask and imgmask_src = np.asarray(Image.open(os.path.join(segclass, mask_path)), dtype=np.uint8)img_src = cv2.imread(os.path.join(JPEGs, mask_path.replace('.png', '.jpg')))# random choice main mask/imgmask_main_path = np.random.choice(masks_path)mask_main = np.asarray(Image.open(os.path.join(segclass, mask_main_path)), dtype=np.uint8)img_main = cv2.imread(os.path.join(JPEGs, mask_main_path.replace('.png', '.jpg')))# Copy-Paste data augmentationmask, img = copy_paste(mask_src, img_src, mask_main, img_main)mask_filename = "copy_paste_" + mask_pathimg_filename = mask_filename.replace('.png', '.jpg')save_colored_mask(mask, os.path.join(args.output_dir, 'SegmentationClass', mask_filename))cv2.imwrite(os.path.join(args.output_dir, 'JPEGImages', img_filename), img)def get_args():parser = argparse.ArgumentParser()parser.add_argument("--input_dir", default="../dataset/VOCdevkit2012/VOC2012", type=str,help="input annotated directory")parser.add_argument("--output_dir", default="../dataset/VOCdevkit2012/VOC2012_copy_paste", type=str,help="output dataset directory")parser.add_argument("--lsj", default=True, type=bool, help="if use Large Scale Jittering")return parser.parse_args()if __name__ == '__main__':args = get_args()main(args)

代码复现:Copy-Paste 数据增强for 语义分割相关推荐

  1. CV Code | 本周新出计算机视觉开源代码汇总(含目标跟踪、语义分割、姿态跟踪、少样本学习等)...

    点击我爱计算机视觉标星,更快获取CVML新技术 刚刚过去的一周出现了很多很实用.有意思.很神奇的CV代码. 比如大家期待的SiamRPN++算法,官方终于要开源了. 阿里MNN成为移动端网络部署的新选 ...

  2. CenterNet:Objects as Points论文学习笔记+代码复现(demo+训练数据)【检测部分】

    目录 1.关键部分Heatmap了解 2.Centernet论文细节: 3.尝试复现CneterNet--INSTALL.md安装: 4.尝试复现CneterNet--跑跑demo.py: 5.尝试复 ...

  3. 总结 62 种在深度学习中的数据增强方式

    数据增强 数据增强通常是依赖从现有数据生成新的数据样本来人为地增加数据量的过程 这包括对数据进行不同方向的扰动处理 或使用深度学习模型在原始数据的潜在空间(latent space)中生成新数据点从而 ...

  4. 语义分割数据增强python代码

                                                           语义分割数据增强python-pytorch代码-语义分割github项目 0. 先放gi ...

  5. 【YOLOV5-6.x讲解】数据增强方式介绍+代码实现

    主干目录: [YOLOV5-6.x 版本讲解]整体项目代码注释导航现在YOLOV5已经更新到6.X版本,现在网上很多还停留在5.X的源码注释上,因此特开一贴传承开源精神!5.X版本的可以看其他大佬的帖 ...

  6. 只讲关键点之兼容100+种关键点检测数据增强方法

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 作者丨DefTruth 编辑丨极市平台 本文介绍了已有的几种关键点检测数据增强的方法,将其的优缺点进行 ...

  7. 复现Detectron2-blendmask之冰墩墩雪容融自定义数据集语义分割

    第一节--Detectron2-BlendMask论文综述 1-1 Detectron2-BlendMask论文摘要 实例分割是计算机视觉中非常基础的任务.近来,全卷积实例分割方法得到了更多的注意力, ...

  8. python ssd目标检测_目标检测算法之SSD的数据增强策略

    前言 这篇文章是对前面<目标检测算法之SSD代码解析>,推文地址如下:点这里的补充.主要介绍SSD的数据增强策略,把这篇文章和代码解析的文章放在一起学最好不过啦.本节解析的仍然是上篇SSD ...

  9. python批量实现图像数据增强(扩增)

    什么是数据扩增? 数据扩增是对数据进行扩充的方法的总称.数据扩增可以增加训练集的样本,可以有效缓解模型过拟合的情况,也可以给模型带来的更强的泛化能力. 通常在进行数据扩增操作的时候应该保持图像原本的标 ...

  10. 目前缺少用于语义分割的 3D LiDAR 数据吗?关于三维点云数据集和方法的调查

    目前缺少用于语义分割的 3D LiDAR 数据吗?关于三维点云数据集和方法的调查 原文 Are We Hungry for 3D LiDAR Data for Semantic Segmentatio ...

最新文章

  1. 集合 (一) ----- 集合的基本概念与Collection集合详解
  2. 新手福音,机器学习工具Sklearn 中文文档 0.19版(最新)
  3. Educoder Matplotlib和Seaborn 三维图 第一关绘制三维图
  4. 利用paramiko模块实现堡垒机+审计功能
  5. 找东西背后的概率问题——From《思考的乐趣 Martix67数学笔记》
  6. Java实现智能对话机器人自动聊天+语音秒回
  7. 利用python进行身份证号码大全_身份证号码设置显示格式,我用了最笨的办法,你有什么好办法吗?...
  8. 怎样将PDF转成表格?超赞的两种PDF转Excel方法
  9. js获取ip本机地址的方法
  10. 医疗信息管理系统(HIS)——>业务介绍
  11. 像素测量工具_像素大厨PxCook for Mac(自动标注工具)中文免费版
  12. 冉宝的leetcode笔记--每日一题 8月1日
  13. #bzoj1526#分梨子(乱搞)
  14. java一道多线程题,子线程循环10次,主线程接着循环100次,如此循环50次的问题
  15. excel自动翻译-excel一键自动翻译免费
  16. python timer详解_python线程定时器Timer实现原理解析
  17. 南大科院Java工程实训
  18. 分支定界算法在中学排课问题中的应用
  19. java 取10位时间戳_java里Date 10位时间戳(Timestamp) String 相互转换
  20. 利用EXCEL获取字段的拼音首字母

热门文章

  1. 人物志 | 技术十年:美团第一位前端工程师潘魏增
  2. 接口implement
  3. 面经(5) 2020/4/5 Java研发实习生 蚂蚁金服
  4. napa与matlab,纳帕谷产区Napa Valley|酒斛网 - 与数十万葡萄酒爱好者一起发现美酒,分享微醺的乐趣...
  5. actived生命周期_初探 Vue 生命周期和钩子函数
  6. pe下找不到ssd硬盘_进入PE系统之后找不到固态硬盘
  7. 个人网页(项目)源码解析「HTML+CSS+JS」
  8. 文献调研(一):基于集成学习和能耗模式分类的办公楼小时能耗预测
  9. 学习日记——Quartus工程创建与运行
  10. Oracle get、start、edit、spool命令,临时变量、已定义变量