Unsupervised Deep Homography - Pytorch实现

  • 前言
    • 使用说明
    • 代码实现

前言

Unsupervised Deep Homography: A Fast and Robust Homography Estimation
Model

Ty Nguyen, Steven W. Chen, Shreyas S. Shivakumar, Camillo J. Taylor, Vijay
Kumar
这篇论文的Pytorch实现,代码地址 unsupervisedDeepHomography-pytorch. 喜欢的朋友给个⭐哦

  • 2021.4.4更新,新增TensorBoard可视化和一些度量指标,快来下载使用吧

使用说明

进入code/文件夹
主要有三个py文件,分别是:
dataset.py: 实现了torch加载合成数据集的Dataset类
homography_model.py: 无监督单应性模型的实现
homography_CNN_synthetic.py: 训练与测试过程

一. 准备合成数据集
下载COCO2014数据集,分别设置训练集和测试集的路径RAW_DATA_PATH和TEST_RAW_DATA_PATH

python utils/gen_synthetic_data.py --mode train

大概要经过几个小时生成100.000个合成数据样本。

python utils/gen_synthetic_data.py --mode test

运行完成后生成的合成数据集文件列表如下:

二、训练模型

python homography_CNN_synthetic.py --mode train

三、测试模型
下载预训练模型并存放在models/synthetic_models文件夹下

链接:https://pan.baidu.com/s/102ilb5HJGydpeHtYelx_Xw    提取码:boq9
python homography_CNN_synthetic.py --mode test

运行结果:

results

代码实现

下面对无监督单应性模型的代码实现进行简单分析:

homography_model.py

import numpy as np
from utils.torch_spatial_transformer import transformer
import torch
from torch import nn
import torch.nn.functional as Fclass ConvBlock(nn.Module):def __init__(self, inchannels, outchannels, batch_norm=False, pool=True):super(ConvBlock, self).__init__()layers = []layers.append(nn.Conv2d(inchannels, outchannels, kernel_size=3, padding=1))layers.append(nn.ReLU(inplace=True))if batch_norm:layers.append(nn.BatchNorm2d(outchannels))layers.append(nn.Conv2d(outchannels, outchannels, kernel_size=3, padding=1))layers.append(nn.ReLU(inplace=True))if batch_norm:layers.append(nn.BatchNorm2d(outchannels))if pool:layers.append(nn.MaxPool2d(2, 2))self.layers = nn.Sequential(*layers)def forward(self, x):return self.layers(x)class HomographyModel(nn.Module):def __init__(self, batch_norm=False):super(HomographyModel, self).__init__()self.feature = nn.Sequential(ConvBlock(2, 64, batch_norm),ConvBlock(64, 64, batch_norm),ConvBlock(64, 128, batch_norm),ConvBlock(128, 128, batch_norm, pool=False),)self.fc = nn.Sequential(nn.Dropout(0.5),nn.Linear(128 * 16 * 16, 1024),nn.ReLU(inplace=True),nn.Dropout(0.5),nn.Linear(1024, 8))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def forward(self, I1_aug, I2_aug, I_aug, h4p, patch_indices):batch_size, _, img_h, img_w = I_aug.size()_, _, patch_size, patch_size = I1_aug.size()y_t = torch.arange(0, batch_size * img_w * img_h,img_w * img_h)batch_indices_tensor = y_t.unsqueeze(1).expand(y_t.shape[0], patch_size * patch_size).reshape(-1)M_tensor = torch.tensor([[img_w / 2.0, 0., img_w / 2.0],[0., img_h / 2.0, img_h / 2.0],[0., 0., 1.]])if torch.cuda.is_available():M_tensor = M_tensor.cuda()batch_indices_tensor = batch_indices_tensor.cuda()M_tile = M_tensor.unsqueeze(0).expand(batch_size, M_tensor.shape[-2], M_tensor.shape[-1])# Inverse of MM_tensor_inv = torch.inverse(M_tensor)M_tile_inv = M_tensor_inv.unsqueeze(0).expand(batch_size, M_tensor_inv.shape[-2],M_tensor_inv.shape[-1])pred_h4p = self.build_model(I1_aug, I2_aug)H_mat = self.solve_DLT(h4p, pred_h4p).squeeze(1)pred_I2 = self.transform(patch_size, M_tile_inv, H_mat, M_tile,I_aug, patch_indices, batch_indices_tensor)l1_loss = F.l1_loss(pred_I2, I2_aug)out_dict = {}out_dict.update(l1_loss=l1_loss, pred_h4p=pred_h4p)return out_dictdef build_model(self, I1_aug, I2_aug):model_input = torch.cat([I1_aug, I2_aug], dim=1)x = self.feature(model_input)x = x.view(x.size(0), -1)x = self.fc(x)return xdef solve_DLT(self, src_p, off_set):...return Hdef transform(self, patch_size, M_tile_inv, H_mat, M_tile, I, patch_indices, batch_indices_tensor):# Transform H_mat since we scale image indices in transformerbatch_size, num_channels, img_h, img_w = I.size()# if torch.cuda.is_available():#     M_tile_inv = M_tile_inv.cuda()H_mat = torch.matmul(torch.matmul(M_tile_inv, H_mat), M_tile)# Transform image 1 (large image) to image 2out_size = (img_h, img_w)warped_images, _ = transformer(I, H_mat, out_size)# Extract the warped patch from warped_images by flatting the whole batch before using indices# Note that input I  is  3 channels so we reduce to graywarped_gray_images = torch.mean(warped_images, dim=3)warped_images_flat = torch.reshape(warped_gray_images, [-1])patch_indices_flat = torch.reshape(patch_indices, [-1])pixel_indices = patch_indices_flat.long() + batch_indices_tensorpred_I2_flat = torch.gather(warped_images_flat, 0, pixel_indices)pred_I2 = torch.reshape(pred_I2_flat, [batch_size, patch_size, patch_size, 1])return pred_I2.permute(0, 3, 1, 2)

