概述

先看一下模型的最后效果

目前计算机视觉的发展如火如荼,让老照片动起来让大家惊艳于计算机视觉的发展。该如何实现该功能呢?本文将一步一步的帮您实现。在阅读本文之前建议有一些基本知识。

1、泰勒公式与雅各比行列式

2、图像上下采样

3、反向光流

本文将分为几部分来完成。本章主要讲解关键点抽取模型,该模型主要为了定位关键点,将物体整体动画分解为背景(不变的)+关键点变化

模型结构

本文是关于first order motion model的pytorch具体实现

引入相关pytorch类库

from torch import nn
import torch
import torch.nn.functional as F
from imageio import imread
import numpy as np
from torchvision import models

1、模型配置

dataset_params:root_dir: data/vox-pngframe_shape: [256, 256, 3]id_sampling: Truepairs_list: data/vox256.csvaugmentation_params:flip_param:horizontal_flip: Truetime_flip: Truejitter_param:brightness: 0.1contrast: 0.1saturation: 0.1hue: 0.1model_params:common_params:num_kp: 10num_channels: 3estimate_jacobian: Truekp_detector_params:temperature: 0.1block_expansion: 32max_features: 1024scale_factor: 0.25num_blocks: 5generator_params:block_expansion: 64max_features: 512num_down_blocks: 2num_bottleneck_blocks: 6estimate_occlusion_map: Truedense_motion_params:block_expansion: 64max_features: 1024num_blocks: 5scale_factor: 0.25discriminator_params:scales: [1]block_expansion: 32max_features: 512num_blocks: 4sn: Truetrain_params:num_epochs: 100num_repeats: 75epoch_milestones: [60, 90]lr_generator: 2.0e-4lr_discriminator: 2.0e-4lr_kp_detector: 2.0e-4batch_size: 40scales: [1, 0.5, 0.25, 0.125]checkpoint_freq: 50transform_params:sigma_affine: 0.05sigma_tps: 0.005points_tps: 5loss_weights:generator_gan: 0discriminator_gan: 1feature_matching: [10, 10, 10, 10]perceptual: [10, 10, 10, 10, 10]equivariance_value: 10equivariance_jacobian: 10reconstruction_params:num_videos: 1000format: '.mp4'animate_params:num_pairs: 50format: '.mp4'normalization_params:adapt_movement_scale: Falseuse_relative_movement: Trueuse_relative_jacobian: Truevisualizer_params:kp_size: 5draw_border: Truecolormap: 'gist_rainbow'

可以将上面的代码存储成yaml文件然后用如下代码进行读取

import yaml
with open("conf/vox-256.yaml") as f:config = yaml.load(f,yaml.FullLoader)

2、关键点抽取模型-keypoint detector

2.1 图片预处理-图片缩放

AntiAlias Interpolation 抗锯齿插值算法
在原文中默认的图像大小是 ** 256*256 ** 但考虑到运算速度在真正进行模型调试时,会采用64*64的图片进行训练,此时会使用该模型为图片进行缩放

2.1.1 AntiAliasInterpolation2d源码

class AntiAliasInterpolation2d(nn.Module):"""Band-limited downsampling, for better preservation of the input signal."""def __init__(self, channels, scale):super(AntiAliasInterpolation2d, self).__init__()#sigma = (1 / scale - 1) / 2kernel_size = 2 * round(sigma * 4) + 1self.ka = kernel_size // 2self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.kakernel_size = [kernel_size, kernel_size]sigma = [sigma, sigma]kernel = 1meshgrids = torch.meshgrid([torch.arange(size, dtype=torch.float32)for size in kernel_size])for size, std, mgrid in zip(kernel_size, sigma, meshgrids):mean = (size - 1) / 2kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))# Make sure sum of values in gaussian kernel equals 1.kernel = kernel / torch.sum(kernel)# Reshape to depthwise convolutional weightkernel = kernel.view(1, 1, *kernel.size())kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))self.register_buffer('weight', kernel)self.groups = channelsself.scale = scaleinv_scale = 1 / scaleself.int_inv_scale = int(inv_scale)def forward(self, input):if self.scale == 1.0:return inputout = F.pad(input, (self.ka, self.kb, self.ka, self.kb))out = F.conv2d(out, weight=self.weight, groups=self.groups)out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]return out

2.1.2 AntiAliasInterpolation2d测试代码

