本文参考文献:Progressive Image Deraining Networks: A Better and Simpler Baseline Dongwei Ren1, Wangmeng Zuo2, Qinghua Hu1, Pengfei Zhu1, and Deyu Meng31College of Computing and Intelligence, Tianjin University, Tianjin, China 2School of Computer Science and Technology, Harbin Institute of Technology, Harbin, China 3Xi’an Jiaotong University, Xi’an, China

论文下载网址:[1901.09221] Progressive Image Deraining Networks: A Better and Simpler Baseline (arxiv.org)https://arxiv.org/abs/1901.09221

论文作者提供的Github实验源码:https://github.com/csdwren/PReNet.

关于论文:本文参考论文的主要贡献是提出了一种简单易实现且有较好效果的去雨网络架构PreNet,虽然其所用技术不是最先进的,但却有着十分优秀的去雨效果,因此作者认为这是一种可供众多研究者学习和实验对比使用的优良的基准模型。而在我看来,正是由于模型的以上这些特点,PreNet也十分适合深度学习去雨入门研究者进行学习和实现。

接下来将仔细说明一种十分简便的实现方法以及部分实验原理。

实验环境

首先介绍实现的实验环境。新手建议首选在以下网址注册账号(注册需科学上网,后续使用不必)利用线上环境写Pytorch代码,并将模型放网站提供的免费云算力服务器上训练。可以省去新手配置本地cuda环境的烦恼。

Kaggle: Your Home for Data Sciencehttps://www.kaggle.com/

点击左侧 Code

再点击 New Notebook 即可开启线上编程环境

线上编程环境和Jupyter类似,进入编程界面后可以在上方菜单栏设置界面外观选项(可以选择添加行号以便于查看代码)。右侧可选择加速器,我推荐使用GPU P100。

代码编写及调试完成后点击右上角 Save Version 来讲模型放在云GPU上训练(注意保存的版本无法手动删除,因此一定要确认代码调试无误后再点击 Save Version,以免版本太多造成的混乱)。

以上即是对编程环境的基本介绍,下面详细介绍实现步骤。

具体实现

图像去雨任务和图像分类任务的处理流程相似,都是:数据处理 --> 模型构建 --> 训练 --> 记录训练信息及模型保存。接下来将结合代码详细介绍。

数据处理:

'''
Dataset for Training.
'''
class MyTrainDataset(Dataset):def __init__(self, input_path, label_path):self.input_path = input_pathself.input_files = os.listdir(input_path)self.label_path = label_pathself.label_files = os.listdir(label_path)self.transforms = transforms.Compose([transforms.CenterCrop([64, 64]), transforms.ToTensor(),])def __len__(self):return len(self.input_files)def __getitem__(self, index):label_image_path = os.path.join(self.label_path, self.label_files[index])label_image = Image.open(label_image_path).convert('RGB')'''Ensure input and label are in couple.'''temp = self.label_files[index][:-4]self.input_files[index] = temp + 'x2.png'input_image_path = os.path.join(self.input_path, self.input_files[index])input_image = Image.open(input_image_path).convert('RGB')input = self.transforms(input_image)label = self.transforms(label_image)return input, label'''
Dataset for testing.
'''
class MyValidDataset(Dataset):def __init__(self, input_path, label_path):self.input_path = input_pathself.input_files = os.listdir(input_path)self.label_path = label_pathself.label_files = os.listdir(label_path)self.transforms = transforms.Compose([transforms.CenterCrop([64, 64]), transforms.ToTensor(),])def __len__(self):return len(self.input_files)def __getitem__(self, index):label_image_path = os.path.join(self.label_path, self.label_files[index])label_image = Image.open(label_image_path).convert('RGB')temp = self.label_files[index][:-4]self.input_files[index] = temp + 'x2.png'input_image_path = os.path.join(self.input_path, self.input_files[index])input_image = Image.open(input_image_path).convert('RGB')input = self.transforms(input_image)label = self.transforms(label_image)return input, label

