↑↑↑关注后"星标"Datawhale

每日干货 & 每月组队学习,不错过

Datawhale干货

作者:陈信达,上海科技大学,Datawhale成员

1.起源:GAN

结构与原理

在介绍DeblurGANv2之前,我们需要大概了解一下GAN,GAN最初的应用是图片生成,即根据训练集生成图片,如生成手写数字图像、人脸图像、动物图像等等,其主要结构如下:

我们先由上图的左下方开始,假设现在只有一个样本,即batch size为1,则Random noise是一个由服从标准正态分布的随机数组成的向量。首先,我们将Random noise输入Generator,最原始GAN的Generator是一个多层感知机,其输入是一个向量,输出也是一个向量,然后我们将输出的向量reshape成一个矩阵,这个矩阵就是一张图片(一个矩阵是因为MNIST手写数据集中的图片是单通道的灰度图,如果想生成彩色图像就reshape成三个矩阵),即与上图的“8”对应。我们称Generator生成的图像为fake image,训练集中的图片为real image。

上图中的Distriminator为判别器,它是一个二分类的多层感知机,输出只有一个数,由于多层感知机只接受向量为其输入,我们将一张图片由矩阵展开为向量后再输入Discriminator,经过一系列运算后输出一个0~1之间的数,这个数越接近于0,代表着判别器认为这张图片是fake image;反之,假如输出的数越接近于1,则判别器认为这张图片是real image。为了方便,我们将Generator简称为G,Distriminator简称为D。

总而言之,G的目的是让自己生成的fake image尽可能欺骗D,而D的任务是尽可能辨别出fake image和real image,二者不停博弈。最终理想情况下,G生成的数据与真实数据非常接近,而D无论输入fake image还是real image都输出0.5。

损失函数

GAN的损失函数是Binary cross entropy loss,简称为BCELoss,其主要利用了极大似然的思想,实际上就是二分类对应的交叉熵损失函数。公式如下:

其中是样本数,是第个样本的真实值,是第个样本的预测值。对于第个样本来说,由于取值只能是0或1,此时只看第个样本,所以。当时,,而的取值范围为0~1,故当时,=0,当时,,我们的目标是使的值越小越好,即当越接近0时,的值越小。反之,当时,,越接近1时,的值越小。总之,当越接近于时,的值越小。

那么BCELoss和GAN有什么关系呢?

我们将GAN的Loss分为和,即生成器的损失和判别器的损失。

  • 对于生成器来说,它希望自己生成的图片能骗过判别器,即希望D(fake)越接近1越好,D(fake)就是G生成的图片输入D后的输出值,D(fake)接近于1意味着G生成的图片可以以假乱真来欺骗判别器,所以GLoss的公式如下所示:

当越接近1,越小,意味着生成器骗过了判别器;

  • 对于判别器来说,它的损失分为两部分,首先,它不希望自己被fake image欺骗,即与相反,这里用表示:

当越接近1,越小,意味着判别器分辨出了fake image;

其次,判别器做出判断必须有依据,所以它需要知道真实图片是什么样的才能正确地辨别假图片,这里用表示:

当越接近1,越小,意味着判别器辨别出了real image。

其实就是这两个损失值的平均值:

优化器

介绍完GAN的损失函数后,我们还剩下最后一个问题:怎么使损失函数的值越来越小?

这里就需要说一下优化器(Optimizer),优化器就是使损失函数值越来越小的工具,常用的优化器有SGD、NAG、RMSProp、Adagrad、Adam和Adam的一些变种,其中最常用的是Adam。

最终结果

由上图我们可以清楚地看出来,随着训练轮数增加,G生成的fake image越来越接近手写数字。

目前GAN有很多应用,每个应用对应的论文和Pytorch代码可以参考下面的链接,其中也有GAN的代码,大家可以根据代码进一步理解GAN:https://github.com/eriklindernoren/PyTorch-GAN

2.图像去模糊算法:DeblurGANv2

数据集