import matplotlib.pyplot as plt
%matplotlib notebook
predictor=AntiAliasInterpolation2d(3,0.25)
imgdata=imagedata=imread('test.jpg')/255
imagedata=torch.unsqueeze(torch.tensor(imagedata,dtype=torch.float32),0)
imagedata=imagedata.permute([0,3,1,2])
x=outdata=predictor(imagedata)
figure,ax=plt.subplots(1,2)
ax[0].imshow(imgdata)
ax[1].imshow(outdata.permute([0,2,3,1])[0])

输出效果如下

2.2 关键点特征图提取 Hourglass(沙漏)模型

  • 该模型主要用来对原图与驱动图中的关键点进行抽取,模型结构如下图所示
  • 注意该模型虽然称之为Hourglass模型但与相关论文中的沙漏模型是不同的他的结构更像是Unet模型的变体
  • 该模型用于检测模型的原始关键点信息

2.2.1 模型相关代码

class UpBlock2d(nn.Module):"""Upsampling block for use in decoder."""def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):super(UpBlock2d, self).__init__()self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,padding=padding, groups=groups)
#         self.norm = BatchNorm2d(out_features, affine=True)self.norm=torch.nn.BatchNorm2d(out_features, affine=True)def forward(self, x):out = F.interpolate(x, scale_factor=2)out = self.conv(out)out = self.norm(out)out = F.relu(out)return outclass DownBlock2d(nn.Module):"""Downsampling block for use in encoder."""def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):super(DownBlock2d, self).__init__()self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,padding=padding, groups=groups)
#         self.norm = BatchNorm2d(out_features, affine=True)self.norm=torch.nn.BatchNorm2d(out_features, affine=True)self.pool = nn.AvgPool2d(kernel_size=(2, 2))def forward(self, x):out = self.conv(x)out = self.norm(out)out = F.relu(out)out = self.pool(out)return out
class Encoder(nn.Module):"""Hourglass Encoder"""def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):super(Encoder, self).__init__()down_blocks = []for i in range(num_blocks):down_blocks.append(DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),min(max_features, block_expansion * (2 ** (i + 1))),kernel_size=3, padding=1))self.down_blocks = nn.ModuleList(down_blocks)def forward(self, x):outs = [x]for down_block in self.down_blocks:outs.append(down_block(outs[-1]))return outsclass Decoder(nn.Module):"""Hourglass Decoder"""def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):super(Decoder, self).__init__()up_blocks = []for i in range(num_blocks)[::-1]:in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))out_filters = min(max_features, block_expansion * (2 ** i))up_blocks.append(UpBlock2d(in_filters, out_filters, kernel_size=3, padding=1))self.up_blocks = nn.ModuleList(up_blocks)self.out_filters = block_expansion + in_featuresdef forward(self, x):out = x.pop()for up_block in self.up_blocks:out = up_block(out)skip = x.pop()out = torch.cat([out, skip], dim=1)return outclass Hourglass(nn.Module):"""Hourglass architecture."""def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):super(Hourglass, self).__init__()self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)self.out_filters = self.decoder.out_filtersdef forward(self, x):return self.decoder(self.encoder(x))

2.2.2 使用tensorboard可视化模型¶

在交互式模型下可以将hourglass模型再展开,看到具体的模型结构以确定模型与设计是否相同

注意

原论文中使用的batchnormal2d不是pytorch框架自带的。因为再分布式计算的情况下,每一个设备的batchnormal操作并不是针对所有数据,而是运行在
本设备上的数据所以采用了第三方的代码 Synchronized-BatchNorm-PyTorch,其代码链接如下 https://github.com/vacancy/Synchronized-BatchNorm-PyTorch

2.3 抽取10个关键点数据

  • hourglass 是获取图像关键点特征图的关键模型,通过该模型可以有效抽取所有关键点的特征图,此时我们假定我们需要10个最主要的关键点
  • 可以通过一个卷积层抽取10个关键点,同时对其进行softmax操作对其权重进行排序

2.3.1 关键代码与测试代码

kp = nn.Conv2d(in_channels=predictor.out_filters, out_channels=10, kernel_size=(7, 7),padding=0)
feature_map = predictor(x)
prediction = kp(feature_map)
final_shape = prediction.shape
heatmap = prediction.view(final_shape[0], final_shape[1], -1)
heatmap = F.softmax(heatmap / 0.1, dim=2)
heatmap = heatmap.view(*final_shape)