上面的代码分为两个部分:分别是训练集和测试集的Dataset类的重写。这是自定义Pytorch数据集处理方式的比较方便的处理方式。由于训练集和测试集的处理方式一致,这里仅对训练集的处理方式进行介绍。

首先我们需要明白,为什么我们特别地需要重写__init__, __length__, __getitem__ 这三个Dataset()类的方法,因为后续处理中封装用的DataLoader类需要调用Dataset对象的这三个函数来获取数据集的相关信息,这个关系可以理解为:DataLoader类负责将数据切分为很多个批次(batch)以分批次进行训练,而Dataset负责记录数据整体信息处理每一批次中的每一对标签和输入数据的内容。换句话说,Dataset类只负责记录整体数据信息处理一对标签和输入数据对,而DataLoader将Dataset的处理方式循环地应用到整个数据集上。因此,对于不同的数据集我们要重写Dataset类的这三个函数以改变DataLoader处理数据的方式。

明白了以上内容,我们就好理解这三个函数的运作方式了:__init__和 __length__负责记录数据集的一些基本信息,__length__的内容一定是返回输入数据项的长度(不可更改),而__init__用于初始化你需要用到的一些基本变量(可高度自定义),这些变量将在__getitem__中被调用。__getitem__就负责处理每一对数据对的匹配输出,其关键是一定要保证最后 return 的 input 和 label 是成对的。

以下是创建Dataset和DataLoader对象的过程:

'''
Path of Dataset.
'''
input_path = "../input/jrdr-deraining-dataset/JRDR/rain_data_train_Heavy/rain/X2"
label_path = "../input/jrdr-deraining-dataset/JRDR/rain_data_train_Heavy/norain"
valid_input_path = '../input/jrdr-deraining-dataset/JRDR/rain_data_test_Heavy/rain/X2'
valid_label_path = '../input/jrdr-deraining-dataset/JRDR/rain_data_test_Heavy/norain''''
Prepare DataLoaders.Attension:'pin_numbers=True' can accelorate CUDA computing.
'''
dataset_train = MyTrainDataset(input_path, label_path)
dataset_valid = MyValidDataset(valid_input_path, valid_label_path)
train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, pin_memory=True)
valid_loader = DataLoader(dataset_valid, batch_size=batch_size, shuffle=True, pin_memory=True)

需要注意的是我使用的是Kaggle网站上提供的线上数据集,可自行搜索添加:

你可以选择使用Heavy数据集训练,也可以使用Light,分别对应人工合成的大雨和小雨数据集。我推荐使用Heavy,使用Heavy训练出来的模型对于真实下雨场景的去雨效果更明显。

模型构建:

# 网络架构class PReNet_r(nn.Module):def __init__(self, recurrent_iter=6, use_GPU=True):super(PReNet_r, self).__init__()self.iteration = recurrent_iterself.use_GPU = use_GPUself.conv0 = nn.Sequential(nn.Conv2d(6, 32, 3, 1, 1),nn.ReLU())self.res_conv1 = nn.Sequential(nn.Conv2d(32, 32, 3, 1, 1),nn.ReLU(),nn.Conv2d(32, 32, 3, 1, 1),nn.ReLU())self.conv_i = nn.Sequential(nn.Conv2d(32 + 32, 32, 3, 1, 1),nn.Sigmoid())self.conv_f = nn.Sequential(nn.Conv2d(32 + 32, 32, 3, 1, 1),nn.Sigmoid())self.conv_g = nn.Sequential(nn.Conv2d(32 + 32, 32, 3, 1, 1),nn.Tanh())self.conv_o = nn.Sequential(nn.Conv2d(32 + 32, 32, 3, 1, 1),nn.Sigmoid())self.conv = nn.Sequential(nn.Conv2d(32, 3, 3, 1, 1),)def forward(self, input):batch_size, row, col = input.size(0), input.size(2), input.size(3)#mask = Variable(torch.ones(batch_size, 3, row, col)).cuda()x = inputh = Variable(torch.zeros(batch_size, 32, row, col))c = Variable(torch.zeros(batch_size, 32, row, col))if self.use_GPU:h = h.cuda()c = c.cuda()x_list = []for i in range(self.iteration):x = torch.cat((input, x), 1)x = self.conv0(x)x = torch.cat((x, h), 1)i = self.conv_i(x)f = self.conv_f(x)g = self.conv_g(x)o = self.conv_o(x)c = f * c + i * gh = o * torch.tanh(c)x = hfor j in range(5):resx = xx = F.relu(self.res_conv1(x) + resx)x = self.conv(x)x = input + xx_list.append(x)return x, x_list