图像去模糊的数据集通常由许多组图像组成,每组图像就是一张清晰图像和与之对应的模糊图像。然而,其数据集的制作并不容易,目前常用的方法有两种,第一种是用高帧数的摄像机拍摄视频,从视频中找到连续帧中的模糊图片和清晰图片作为一组数据;第二种方法是用已知或随机生成的运动模糊核对清晰图片进行模糊操作,生成对应的一组数据。albumentations是Python中常用的数据扩增库,可以对图片进行旋转、缩放、裁剪等操作,我们也可以使用albumentations给图像增加运动模糊,具体操作如下:

首先安装albumentations库,在cmd或虚拟环境中输入:

python -m pip install albumentations

为了给图像添加运动模糊,我们需要用matplotlib库来读取、显示和保存图片。

import albumentations as A
from matplotlib import pyplot as plt
# 读取和显示原图
img = plt.imread('./images/ywxd.jpg')
plt.imshow(img)
plt.axis('off')
plt.show()

albumentations添加运动模糊操作如下,其中blur_limit是卷积核大小的范围,这里卷积核大小在150到180之间,卷积核越大,模糊效果越明显;p是进行运动模糊操作概率。

aug = A.MotionBlur(blur_limit=(50, 80), p=1.0)
aug_img = aug(image=img)['image']
plt.imshow(aug_img)
plt.axis('off')
plt.show()

如果想查看对应的模糊核,我们可以对aug这个实例调用get_params方法,这里为了大家观看方便,我使用的是3*3的卷积核。

aug = A.MotionBlur(blur_limit=(3, 3), p=1.0)
aug.get_params()
{'kernel': array([[0.        , 0.        , 0.33333334],[0.33333334, 0.33333334, 0.        ],[0.        , 0.        , 0.        ]], dtype=float32)}

我使用的数据集是DeblurGANv1的数据集,链接:https://gas.graviti.cn/dataset/datawhale/BlurredSharp

模糊图片:

清晰图片:

网络结构

DeblurGANv2的思路与GAN大致相同,区别之处在于其对GAN做了大量优化,我们先来看Generator的结构:

观察上图可以发现,G主要有两个改变:

  • 输入用模糊的图片替代了GAN中的随机向量

  • 网络结构引入了目标检测中的FPN结构,融合了多尺度的特征

另外,在特征提取部分作者提供了三种网络主干:MobileNetv2、inceptionresnetv2和densenet121,经过作者实验得出,inceptionresnetv2的效果最好,但模型较大,而MobilNetv2在不降低太大效果的基础上大大减少了网络参数,网络主干在上图中对应部分如下所示:

最后,将fpn的输出与原图进行按元素相加操作得到最终输出。

DeblurGANv2的判别器由全局和局部两部分组成,全局判别器输入的是整张图片,局部判别器输入的是随机裁剪后的图片,将输入图片经过一系列卷积操作后输出一个数,这个数代表判别器认为其为real image的概率,判别器的结构如下所示:

损失函数

DeblurGANv2与GAN差别最大的部分就是它的损失函数,我们首先看看D的loss:

D的目的是为了辨别图片的真假,所以D(fake)越小,D(real)越大时,代表D能很好地判断图片的真假,故对于D来说,越小越好

为了防止过拟合,后面还会加上一个L2惩罚项:

G的loss较D复杂很多,它由和组成,其实就是一个perceptual loss,它其实就是将real image和fake image分别输入vgg19,将输出的特征图做MSELoss(均方误差),而作者在perceptual loss的基础上又做了一些改变,公式可以总结为下式:

由公式可以很容易推断,的作用就是让G生成的图片和原图尽可能相似来达到去模糊的目的。

对于来说,其可以总结为下面公式:

由于G的目的是尽可能以假乱真骗过D,所以和越接近于1越好,即越小越好。

最后,G的loss如下所示:

作者给出的lambda为0.001,可以看出作者更注重生成图像与原图的相似性。

3.代码实践

训练自己的数据集

(目前仅支持gpu训练!)

github项目地址:https://github.com/VITA-Group/DeblurGANv2

数据地址:https://gas.graviti.cn/dataset/datawhale/BlurredSharp

首先将数据文件夹和项目文件夹按照下面结构放置:

安装python环境,在cmd中输入:

conda create -n deblur python=3.9
conda activate deblur
python -m pip install -r requirements.txt