无监督单应性模型的输入有I1_aug, I2_aug, I_aug, h4p, patch_indicesI1_augI2_aug是从一对变换图像IAI^AIA和IBI^BIB的相同位置裁剪出的patch对,I_aug是用来变换的图像IAI^AIA,h4ppatch_indices表示PA的四个顶点坐标以及PA在图像IAI^AIA上的索引。

整个无监督单应性模型主要由三个部分组成:
第一部分是一个VGG风格的回归网络,即self.build_model()。该部分的输入是patch对I1_augI2_aug,输出是这两个patch对之间的位移量pred_h4p

pred_h4p = self.build_model(I1_aug, I2_aug)

第二部分是TensorDLT层,即self.solve_DLT()。C4ptA\mathbf{C}_{4 p t}^{A}C4ptA​(h4p)是PA的四个顶点的坐标,加上H~4pt\mathbf{\tilde{H}}_{4 p t}H~4pt​(pred_h4p)就得到了对应的C~4ptB\tilde{\mathbf{C}}_{4 p t}^{B}C~4ptB​。通过直接线性变换(DLT)方法可以从C4ptA\mathbf{C}_{4 p t}^{A}C4ptA​和C~4ptB\tilde{\mathbf{C}}_{4 p t}^{B}C~4ptB​中估计单应性矩阵的9个参数。

H_mat = self.solve_DLT(h4p, pred_h4p).squeeze(1)


第三部分是空间变换层,即self.transform()。该部分对图像IAI^AIA的像素坐标xi\mathscr{\mathbf{x}_{i}}xi​应用Tensor DLT层的输出单应性矩阵H~\mathbf{\tilde{H}}H~,得到变换后的图像IA(H(xi))I^{A}\left(\mathscr{H}\left(\mathbf{x}_{i}\right)\right)IA(H(xi​))。

warped_images, _ = transformer(I, H_mat, out_size)

注意这里得到的warped_images是IAI^AIA变换后的大图,因此再通过索引patch_indices得到P~B\tilde{\mathbf{P}}^\mathbf{B}P~B(pred_I2)。

最后,使用L1损失作为损失函数:

l1_loss = F.l1_loss(pred_I2, I2_aug)


完整代码见 unsupervisedDeepHomography-pytorch. 喜欢的朋友给个⭐哦