我直接照搬的论文源码的网络架构,简单来说,该网络就是结合了LSTM和递归残差网络的处理方式。你暂时可以不用理解,直接用就行。如果想要进行深入了解的话可以查看本文开头处提供的原文链接或本站搜索论文翻译。

训练:

'''
Define optimizer and Loss Function.
'''
optimizer = optim.RAdam(net.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=epoch)loss_f = SSIM()

首先初始化优化器和损失函数,采用RAdam优化器(Adam优化器的基础上增加了warm-up的功能)并使用CosineAnnealingLR(余弦退火算法)让学习率随训练轮数呈余弦变化,以优化训练结果。

SSIM损失函数使用的是论文作者提供的源码:

# SSIM损失函数实现def gaussian(window_size, sigma):gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])return gauss/gauss.sum()def create_window(window_size, channel):_1D_window = gaussian(window_size, 1.5).unsqueeze(1)_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())return windowdef _ssim(img1, img2, window, window_size, channel, size_average = True):mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)mu1_sq = mu1.pow(2)mu2_sq = mu2.pow(2)mu1_mu2 = mu1*mu2sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sqsigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sqsigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2C1 = 0.01**2C2 = 0.03**2ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))if size_average:return ssim_map.mean()else:return ssim_map.mean(1).mean(1).mean(1)class SSIM(torch.nn.Module):def __init__(self, window_size = 11, size_average = True):super(SSIM, self).__init__()self.window_size = window_sizeself.size_average = size_averageself.channel = 1self.window = create_window(window_size, self.channel)def forward(self, img1, img2):(_, channel, _, _) = img1.size()if channel == self.channel and self.window.data.type() == img1.data.type():window = self.windowelse:window = create_window(self.window_size, channel)if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)self.window = windowself.channel = channelreturn _ssim(img1, img2, window, self.window_size, channel, self.size_average)def ssim(img1, img2, window_size = 11, size_average = True):(_, channel, _, _) = img1.size()window = create_window(window_size, channel)if img1.is_cuda:window = window.cuda(img1.get_device())window = window.type_as(img1)return _ssim(img1, img2, window, window_size, channel, size_average)

SSIM是一种评估两幅图像相似度的算法,具体原理此处不再详述,你只需要记住其值越大两张图像相似度越高,值为一则两张图象完全一样。因此我们在后续训练时需要取SSIM的负值。

下面是循环训练的代码:

'''
START Training ...
'''
for i in range(epoch):
# ---------------Train----------------net.train()train_losses = []'''tqdm is a toolkit for progress bar.'''for batch in tqdm(train_loader):inputs, labels = batchoutputs, _ = net(inputs.to(device))loss = loss_f(labels.to(device), outputs)loss = -lossoptimizer.zero_grad()loss.backward()'''Avoid grad to be too BIG.'''grad_norm = nn.utils.clip_grad_norm_(net.parameters(), max_norm=10)optimizer.step()'''Attension:We need set 'loss.item()' to turn Tensor into Numpy, or plt will not work.'''train_losses.append(loss.item())train_loss = sum(train_losses) / len(train_losses)Loss_list.append(train_loss)print(f"[ Train | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {train_loss:.5f}")scheduler.step()for param_group in optimizer.param_groups:learning_rate_list.append(param_group["lr"])print('learning rate %f' % param_group["lr"])# -------------Validation-------------'''Validation is a step to ensure training process is working.You can also exploit Validation to see if your net work is overfitting.Firstly, you should set model.eval(), to ensure parameters not training.'''net.eval()valid_losses = []for batch in tqdm(valid_loader):inputs, labels = batch'''Cancel gradient decent.'''with torch.no_grad():outputs, _ = net(inputs.to(device))loss = loss_f(labels.to(device), outputs)loss = -lossvalid_losses.append(loss.item())valid_loss = sum(valid_losses) / len(valid_losses)Valid_Loss_list.append(valid_loss)print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f}")break_point = i + 1'''Update Logs and save the best model.Patience is also checked.'''if valid_loss < best_valid_loss:print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f} -> best")else:print(f"[ Valid | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {valid_loss:.5f}")if valid_loss < best_valid_loss:print(f'Best model found at epoch {i+1}, saving model')torch.save(net.state_dict(), f'model_best.ckpt')best_valid_loss = valid_lossstale = 0else:stale += 1if stale > patience:print(f'No improvement {patience} consecutive epochs, early stopping.')break