修改config文件夹中的配置文件config.yaml:

project: deblur_gan
experiment_desc: fpntrain:files_a: &FILES_A ./dataset/train/blurred/*.png  files_b: &FILES_B ./dataset/train/sharp/*.png  size: &SIZE 256 crop: random  preload: &PRELOAD falsepreload_size: &PRELOAD_SIZE 0bounds: [0, .9]scope: geometriccorrupt: &CORRUPT- name: cutoutprob: 0.5num_holes: 3max_h_size: 25max_w_size: 25- name: jpegquality_lower: 70quality_upper: 90- name: motion_blur- name: median_blur- name: gamma- name: rgb_shift- name: hsv_shift- name: sharpenval:files_a: &FILE_A ./dataset/val/blurred/*.pngfiles_b: &FILE_B ./dataset/val/sharp/*.pngsize: *SIZEscope: geometriccrop: centerpreload: *PRELOADpreload_size: *PRELOAD_SIZEbounds: [.9, 1]corrupt: *CORRUPTphase: train
warmup_num: 3
model:g_name: resnetblocks: 9d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scaled_layers: 3content_loss: perceptualadv_lambda: 0.001disc_loss: wgan-gplearn_residual: Truenorm_layer: instancedropout: Truenum_epochs: 200
train_batches_per_epoch: 1000
val_batches_per_epoch: 100
batch_size: 1
image_size: [256, 256]optimizer:name: adamlr: 0.0001
scheduler:name: linearstart_epoch: 50min_lr: 0.0000001

如果是windows系统需要删除train.py第180行

然后在cmd中cd到项目路径并输入:

python train.py

训练结果可以在tensorboard中可视化出来:

验证集ssim(结构相似性):

验证集GLoss:

验证集PSNR(峰值信噪比):

测试(CPU、GPU均可)

  1. GPU

将测试图片以test.png保存到DeblurGANv2-master文件夹下,在CMD中输入:

python predict.py test.png

运行成功后结果submit文件夹中,predict.py中的模型文件默认为best_fpn.h5,大家也可以在DeblurGANv2的github中下载作者训练好的模型文件,保存在项目文件夹后将predict.py文件中的第93行改为想要用的模型文件即可,如将'best_fpn.h5'改为'fpn_inception.h5',但是需要将config.yaml中model对应的g_name改为相应模型,如想使用'fpn_mobilenet.h5',就将'fpn_inception'改为'fpn_mobilenet'

  1. CPU

将predict.py文件中第21行、22和65行改为下面代码即可

model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu'))['model'])
self.model = model
inputs = [img]

运行后就可以得到下面效果:

DeblurGAN的应用:优化YOLOv5性能

由上图可以看出,图片去模糊不仅可以提高YOLOv5的检测置信度,还可以使检测更准确。以Mobilenetv2为backbone的DeblurGANv2能达到图片实时去模糊的要求,进而可以使用到视频质量增强等方向。

线上训练

如果我们不想把数据集下载到本地的话可以考虑格物钛(Graviti)的线上训练功能,在原项目的基础上改几行代码即可。

首先我们打开项目文件夹中的dataset.py文件,在第一行导入tensorbay和PIL(如果没有安装tensorbay需要先pip install):

from tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDataset
from PIL import Image

我们主要修改的是PairedDatasetOnline类还有_read_img函数,为了保留原来的类,我们新建一个类,将下面代码复制粘贴到dataset.py文件中即可(记得将ACCESS_KEY改为自己空间的 Graviti AccessKey):

class PairedDatasetOnline(Dataset):def __init__(self,files_a: Tuple[str],files_b: Tuple[str],transform_fn: Callable,normalize_fn: Callable,corrupt_fn: Optional[Callable] = None,preload: bool = True,preload_size: Optional[int] = 0,verbose=True):assert len(files_a) == len(files_b)self.preload = preloadself.data_a = files_aself.data_b = files_bself.verbose = verboseself.corrupt_fn = corrupt_fnself.transform_fn = transform_fnself.normalize_fn = normalize_fnlogger.info(f'Dataset has been created with {len(self.data_a)} samples')if preload:preload_fn = partial(self._bulk_preload, preload_size=preload_size)if files_a == files_b:self.data_a = self.data_b = preload_fn(self.data_a)else:self.data_a, self.data_b = map(preload_fn, (self.data_a, self.data_b))self.preload = Truedef _bulk_preload(self, data: Iterable[str], preload_size: int):jobs = [delayed(self._preload)(x, preload_size=preload_size) for x in data]jobs = tqdm(jobs, desc='preloading images', disable=not self.verbose)return Parallel(n_jobs=cpu_count(), backend='threading')(jobs)@staticmethoddef _preload(x: str, preload_size: int):img = _read_img(x)if preload_size:h, w, *_ = img.shapeh_scale = preload_size / hw_scale = preload_size / wscale = max(h_scale, w_scale)img = cv2.resize(img, fx=scale, fy=scale, dsize=None)assert min(img.shape[:2]) >= preload_size, f'weird img shape: {img.shape}'return imgdef _preprocess(self, img, res):def transpose(x):return np.transpose(x, (2, 0, 1))return map(transpose, self.normalize_fn(img, res))def __len__(self):return len(self.data_a)def __getitem__(self, idx):a, b = self.data_a[idx], self.data_b[idx]if not self.preload:a, b = map(_read_img, (a, b))a, b = self.transform_fn(a, b)if self.corrupt_fn is not None:a = self.corrupt_fn(a)a, b = self._preprocess(a, b)return {'a': a, 'b': b}@staticmethoddef from_config(config):config = deepcopy(config)# files_a, files_b = map(lambda x: sorted(glob(config[x], recursive=True)), ('files_a', 'files_b'))segment_name = 'train' if 'train' in config['files_a'] else 'val'ACCESS_KEY = "yours"gas = GAS(ACCESS_KEY)dataset = TensorBayDataset("BlurredSharp", gas)segment = dataset[segment_name]files_a = [i for i in segment if 'blurred' == i.path.split('/')[2]]files_b = [i for i in segment if 'sharp' == i.path.split('/')[2]]transform_fn = aug.get_transforms(size=config['size'], scope=config['scope'], crop=config['crop'])normalize_fn = aug.get_normalize()corrupt_fn = aug.get_corrupt_function(config['corrupt'])# ToDo: add more hash functionsverbose = config.get('verbose', True)return PairedDatasetOnline(files_a=files_a,files_b=files_b,preload=config['preload'],preload_size=config['preload_size'],corrupt_fn=corrupt_fn,normalize_fn=normalize_fn,transform_fn=transform_fn,verbose=verbose)

再将_read_img改为:

def _read_img(x):with x.open() as fp:img = cv2.cvtColor(np.asarray(Image.open(fp)), cv2.COLOR_RGB2BGR)if img is None:logger.warning(f'Can not read image {x} with OpenCV, switching to scikit-image')img = imread(x)[:, :, ::-1]return img

最后一步将train.py第184行的datasets = map(PairedDataset.from_config, datasets)改为datasets = map(PairedDatasetOnline.from_config, datasets)即可。

更多干货

YOLO算法最全综述/YOLOv5做目标检测/YOLOv5识别表情

↓↓↓“阅读原文”获取数据集

图像去模糊算法代码实践!相关推荐

  1. 【深度学习】图像去模糊算法代码实践!

    作者:陈信达,上海科技大学,Datawhale成员 1.起源:GAN 结构与原理 在介绍DeblurGANv2之前,我们需要大概了解一下GAN,GAN最初的应用是图片生成,即根据训练集生成图片,如生成 ...

  2. 图像去模糊算法 deblur

    图像去模糊算法 循序渐进 附完整代码 https://www.cnblogs.com/cpuimage/p/9735150.html (后面要对比smartdeblur,deblur-gan) 关于图 ...

  3. 更好也更快!最先进的图像去模糊算法DeblurGAN-v2

    点击我爱计算机视觉标星,更快获取CVML新技术 以GAN为代表的生成模型正在视觉造假的路上越来越成熟,狗变猫.白马变斑马.实景变素描是GAN用于高级图像生成的例证. 能否将GAN应用于低级的图像处理呢 ...

  4. 泊松图像融合算法代码实现_部分多曝光图像融合算法(含少数通用图像融合算法)代码下载链接...

    最近弄了一篇多曝光图像融合的论文,顺便搜集了一些多曝光图像融合算法.为了方便大家,特将链接放在本文里.另外我制作和收集了100对多曝光图像,并用下述方法生成了2100张融合图像,后续将放出. 一.传统 ...

  5. 图像去模糊算法 循序渐进 附完整代码

    关于图像模糊算法的实现, 我相信大多数学习图像算法的朋友都很熟悉. 例如常见的毛玻璃效果,高斯模糊等等. 而图像模糊最简单的实现就是 在一定区域 对像素做平均值计算. 术语描述,卷积. 1.认识卷积 ...

  6. matlab实现鬼波信号压制算法  代码实践--第二篇 频率-波数域鬼波压制

    第二篇  matlab实现频率-波数域鬼波压制方法 本篇用来介绍频率-波数域鬼波压制的实现思路和压制效果 算法实现思路见2.3节,除了文中代码外,需配置鬼波压制算法工具包(https://downlo ...

  7. matlab实现鬼波信号压制算法  代码实践--第三篇 拉东域鬼波压制

    第三篇 matlab实现 拉东域鬼波压制方法 本篇用来介绍拉东域鬼波压制的实现思路和压制效果 该方法可用于水平拖缆.斜缆数据的鬼波压制 算法实现思路见3.3节,除了文中代码外,需配置鬼波压制算法工具包 ...

  8. 图像融合论文及代码网址整理总结(3)——题目中未加区分的图像融合算法

    写在前面的话: 之前写过两篇博文,针对性地整理汇总了多聚焦图像融合和红外与可见光图像融合的算法.之所以这样分类,是基于论文的标题和内容中明确指出了所处理的图像类型. 而本篇博文主要整理那些论文标题中未 ...

  9. 【图像融合学习笔记004】图像融合论文及代码网址整理总结(3)——题目中未加区分的图像融合算法

    写在前面的话: 之前写过两篇博文,针对性地整理汇总了多聚焦图像融合和红外与可见光图像融合的算法.之所以这样分类,是基于论文的标题和内容中明确指出了所处理的图像类型. 而本篇博文主要整理那些论文标题中未 ...

最新文章

  1. 计算机基础知识在线作业,福建师范大学《计算机应用基础》在线作业一答案.docx...
  2. Linux dmesg 命令学习
  3. day12_oracle hint——SQL优化过程中常见Oracle中HINT的30个用法
  4. 2019年第十届蓝桥杯国赛B组试题E-路径计数-dfs(坑题)
  5. C语言博客作业--函数
  6. Oracle 索引相关
  7. 算法总结系列之八:复读机的故事 - 散列表.NET应用的研究(下集)
  8. RHEL7忘记密码:修改root密码
  9. hibernate Day1 案例代码
  10. Android 属性动画(一)
  11. 程序员别再迷茫,赚钱,方法比你想得更多
  12. [e袋购APP]高校物业管理的特点
  13. 黄金矿工swf游戏修改
  14. linux实用技巧:ubuntu18.04安装配置ibus中文输入法
  15. 基于视频/摄像头的简单行为动作识别模型的训练步骤
  16. 计算机网络中网关、路由器、网桥、交换机概念详解
  17. 计算机组装与维护考试题a卷,计算机组装与维护考试题A卷.docx
  18. PostgreSQL的upsert功能(on conflict do)的用法
  19. 第九层(1):初识STL
  20. Android记事本

热门文章

  1. 语音识别—前端录音传给后台语音识别
  2. RealPlay在Fedora13上的安装
  3. 通过playwright获取某平台直播弹幕
  4. XNA游戏:横竖屏设置
  5. python14基础末异常文件
  6. 公司暂停社保,个人自主续交
  7. anaconda版本查看,老版本下载。
  8. Eclipse中outline的意思、和各种图标的含义
  9. 2020-10-22从np.random.normal()到正态分布的拟合
  10. 【C语言】全面解析数据在内存中的存储