2.3.2 创建本地归一化坐标系(R)获取高斯置信图¶

  • 该函数根据指定的高与宽创建一个网格(mesh grid)
  • 该网格在-1 与 1 之间均与分布
  • 首先使用linspace生成网格坐标点
  • 将坐标点转化到-1到+1的均匀分布
  • 填充网格点生成坐标

创建归一化坐标系

def make_coordinate_grid(spatial_size, type):"""Create a meshgrid [-1,1] x [-1,1] of given spatial_size."""h, w = spatial_sizex = torch.arange(w).type(type)y = torch.arange(h).type(type)x = (2 * (x / (w - 1)) - 1)y = (2 * (y / (h - 1)) - 1)yy = y.view(-1, 1).repeat(1, w)xx = x.view(1, -1).repeat(h, 1)meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)return meshed
def gaussian2kp(heatmap):"""Extract the mean and from a heatmap"""shape = heatmap.shapeheatmap = heatmap.unsqueeze(-1)grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)value = (heatmap * grid).sum(dim=(2, 3))kp = {'value': value}return kp
out = gaussian2kp(heatmap)

输出效果如下

{'value': tensor([[[-0.4011, -0.4888],[-0.5129,  0.3301],[ 0.0890, -0.1418],[-0.0375,  0.1512],[-0.0831, -0.0603],[-0.4330,  0.4204],[ 0.0383,  0.0883],[-0.2365,  0.4034],[-0.1921,  0.3863],[-0.4409, -0.3465]]], grad_fn=<SumBackward1>)}
  • 相关原理说明
  1. 该函数是用来进行关键点定位,通过10个conv对hourglass输出(featuremap)进行变化,生成了10个关键点的未知信息,每个关键点的为58*58
  2. 对58*58进行进行argmax操作即可获得该关键点的坐标但由于argmax不可导所以引入soft-argmax概念并通过在宽高方向上的softmax操作来进行

请参考如下代码,帮助理解

  1. 伪造一个特征点
  2. 在宽高方向上进行softmax操作
  3. 可视化该图形
# 伪造特征点
kp_test_data=torch.ones([2,2])
kp_test_data=nn.functional.pad(kp_test_data,[46,10,20,36])
figure=plt.figure(figsize=(4,4))
plt.imshow(kp_test_data,cmap='gray')

# 观察置信图分布
x=y=np.arange(0,58)
X,Y=np.meshgrid(x,y)
figure=plt.figure(figsize=(5,4))
ax3d=figure.add_subplot(projection='3d')
ax3d.plot_surface(X, Y, kp_test_data.numpy(), rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax3d.set_xlabel('x')
ax3d.set_ylabel('y')
ax3d.set_zlabel('z')
plt.show()

  • 进行softmax之后的输出
  • 注意如果想在一个平面上进行softmax则整个维度必须大于3否在便只能在行或者列上进行softmax
  • 所以将图像数据转化到pytorch标准的4维格式[batch,chanel,heigth,width]
  • 注意将原始数据放大10倍对结果观察十分重要
#修正维度
rowdata=kp_test_data.view(-1).unsqueeze(0).unsqueeze(0)
softmax_rawdata=F.softmax(rowdata / 0.1, dim=2).view(1,1,58,58)
x=y=np.arange(0,58)
X,Y=np.meshgrid(x,y)
figure=plt.figure(figsize=(5,4))
ax3d=figure.add_subplot(projection='3d')
ax3d.plot_surface(X, Y, softmax_rawdata.view([58,58]).numpy(), rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax3d.set_xlabel('x')
ax3d.set_ylabel('y')
ax3d.set_zlabel('z')
plt.show()

  • 观察应用网格之后的效果
  • 可以看到关键点坐标系从-1 到 1 两个值大致说明了关键点的坐标,具体看下面代码
grid = make_coordinate_grid([58,58], torch.float32)
heatmap1=softmax_rawdata.view([58,58,1])
landmark=(heatmap1 * grid).view([2,58,58])
value = landmark.sum(dim=(1, 2))
print(value)

此时输出

tensor([0.3286, 0.0094])

可以看到x坐标明显靠右,y坐标靠近中间,他们的相对位置与原图基本一致

2.3 获得变换矩阵雅各比矩阵(行列式)