见注释即可。其中break_point用于记录训练结束的epoch值,stale用于记录模型未进步所持续的训练轮数,patience是预设的模型未进步所持续轮数的最大值。

记录训练信息及模型保存:

部分内容是从上面的代码段截取的,由于实例较分散,我只在下面说明了其中比较典型的几个:

    Loss_list.append(train_loss)  # 用于后续绘制Loss曲线print(f"[ Train | {i + 1:03d}/{epoch:03d} ] SSIM_loss = {train_loss:.5f}")

打印日志信息

        print(f'Best model found at epoch {i+1}, saving model')torch.save(net.state_dict(), f'model_best.ckpt')

保存模型(关于所保存文件的后缀.ckpt 和 .pth的区别此处不详述,可自行搜索)

'''
Use plt to draw Loss curves.
'''
plt.figure(dpi=500)plt.subplot(211)
x = range(break_point)
y = Loss_list
plt.plot(x, y, 'ro-', label='Train Loss')
plt.plot(range(break_point), Valid_Loss_list, 'bs-', label='Valid Loss')
plt.ylabel('Loss')
plt.xlabel('epochs')plt.subplot(212)
plt.plot(x, learning_rate_list, 'ro-', label='Learning rate')
plt.ylabel('Learning rate')
plt.xlabel('epochs')plt.legend()
plt.show()

利用matplotlib库绘制训练过程中重要参数变化曲线。

实验结果:

先附上我的完整代码链接(注意前面提供的代码不完整,一些细节部分被我省略了,直接复制粘贴无法运行):

Kaggle平台:PreNet | Kagglehttps://www.kaggle.com/code/leeding123/prenet        Gihub仓库(欢迎点星):Derain_platform/prenet.ipynb at f3249f6ee4f14055bf30c53239141bccecdcb0f2 · DLee0102/Derain_platform · GitHubContribute to DLee0102/Derain_platform development by creating an account on GitHub.https://github.com/DLee0102/Derain_platform/blob/f3249f6ee4f14055bf30c53239141bccecdcb0f2/prenet.ipynb        我在Heavy训练集上的训练结果:

注:Loss曲线图中红色为训练集Loss蓝色为测试集Loss

合成数据集上的去雨效果:

真实数据集上的去雨效果:

深度学习:图像去雨网络实现Pytorch (二)一个简单实用的基准模型(PreNet)实现相关推荐

  1. 深度学习核心技术精讲100篇(二十)-如何通过树模型实现梯度提升树(GBDT)+LR,随机森林(RandomForest) +LR

    前言 在讲如何通过树模型做特征工程之前,首先让我们回顾一下一个机器学习(除去深度学习项目部分)项目的大致流程: 从业务场景中抽象出问题--分类问题,回归问题,还是聚类问题等, 接下来是数据获取,数据清 ...

  2. 深度学习 GAN生成对抗网络-1010格式数据生成简单案例

    一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...

  3. hikey970学习-012 hikey970上移植图像去雨深度神经网络算法

    摘要:本文在hikey970嵌入式设备上运行图像去雨深度神经网络算法,验证hikey970运行神经网络模型的性能与效果. 一.图像去雨简介     图像去雨处理指的是对于一张雨中的图片,去除画面中的雨 ...

  4. 图像去雨算法(基于卷积网络)

    图像去雨算法文章: https://pdfs.semanticscholar.org/bf10/3b3ea90f0d032d1d73dbb83ae41731ee006f.pdf 相应的代码和论文 ht ...

  5. 深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别

    深度学习 LSTM长短期记忆网络原理与Pytorch手写数字识别 一.前言 二.网络结构 三.可解释性 四.记忆主线 五.遗忘门 六.输入门 七.输出门 八.手写数字识别实战 8.1 引入依赖库 8. ...

  6. 【每周CV论文】深度学习图像降噪应该从阅读哪些文章开始

    欢迎来到<每周CV论文推荐>.在这个专栏里,还是本着有三AI一贯的原则,专注于让大家能够系统性完成学习,所以我们推荐的文章也必定是同一主题的. 图像降噪是图像处理领域中非常传统和经典的问题 ...

  7. Transformer 杀疯了,图像去雨、人脸幻构、风格迁移、语义分割等通通上分

    前段时间 Transformer 已席卷计算机视觉领域,并获得大量好评,如『基于Swin-Transformer』.『美团提出具有「位置编码」的Transformer,性能优于ViT和DeiT』.『L ...

  8. 【三维重建】【深度学习】windows10下NeRF代码Pytorch实现

    [三维重建][深度学习]windows10下NeRF代码Pytorch实现 提示:最近开始在[三维重建]方面进行研究,记录相关知识点,分享学习中遇到的问题已经解决的方法. 文章目录 [三维重建][深度 ...

  9. 【Pytorch深度学习实践】B站up刘二大人之BasicCNN Advanced CNN -代码理解与实现(9/9)

    这是刘二大人系列课程笔记的 最后一个笔记了,介绍的是 BasicCNN 和 AdvancedCNN ,我做图像,所以后面的RNN我可能暂时不会花时间去了解了: 写在前面: 本节把基础个高级CNN放在一 ...

最新文章

  1. jBPM3.12用户指南中文翻译----第一章 绪论
  2. TCP和UDP的最完整的区别
  3. python实现计算器代码 博客园_python实现计算器
  4. OpenGL编程轻松入门(一)
  5. 计网 --- 域名服务系统:因特网的目录服务
  6. 实现织梦dedecms百度主动推送(实时)网页抓取
  7. PHP处理图片导致内存溢出
  8. ORACLE异常处理及函数
  9. 【剑指 offer】(十九)—— 二叉树镜像
  10. XSS注入,js脚本注入后台
  11. (软件项目管理)项目会议纪要模板
  12. mdx词典包_不会用医学词典?停姐手把手教你啊(内附海量医学词典词库资源下载)...
  13. 图书馆大数据分析系统能帮助图书馆完成哪些事,意义有多大?
  14. python2在线编译器_C/C++/Python在线编译器
  15. android平板电脑怎么才能连接,安卓平板电脑如何连接网线?
  16. Buying Hay 购买干草 完全背包
  17. CentOS 7 安装 libtool 2.4.6
  18. dnf锁定计算机,DNF老电脑福音 教你如何提升游戏流畅度
  19. OpenCV3 和 Qt5 计算机视觉 学习笔记 - 图像转换
  20. servlet 3.0 slf4j +log4j2 始终无法写入log文件的解决

热门文章

  1. (10.1.6)极简主义
  2. 设计的萌芽阶段_第一章 设计的萌芽阶段 第一节 设计概念的产生
  3. DM数据库安装过程--基于NeoKylin操作系统详解
  4. laravel从入门到精通之 时区设置
  5. 8.4|Crypto Tech Night 第十六期:深入探索DeFi的“代码世界”
  6. Stanford Alpaca (羊驼):ChatGPT 学术版开源实现
  7. 如何成为一个更好的程序员,或者说是学习者?给你七个建议!
  8. idea永久激活教程:smirk:
  9. 计算机中英文字录入教案,文字录入教案.doc
  10. at89c2051 定时器用法 c语言编程资料,用AT89C2051单片机制作的数字电容表.doc