Unsupervised Deep Homography - Pytorch实现相关推荐

  1. Unsupervised Deep Image Stitching:首个无监督图像拼接框架(TIP2021)

    作者丨廖康@知乎 来源丨https://zhuanlan.zhihu.com/p/386863945 编辑丨3D视觉工坊 一.写在前面 图像拼接(Image Stitching)可以说是计算机视觉领域 ...

  2. UDT(【CVPR2019】Unsupervised Deep Tracking无监督目标跟踪)

    UDT是中科大.腾讯AI lab和上交的研究者提出的无监督目标跟踪算法.仔细阅读过这篇文章之后,写下一篇paper reading加深印象. 论文标题:Unsupervised Deep Tracki ...

  3. 【论文笔记】Unsupervised Deep Embedding for Clustering Analysis(DEC)

    [论文笔记]Unsupervised Deep Embedding for Clustering Analysis(DEC) 文章题目:Unsupervised Deep Embedding for ...

  4. 论文笔记:SESF-Fuse: an unsupervised deep model for multi-focus image fusion (2021)

    SESF-Fuse: an unsupervised deep model for multi-focus image fusion [引用格式]:Boyuan Ma et al. "SES ...

  5. Unsupervised Deep Anomaly Detection for Multi-Sensor Time-Series Signals-TKDE-A类-

    a25-2021-TKDE(A类)-无监督-Unsupervised Deep Anomaly Detection for Multi-Sensor Time-Series Signals-精度-基于 ...

  6. 【论文精读】Unsupervised Deep Image Stitching: Reconstructing Stitched Features to Images(无监督的深度图像拼接)

    论文下载链接 文章目录 前言 摘要 一.介绍 二.相关工作 2.1 基于特征的图像拼接 2.2 基于学习的图像拼接 2.3深度单应方法 ==>研究动机 三.无监督图像拼接 Ⅰ.无监督图像对齐 Ⅱ ...

  7. 【论文精读】Learning Edge-Preserved Image Stitching from Large-Baseline Deep Homography

    文章目录 一.论文翻译 题目:从大基线深度单应性学习边缘保留图像拼接 0摘要 1引言 2相关工作 A传统图像拼接 B深度单应方案 C深度图像拼接 3方法 3.1大基线深度单应 3.2边缘保持变形网络 ...

  8. OOD : A Self-supervised Framework for Unsupervised Deep Outlier Detection e3笔记

    然而,最近的研究(如[69].[70])表明,像素级重建方法往往过分强调低层次图像细节,而这些细节对人类感知的兴趣非常有限. 相比之下,高层图像结构的语义被忽略,但它们实际上是基于DNN的OD的关键. ...

  9. 单应性Homography梳理,概念解释,传统方法,深度学习方法

    Homography 这篇博客比较清晰准确的介绍了关于刚性变换,仿射变换,透视投影变换的理解 单应性变换 的 条件和表示 用 [无镜头畸变] 的相机从不同位置拍摄 [同一平面物体] 的图像之间存在单应 ...

  10. 图像配准:从SIFT到深度学习

    图像配准(Image Registration)是计算机视觉中的基本步骤.在本文中,我们首先介绍基于OpenCV的方法,然后介绍深度学习的方法. 什么是图像配准 图像配准就是找到一幅图像像素到另一幅图 ...

最新文章

  1. 单片机怎么学?新手怎么快速学会单片机?
  2. 软件开发过程中的思维方式 -- 如何分析问题
  3. Java集合源码解析之ArrayList
  4. android otg读取索尼相机usb_索尼新概念!即将上市全画幅无反相机α7C先睹为快
  5. 【9.28作业】论XX信息系统建设项目的范围管理
  6. 【转】C语言浮点数运算
  7. python redis模块常用_python redis 模块
  8. win7 正在锁定计算机 卡住,win7系统安装卡在正在启动windows界面的解决方法
  9. 使用nssm管理tomcat服务操作步骤
  10. java 地心坐标系(ECEF)和WGS-84坐标系(WGS84)互转
  11. matlab符号运算报错,matlab符号运算符
  12. java鼠标乱跑_光标乱跑怎么办 光标乱跑解决方法【图文】
  13. 曙光服务器如何重新设置u盘启动_u盘装曙光服务器 曙光服务器进bios设置u盘启动...
  14. 如何使用 哑节点(dummy node),高效解决问题
  15. Required request body is missing 报错解决
  16. 计算机金钱符号怎么打,€欧元符号怎么打出来?各种输入欧元的货币符号方法!...
  17. 轮循与连接-- 细雪之舞
  18. Linux需要学什么
  19. 苹果手机中病毒显示无服务器,iPhone手机真的不会“中毒”?出现这3个状况就要小心了...
  20. [转]告别写计划的烦恼!一页纸四步打造出一份牛逼的商业计划

热门文章

  1. <PCI-E> PCI-E的 x1/x4/x8/x16 四种插槽区别
  2. TIA Openness开发入门(1)
  3. php解决时间超过2038年
  4. DHCP与DHCP中继模式下获取IP地址
  5. 教你如何测试U盘读写速度?
  6. 既能被2又能被5整C语言,2012年国研究生统一考试心理学专业试题与答案
  7. 学计算机跨考航天航空,北京航空航天大学计算机考研辅导班:跨考考研经验
  8. 系列课程 ElasticSearch 之第 8 篇 —— SpringBoot 整合 ElasticSearch 做查询(分页查询)
  9. 自媒体学习教程 新手怎么开始学习自媒体
  10. 点击click触发两次事件解决办法