轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur
作者丨科技猛兽
编辑丨极市平台
清华大学自动化系智能计算实验室团队开源基于 PyTorch 的视频 (图片) 去模糊框架 SimDeblur。
基于 PyTorch 的视频 (图片) 去模糊框架 SimDeblur
它的特点是:
- 全面: 涵盖经典的视频 (图像) 去模糊算法,如 MSCNN, SRN, DeblurGAN, EDVR, 等等。
- 高效: 支持 DDP 多机多卡训练。
- 轻量: 便于拓展,易上手,让更多的人能更快地上手使用。
- 专注: 使我们在实现自己的新模型时只需要关注一个文件或很少的几个文件。
Github link:
ljzycmd/SimDeblurgithub.com
目录
1 为什么要做这个开源框架?
1.1 怎么总是这几个baseline?
1.2 同一个baseline,在不同论文中的质量差别很大
1.3 同一个baseline,同一个数据集实验结果可比吗?
1.4 低质量的代码开源2 SimDeblur: 基于PyTorch的视频 (图片) 去模糊框架
2.1 已实现模型
2.2 使用方法
2.3 代码解读3 作者团队信息
1 为什么要做这个开源框架?
在深度学习领域,有几个问题我觉得很有必要提一下:
1.1 怎么总是这几个baseline?
比如说
在检测领域,baseline一般都有:
在分割领域,baseline一般都有:
在Vision Transformer领域,baseline一般都有:
在超分领域,baseline一般都有:
大家都不比较那些“最好”的baseline,而是去比较很 Popular 的baseline。
这就像买显卡时,
1060说:我比960好。
1080说:我比960好。
2080Ti说:我比960好。
有很多自称达到了 SOTA 的模型,涨到了比较高的性能,但是很难考证。所以后续研究者在选择比较对象的时候就会选择一些性能相对较低的,但是代码高质量开源的论文去比较。原因有2点:
- 这些论文因为代码高质量开源,所以引用量高,大家都知道且信服,比较 Popular。
- 这些论文性能相对低一点,和他们比较显得自己提出的方法厉害一点,也就更容易发论文。
这样做的好处是有百花齐放百家争鸣的感觉。但坏处是有的真正好的 baseline 模型被忽略掉了,导致了劣币驱逐良币。
如果今天你问一个你所在领域的专家,随便挑一个人,你问他:
" 我们这个任务目前最好的模型是哪个?"
他一定也很难回答。
你可能会问了:
" 这有啥难的?我直接把最新的论文都找出来,看看这个任务里面,谁超过baseline最多,谁提升的幅度最大,谁不就是最好的吗?"
这就引出了第2个问题:
1.2 同一个baseline,在不同论文中的质量差别很大
这句话的意思是说:同一个baseline模型,相同的任务,不同论文中给出的结果性能是不同的。 为什么呢?
这是因为:很多研究者对baseline的复现,其实并没有做到“全心全意”。换句话说,对baseline参数的调整其实带有相当大的随意性,对baseline的调整不会下过多的功夫,导致得到的baseline的性能没有达到其可以达到的最佳状态。
在这种情况下,如果你想比较2个自称达到了SOTA的模型的性能,因为它们对比的baseline的性能有差距,所以假设它们都相对baseline涨了3个点,但其实它们的性能是有差别的,所以就不具备很好的可比性。可能甲把baseline调得非常好,另一个乙把baseline没有调得很好,那么乙的提升就不具备很高的可信度。
你可能又会问了:
" 那我就直接找出baseline论文中给出的它在某个数据集上的性能,直接使用它的结果不就好了吗?"
这就引出了第3个问题:
- 1.3 同一个baseline,同一个数据集实验结果可比吗?
即使baseline在用一个数据集上,其实验结果也是不可比的。这是因为实验中的很多其他变量无法得到相同的控制。比如在数据预处理环节,每篇论文所列的baseline方法是否做到了完全一致?再比如在超参数的设置上,每篇论文所列的baseline方法是否做到了完全相同?
我们看下面的2张图,图1是DeiT模型的超参数设置 (DeiT是一种用于分类任务的视觉Transformer模型),图2是不同超参数设置下的模型性能对比。我们可以看到,相同的模型在相同的数据集下面,性能还是有差别的。所以这些看似不起眼的设定,其实是对模型的性能有着相对重大的影响,而这些却不会出现在引用DeiT的论文里面。所以你可能会看到:相同的模型在相同的数据集下面,结果又是会出现很大的差异。假设我们有8个超参数,每个超参数只有2种选择,那么不同的组合就多达282^{8}28种。
图1:DeiT模型的超参数设置
图2:DeiT模型不同超参数设置下的模型性能对比
总之这里想说的就是:很难保证 A 和 B 两篇论文的一切实验设置都是相同的。这就导致即使我们找到了 A 和 B 两篇在相同的模型在相同的数据集下面进行的实验,它们的结果也不是那么的可比。
你可能又会问了:
" 那很多论文都提供了开源代码,我直接下载下来在自己的任务上跑跑不就行了吗?"
这就引出了第4个问题:
1.4 低质量的代码开源
目前一篇顶会论文开源代码的最低要求是:能复现论文中所列的实验结果。但遗憾的是,许多开源代码根本无法达到这个要求。对于有些达到了这个要求的代码,它们的可重用性也非常差,想把它移植到你自己的实验环境下也十分地困难。我之前遇到过很多种奇葩的开源代码,这里随便举一个例子 (具体的论文就不说了。。)。比如它做 NAS 的论文,开源的代码里面没有 NAS 搜索的代码,只有模型的 model.py,那这样的开源代码就缺乏了最核心的 NAS 算法的开源,就是无意义的。那遇到这样的情况可能一周过去了,你还是无法复现出原论文的结果,这时候开组会时:
导师:你这周干了啥?
你:复现某某某论文失败了。
导师:这代码不是开源了吗,怎么还是复现不出来,你有没有认真做实验?
你:。。。。。。(委屈脸)
这种情况其实是很普遍且很不合理的情况,真的不是你的能力不行,而是目前领域中广泛存在的问题,Are we really making progress?所以在目前领域文章看似百花齐放的前提下,其实隐藏着一个潜在的,使领域停滞不前的问题。
这里我在举一个良性的例子。
比如2020年是视觉Transformer爆火的一年,从20年下半年开始一直持续到21年,Transformer模型被应用在了视觉的各个领域,想详细了解的童鞋们可以参考:
科技猛兽:Vision Transformer 超详细解读 (原理分析+代码解读) (目录)zhuanlan.zhihu.com
但是,在2020年爆火的Vision Transformer背后,其实是有一个重要的依托,就是**Ross Wightman大佬创建的timm库**。PyTorchImageModels,简称timm,包含很多种PyTorch的视觉模型,是一个巨大的PyTorch代码集合,包括了一系列:
- image models
- layers
- utilities
- optimizers
- schedulers
- data-loaders / augmentations
- training / validation scripts
旨在将各种SOTA模型整合在一起,并具有复现ImageNet训练结果的能力,详细的介绍如下:
科技猛兽:视觉Transformer优秀开源工作:timm库vision transformer代码解读zhuanlan.zhihu.com
许多Vision Transformer,包含高引的DeiT,CaiT等,其实都是基于timm库来实现的。所以这给了我们启发:我们需要一个benchmark平台,包含多种模型,使得它们在同一条件下得到公平的评测,这也是我们开发这一框架的初衷。
在设计这个框架时,我们的思想是:
- 首先它应该轻量,易上手,让更多的人能更快地上手使用。
- 其次它应该高效,使使用者专注于模型的实现,对于训练和评估的过程尽量少关心。
- 其次它应该灵活,适配不同的数据输入格式和实验设定。
- 最后就是专注,使我们在实现新模型时只需要关注一个文件。
2 SimDeblur: 基于PyTorch的视频 (图片) 去模糊框架
2.1 已实现模型
(粗体表示已经实现的模型,其他是待实现的模型)
Single Image Deblurring
- MSCNN [Paper, Project]
- SRN [Paper, Project]
Video Deblurring
- DBN [Paper, Project]
- STRCNN [paper]
- DBLRNet [Paper]
- EDVR [Paper, Project]
- STFAN [Paper, Project]
- IFIRNN [Paper]
- CDVD-TSP [Paper, Project]
- ESTRNN [Paper, Project]
Benchmarks
- GoPro [Paper, Data]
- DVD [Paper, Data]
- REDS [Paper, Data]
2.2 使用方法
1) 安装依赖
Python 3 (Conda is recommended)
Pytorch 1.5.1 (with GPU)
CUDA 10.2+
Clone the repositry or download the zip file:
git clone https://github.com/ljzycmd/SimDeblur.git
Install SimDeblur:
# create a pytorch env
conda create -n simdeblur python=3.7
conda activate simdeblur
# install the packages
cd SimDeblur
bash Install.sh
2) 使用默认的 trainer 来搭建一个训练进程,如下所示:
from simdeblur.config import build_config, merge_args
from simdeblur.engine.parse_arguments import parse_arguments
from simdeblur.engine.trainer import Trainerargs = parse_arguments()cfg = build_config(args.config_file)
cfg = merge_args(cfg, args)
cfg.args = argstrainer = Trainer(cfg)
trainer.train()
3) 单卡训练:
CUDA_VISIBLE_DEVICES=0 bash ./tools/train.sh ./config/dbn/dbn_dvd.yaml 1
4) 多卡训练:
CUDA_VISIBLE_DEVICES=0,1,2,3 bash ./tools/train.sh ./config/dbn/dbn_dvd.yaml 4
train.sh:
CONFIG=$1
GPUS=$2
PORT=${PORT:=10086}
# PORT=10086
# single gpu training
if [ GPUS == 1 ]
then
echo start single GPU training
python train.py $CONFIG --gpus=$GPUS
else
echo start distributed training
# distributed training
PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \train.py $CONFIG --gpus=$GPUS
fi
5) 也可以直接通过 SimDeblur 中的函数构建各种模块:
build the a dataset:
from easydict import EasyDict as edict
from simdeblur.dataset import build_datasetdataset = build_dataset(edict({"name": "DVD","mode": "train","sampling": "n_c","overlapping": True,"interval": 1,"root_gt": "./dataset/DVD/quantitative_datasets","num_frames": 5,"augmentation": {"RandomCrop": {"size": [256, 256] },"RandomHorizontalFlip": {"p": 0.5 },"RandomVerticalFlip": {"p": 0.5 },"RandomRotation90": {"p": 0.5 },}
}))print(dataset[0])
build the model:
from simdeblur.model import build_backbonemodel = build_backbone({"name": "DBN","num_frames": 5,"in_channels": 3,"inner_channels": 64
})x = torch.randn(1, 5, 3, 256, 256)
out = model(x)
build the loss:
from simdeblur.model import build_losscriterion = build_loss({"name": "MSELoss",
})
x = torch.randn(2, 3, 256, 256)
y = torch.randn(2, 3, 256, 256)
print(criterion(x, y))
2.3 代码解读:
1 框架架构:
/configs
→ /dblrnet: dblrnet配置文件
→ /dbn: dbn配置文件
→ /edvr: edvr配置文件
→ /…/datasets: 数据集位置
/docs
/simdeblur
→ __init__.py→ /config
→ → __init__.py
→ → build.py:读取配置信息的一些函数
→ → default_config.py:默认配置信息→ /dataset
→ → __init__.py
→ → build.py:创建数据集的接口
→ → augment.py:数据增强的函数
→ → dvd.py
→ → gopro.py
→ → red.py→ /engine
→ → __init__.py
→ → parse_arguments.py
→ → trainer.py:主要的训练代码
→ → hook.py→ /model
→ → __init__.py
→ → build.py:创建模型的接口
→ → /backbone:各种 backbone 具体实现
→ → →/dblrnet:dblrnet 具体实现
→ → →/dbn:dbn 具体实现
→ → →/edvr:edvr 具体实现
→ → →/ifirnn:ifirnn 具体实现
→ → →/stfan:stfan 具体实现
→ → →/strcnn:strcnn 具体实现
→ → /layer:各种 layer 具体实现
→ → →__init__.py
→ → →non_local.py:non_local block 具体实现
→ → →res_block.py:残差块具体实现
→ → →vgg.py:VGG 块具体实现
→ → /loss:各种损失函数具体实现
→ → →__init__.py
→ → →loss.py
→ → →perceptual_loss.py
→ → /meta_arch→ /scheduler: 优化器和学习率 scheduler 函数
→ /utils: 打印日志的相关函数
/tools: 生成demo的一些工具函数,以及启动文件 train.sh
/utils: 其它涉及到的一些工具函数
/requirements.txt: 运行需要的依赖库
setup.py: 上传 PYPI 需要的文件
test.py: 模型测试的接口文件,需要传入.yaml格式的配置文件
train.py: 模型训练的接口文件,需要传入.yaml格式的配置文件
2 train.py:
import torchfrom simdeblur.config import build_config, merge_args
from simdeblur.engine.parse_arguments import parse_arguments
from simdeblur.engine.trainer import Trainerdef main():args = parse_arguments()cfg = build_config(args.config_file)cfg = merge_args(cfg, args)cfg.args = argstrainer = Trainer(cfg)trainer.train()if __name__ == "__main__":main()
build_config:根据配置文件 (.yaml) 得到配置信息cfg (字典)。
merge_args:融合命令行参数。
得到包含了所有配置信息的变量 cfg,传入Trainer类。
3 Trainer 类介绍:
(a) 定义 Trainer 类属性:
from simdeblur.dataset import build_dataset
from simdeblur.scheduler import build_optimizer, build_lr_scheduler
from simdeblur.model import build_backbone, build_meta_arch, build_loss
from simdeblur.utils.logger import LogBuffer, SimpleMetricPrinter, TensorboardWriter
from simdeblur.utils.metrics import calculate_psnr, calculate_ssim
from simdeblur.utils import dist_utilsfrom simdeblur.engine import hookslogging.basicConfig(format='%(asctime)s - %(levelname)s - SimDeblur: %(message)s',level=logging.INFO)
logging.info("******* A simple deblurring framework ********")class Trainer:def __init__(self, cfg):"""Argscfg(edict): the config file, which contains arguments form comand line"""self.cfg = copy.deepcopy(cfg)# initialize the distributed trainingif cfg.args.gpus > 1:dist_utils.init_distributed(cfg)# create the working dirsself.current_work_dir = os.path.join(cfg.work_dir, cfg.name)if not os.path.exists(self.current_work_dir):os.makedirs(self.current_work_dir, exist_ok=True)self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# self.device = torch.device("cpu")# default loggerlogger = logging.getLogger("simdeblur")logger.setLevel(logging.INFO)logger.addHandler(logging.FileHandler(os.path.join(self.current_work_dir, self.cfg.name.split("_")[0] + ".json")))# construct the modulesself.model = self.build_model(cfg).to(self.device)self.criterion = build_loss(cfg.loss).to(self.device)self.train_dataloader, self.train_sampler = self.build_dataloder(cfg, mode="train")self.val_datalocaer, _ = self.build_dataloder(cfg, mode="val")self.optimizer = self.build_optimizer(cfg, self.model)self.lr_scheduler = self.build_lr_scheduler(cfg, self.optimizer)# trainer hooksself._hooks = self.build_hooks()# some induces when trainingself.epochs = 0self.iters = 0self.batch_idx = 0 self.start_epoch = 0self.start_iter = 0self.total_train_epochs = self.cfg.schedule.epochsself.total_train_iters = self.total_train_epochs * len(self.train_dataloader)# resume or load the ckpt as init-weightsif self.cfg.resume_from != "None":self.resume_or_load_ckpt(ckpt_path=self.cfg.resume_from)# log bufffer(dict to save) self.log_buffer = LogBuffer()
(b) 每个 epoch 开始前 shuffle the dataloader when dist training:
def before_epoch(self):for h in self._hooks:h.before_epoch(self)# shuffle the data when dist training ...if self.train_sampler:self.train_sampler.set_epoch(self.epochs)
(c) 每个 iteration 开始前 shuffle the dataloader when dist training:
def before_epoch(self):for h in self._hooks:h.before_epoch(self)# shuffle the data when dist training ...if self.train_sampler:self.train_sampler.set_epoch(self.epochs)
(d) 准备输入信息:
def preprocess(self, batch_data):"""prepare for input"""return batch_data["input_frames"].to(self.device)
(e) 模型输出的后处理:
def postprocess(self):"""post process for model outputs"""# When the outputs is a img tensorif isinstance(self.outputs, torch.Tensor) and self.outputs.dim() == 5:self.outputs = self.outputs.flatten(0, 1)
(f) 计算损失:
def calculate_loss(self, batch_data, model_outputs):"""calculate the loss"""gt_frames = batch_data["gt_frames"].to(self.device).flatten(0, 1)if model_outputs.dim() == 5:model_outputs = model_outputs.flatten(0, 1) # (b*n, c, h, w)return self.criterion(gt_frames, model_outputs)
(g) 优化器更新参数:
def update_params(self):"""update paramspipline: zero_grad, backward and update grad"""self.optimizer.zero_grad()self.loss.backward()self.optimizer.step()
(h) 每个 iteration 或者 epoch 结束以后,使用 hook 干一些事情,比如:lr_scheduler 更新,calculate metrics,保存日志等等,具体可以查看 /simdeblur/engine.hook.py 文件。
def after_iter(self):for h in self._hooks:h.after_iter(self)def after_epoch(self):for h in self._hooks:h.after_epoch(self)
(i) 根据以上工具函数写训练函数 train():
def train(self, **kwargs):self.model.train()self.before_train()logger = logging.getLogger("simdeblur")logger.info("Starting training...")for self.epochs in range(self.start_epoch, self.cfg.schedule.epochs):# shuffle the dataloader when dist training: dist_data_loader.set_epoch(epoch)self.before_epoch()for self.batch_idx, self.batch_data in enumerate(self.train_dataloader):self.before_iter()input_frames = self.preprocess(self.batch_data)self.outputs = self.model(input_frames)self.postprocess()self.loss = self.calculate_loss(self.batch_data, self.outputs)self.update_params()self.iters += 1self.after_iter()if self.epochs % self.cfg.schedule.val_epochs == 0:self.val()self.after_epoch()
before_epoch(), after_epoch(), before_iter(), after_iter() 这四个函数都是通过 hook 来定义每个 epoch 之前或之后,每个 iteration 之前或之后要做的事情,具体可以查看 /simdeblur/engine.hook.py 文件。
3 作者团队信息
曹铭登:
清华大学自动化系19级硕士,目前实习于腾讯 AI Lab。
邮箱:mingdengcao@gmail.com
王家豪:
清华大学自动化系19级硕士,目前实习于北京华为诺亚方舟实验室。
邮箱:wang-jh19@mails.tsinghua.edu.cn
智能计算实验室信息:
https://sites.google.com/view/iigroup-thusites.google.com
学术合作 or 沟通交流欢迎私信联系~
cite as:
@Article{wang2021simdeblur,author = {Mingdeng Cao, Jiahao Wang},title = {清华智能计算实验室团队开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur},journal = {https://zhuanlan.zhihu.com/},howpublished = {\url{https://github.com/ljzycmd/SimDeblur}},year = {2021},url= {https://zhuanlan.zhihu.com/p/368312516/},
}
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur相关推荐
- Fel轻量高效的表达式计算引擎
Fel是轻量级的高效的表达式计算引擎 Fel在源自于企业项目,设计目标是为了满足不断变化的功能需求和性能需求. Fel是开放的,引擎执行中的多个模块都可以扩展或替换.Fel的执行主要是通过函数实现,运 ...
- 港中文开源基于PyTorch的多任务人脸识别框架
点击我爱计算机视觉标星,更快获取CVML新技术 今天跟大家分享一款新晋开源的出自香港中文大学MMLab实验室的人脸识别库,其最大特点是支持人脸多任务训练,方便使用PyTorch进行人脸识别的训练.评估 ...
- 31款轻量高效的开源 JavaScript 插件和库
31款轻量高效的开源 JavaScript 插件和库 目前有很多网站设计师和开发者喜欢使用由[url=http://www.kubiji.cn/forum-id261.html]JavaScript[ ...
- 达摩院智能计算实验室负责人周靖人 入选IEEE Fellow
今日记者获悉,被国际学术科技界认定为权威荣誉的IEEE Fellow,近日又添阿里面孔:达摩院智能计算实验室.大数据智能计算和搜索推荐平台负责人周靖人入选. 周靖人所在的阿里巴巴达摩院,虽然刚成立一年 ...
- 权威荣誉,达摩院智能计算实验室负责人周靖人入选IEEE Fellow
记者刚刚获悉,达摩院智能计算实验室.大数据智能计算和搜索推荐平台负责人周靖人入选IEEE Fellow. IEEE Fellow ,即国际电子电气工程师学会会士,为协会最高等级会员,是该组织授予的最高 ...
- 博士申请 | 香港理工大学智能计算实验室招收机器学习方向全奖博士/RA/博后
合适的工作难找?最新的招聘信息也不知道? AI 求职为大家精选人工智能领域最新鲜的招聘信息,助你先人一步投递,快人一步入职! 香港理工大学 香港理工大学 (The Hong Kong Polytech ...
- CYQ.Data 轻量数据层之路 使用篇-MAction 数据查询 视频 D (二十一)
2019独角兽企业重金招聘Python工程师标准>>> 说明: 本次录制主要为使用篇:CYQ.Data 轻量数据层之路 使用篇二曲 MAction 数据查询(十三) 的附加视频教 ...
- 工业物联网IIoT环境下,为边缘计算提供的基于区块链的机械学习安全框架
A Blockchain-Based Machine Learning Framework for Edge Services in IIoT 工业物联网导论 新兴的工业互联网技术与智能生产 问题分析 ...
- 轻量高效的开源JavaScript插件和库 【转】
图片 布局 轮播图 弹出层 音频视频 编辑器 字符串 表单 存储 动画 时间 其它 加载器 构建工具 测试 包管理器 CDN 图片 baguetteBox.js- 是一个简单易用的响应式图像灯箱效果脚 ...
最新文章
- 中国电信天翼Live究竟胜算几何?
- 两年JAVA程序员的面试总结
- 使用脚本创建查找修改销毁游戏对象
- ecplice中class.forname一直报错_Python怎么把文件内容读取出来,怎么把内容写入文件中
- Django死活不跳转的问题
- 做演员是圆梦 做生意学会面对现实
- python控制其它软件_从另一个脚本控制python脚本
- cordova 实现网页缓存_如何解决ionic,cordova混合开发的app缓存大的问题
- python调用TensorFlow时报错:FutureWarning: Passing (type, 1) or ‘1type‘ as a synonym of type is deprecated
- ubuntu 常用指令
- 容器技术Docker K8s 46 Serverless Kubernetes(ASK)详解-场景应用
- Matlab 检测直线并求解直线方程
- Matlab分析dac模拟信号,[滤波器在音频DAC测试中的应用] 音频滤波器
- css压缩有啥好处呢?
- 再记公式弱爆了!用 ChatGPT 将 Excel 工作效率提高 10 倍
- 重磅!全球Top 1000计算机科学家h指数公布:中国53位学者上榜
- 余弦定理对比文本相似度实现查重
- Integer、new Integer()和int的区分与比较
- 2017南京师范大学计算机学院录取名单,关于公布南京师范大学2017年硕士研究生复试成绩及录取名单的通知...
- 接口测试的流程和步骤,主要测试哪些方面,测试工具,测试用例,以及测试框架
热门文章
- BZOJ.2707.[SDOI2012]走迷宫(期望 Tarjan 高斯消元)
- 台式机 双显卡切换实战
- Selenium2Library关键字(1)
- sql server中创建链接服务器图解教程
- 和quot;分别是什么?
- html怎么循环输出_for 循环疑难点
- ps怎么缩放图层大小_【无机纳米材料科研制图——Photoshop 0404】PS排列扫描透射电子显微镜图TEM/STEM...
- ic卡复制软件_使用MCT复制IC卡0扇区的方法(适用于NFC手机复制或模拟门禁卡)...
- 为什么我的论文没人引用?
- mysql udf禁用_Mysql数据库UDF的安全问题利用