文章目录

  • 前言
  • 一、config.py
  • 二、datalist.py
  • 三.common.py
  • 四.model.py
  • 五.model_common.py
  • 六.train.py
  • 总结

前言

该算法是从github上找的onion peel network算法,但是由于开发者只提供了demo部分,所以我试着自己把train的部分自己实现了,目前来看多少有点能补全的意思。目前来看还不是很成熟,但我还是发出来给大家看看。当然我把风格严格控制成我以前发的代码风格,方便学习和以后的优化

一、config.py

import argparseparser=argparse.ArgumentParser(description="Onion Peel Network")parser.add_argument('--project_name',type=str,default="video completion by Onion Peel Network",help='工程名')
parser.add_argument("--use_cuda",type=bool,default=True,help="是否想使用cuda")
parser.add_argument("--seed",type=int,default=123,help="随机种子")
parser.add_argument("--resume",type=bool,default=True,help="是否使用预训练的权重加载模型")
parser.add_argument("--pretrained_weight",type=str,default='OPN.pth',help="预训练模型加载路径")
parser.add_argument("--lr",type=float,default=0.0001,help="学习率")
parser.add_argument("--weight_decay",type=float,default=1e-4,help="权重衰减系数")
parser.add_argument("--momentum",type=float,default=0.5,help="动量系数")
parser.add_argument("--epoch",type=int,default=10,help="训练epoch次数")
parser.add_argument("--train_batch_size",type=int,default=1,help="训练batch_size")
parser.add_argument("--test_batch_size",type=int,default=1,help="测试batch_szie")
parser.add_argument("--save",type=bool,default=True,help="保存图片")

这是我工程的配置文件,跟以前的一样

二、datalist.py

import os
import randomimport cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import DatasetT, H, W = 5, 240, 424class Dataset(Dataset):def __init__(self, type='train'):self.type = typedef __len__(self):return len(os.listdir('Image_inputs/W')) // 2def __getitem__(self, index):if index >= len(self) - 5:index = index - 5print(index)frames = np.empty((T, H, W, 3), dtype=np.float32)holes = np.empty((T, H, W, 1), dtype=np.float32)dists = np.empty((T, H, W, 1), dtype=np.float32)for i in range(5):# rgbimg_file = os.path.join('Image_inputs', 'W', '{:04d}.jpg'.format(index + i))raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)frames[i] = raw_frame# maskmask_file = os.path.join('Image_inputs', 'W', '{:04d}.png'.format(index + i))raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)raw_mask = (raw_mask > 0.5).astype(np.uint8)raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作holes[i, :, :, 0] = raw_mask.astype(np.float32)# distsdists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来# 图片转换成tensorframes = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()# remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)# valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算valids = 1 - holes# frames = frames.unsqueeze(0)# holes = holes.unsqueeze(0)# dists = dists.unsqueeze(0)# valids = valids.unsqueeze(0)return frames, valids, distsclass Generator(object):def __init__(self, batch_size=1):self.batch_size = batch_sizeself.images = os.listdir("Image_inputs")def generator(self):# while True:dir = self.images[random.choice(range(len(self.images)))]frames = np.empty((T, H, W, 3), dtype=np.float32)holes = np.empty((T, H, W, 1), dtype=np.float32)dists = np.empty((T, H, W, 1), dtype=np.float32)label = np.empty((T, H, W, 3), dtype=np.float32)for i in range(5):# rgbimg_file = os.path.join('Image_inputs', dir, 'gt_{:1d}.jpg'.format(i))raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)frames[i] = raw_framelabel[i] = raw_frame# maskmask_file = os.path.join('Image_inputs', dir, 'mask_{:1d}.png'.format(i))raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)raw_mask = (raw_mask > 0.5).astype(np.uint8)raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作holes[i, :, :, 0] = raw_mask.astype(np.float32)# distsdists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来# 图片转换成tensorframes = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()# remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)# valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算valids = 1 - holesframes = frames.unsqueeze(0)dists = dists.unsqueeze(0)valids = valids.unsqueeze(0)label = label.unsqueeze(0)yield frames, valids, dists, labelclass Dataset2(Dataset):def __init__(self):self.images = os.listdir("image")def __len__(self):return 1def __getitem__(self, index):dir = self.images[random.choice(range(len(self.images)))]frames = np.empty((T, H, W, 3), dtype=np.float32)holes = np.empty((T, H, W, 1), dtype=np.float32)dists = np.empty((T, H, W, 1), dtype=np.float32)label = np.empty((T, H, W, 3), dtype=np.float32)for i in range(5):# rgbimg_file = os.path.join('image', dir, 'gt_{:04d}.jpg'.format(random.choice(range(len(os.listdir(os.path.join('image', dir)))))))raw_frame = np.array(Image.open(img_file).convert('RGB')) / 255.raw_frame = cv2.resize(raw_frame, dsize=(W, H), interpolation=cv2.INTER_CUBIC)frames[i] = raw_framelabel[i] = raw_frame# maskmask_file = os.path.join('mask', 'mask_{:04d}.png'.format(random.choice(range(0,52))))raw_mask = np.array(Image.open(mask_file).convert('P'), dtype=np.uint8)raw_mask = (raw_mask > 0.5).astype(np.uint8)raw_mask = cv2.resize(raw_mask, dsize=(W, H), interpolation=cv2.INTER_NEAREST)raw_mask = cv2.dilate(raw_mask, cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)))  # cv2.dilate 膨胀操作holes[i, :, :, 0] = raw_mask.astype(np.float32)# distsdists[i, :, :, 0] = cv2.distanceTransform(raw_mask, cv2.DIST_L2,maskSize=5)  # cv2.distanceTransform()可以方便地将前景对象提取出来# 图片转换成tensorframes = torch.from_numpy(np.transpose(frames, (3, 0, 1, 2)).copy()).float()holes = torch.from_numpy(np.transpose(holes, (3, 0, 1, 2)).copy()).float()dists = torch.from_numpy(np.transpose(dists, (3, 0, 1, 2)).copy()).float()label = torch.from_numpy(np.transpose(label, (3, 0, 1, 2)).copy()).float()# remove holes 在图片中抠出相对应的洞   [0.4585, 0.456, 0.406]使用了imageNet的平均值frames = frames * (1 - holes) + holes * torch.tensor([0.4585, 0.456, 0.406]).view(3, 1, 1, 1)# valids area  验证标签,也就是获取图片中被扣掉部分原有的数值分布,用于后续的loss计算valids = 1 - holesreturn frames, valids, dists, label

这里的data部分,因为需求不一样了,我略微的做了一点改变。用的是generator部分,前面几个data是失败部分。

三.common.py

import torch
import torch.nn as nndef get_features(img, model, layers=None):'''获取特征层'''if layers is None:layers = {'0': 'conv1_1','5': 'conv2_1','10': 'conv3_1','19': 'conv4_1','21': 'conv4_2',  # content层'28': 'conv5_1'}features = {}x = imgfor name, layer in model._modules.items():x = layer(x)if name in layers:features[layers[name]] = xreturn featuresdef gram_matrix(tensor):'''计算Gram matrix'''_, d, h, w = tensor.size()  # 第一个是batch_sizetensor = tensor.view(d, h * w)gram = torch.mm(tensor, tensor.t())return gram'''
TV loss是常用的一种正则项(注意是正则项,配合其他loss一起使用,约束噪声)
图片中相邻像素值的差异可以通过降低TV loss来一定程度上解决图像上的一点点噪声可能就会对复原的结果产生非常大的影响,因为很多复原算法都会放大噪声。
这时候我们就需要在最优化问题的模型中添加一些正则项来保持图像的光滑性
'''class TVLoss(nn.Module):def __init__(self, TVLoss_weight=1):super(TVLoss, self).__init__()self.TVLoss_weight = TVLoss_weightdef forward(self, x):batch_size = x.size()[0]h_x = x.size()[2]w_x = x.size()[3]count_h = self._tensor_size(x[:, :, 1:, :])  # 算出总共求了多少次差count_w = self._tensor_size(x[:, :, :, 1:])h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()# x[:,:,1:,:]-x[:,:,:h_x-1,:]就是对原图进行错位,分成两张像素位置差1的图片,第一张图片# 从像素点1开始(原图从0开始),到最后一个像素点,第二张图片从像素点0开始,到倒数第二个# 像素点,这样就实现了对原图进行错位,分成两张图的操作,做差之后就是原图中每个像素点与相# 邻的下一个像素点的差。w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_sizedef _tensor_size(self, t):return t.size()[1] * t.size()[2] * t.size()[3]class L1_Loss(nn.Module):def __init__(self):super(L1_Loss, self).__init__()def forward(self, x, y, pv):loss = 0pv = pvif pv.ndim > 4:pv = pvpv = pv.squeeze(dim=0)for i in range(pv.size(1)):loss += torch.sum(torch.abs(x - y) * pv[:, i, :, :])return losselse:loss = torch.sum(torch.abs(x - y) * pv)/(y.shape(0)*y.shape(1)*y.shape(2))return lossclass L1_Lossv2(nn.Module):def __init__(self):super(L1_Lossv2, self).__init__()def forward(self,x,y,pv):loss=0if pv.ndim>4:for i in range(pv.size()[1]):temp=pv[:,:,i]loss=loss+torch.sum((torch.abs(x.flatten()-y.flatten())*temp.flatten()))return losselse:loss=torch.sum(torch.abs(x.flatten()-y.flatten())*pv.flatten())return lossdef L1(x,y,mask):res=torch.abs(x-y)res=res*maskreturn torch.sum(res)/(y.shape(0)*y.shape(1)*y.shape(2))def ll1(x,y):return torch.sum(x-y)

这里主要是自己写loss,因为论文里他loss挺多的,相关博文介绍又少,我花了好多时间去测试,但自己写loss真的很重要。毕竟现在顶会论文花样越来越多了,已经不局限于网络结构设计都看不懂了,真不想搞了,cnn打天下不好吗

四.model.py

from __future__ import division# general libs
import math
import syssys.path.insert(0, '.')
from .common import *sys.path.insert(0, '../utils/')
from utils.helpers import *class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.conv12 = GatedConv2d(5, 64, kernel_size=5, stride=2, padding=2,activation=nn.LeakyReLU(negative_slope=0.2))  # 2self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 2self.conv23 = GatedConv2d(64, 128, kernel_size=3, stride=2, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.key3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None)  # 4self.val3 = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1, activation=None)  # 4self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))def forward(self, in_f, in_v, in_h):# 图片标准化 framesf = (in_f - self.mean) / self.stdx = torch.cat([f, in_v, in_h], dim=1)x = self.conv12(x)x = self.conv2(x)x = self.conv23(x)x = self.conv3a(x)x = self.conv3b(x)x = self.conv3c(x)x = self.conv3d(x)k = self.key3(x)v = self.val3(x)return k, v# 不对称注意力模块
class MaskedRead(nn.Module):def __init__(self):super(MaskedRead, self).__init__()def forward(self, qkey, qval, qmask, mkey, mval, mmask):'''read for *mask area* of query from *mask area* of memory'''B, Dk, _, H, W = mkey.size()_, Dv, _, _, _ = mval.size()# key: b,dk,t,h,w# value: b,dv,t,h,w# mask: b,1,t,h,wfor b in range(B):# exceptionsif qmask[b, 0].sum() == 0 or mmask[b, 0].sum() == 0:# print('skipping read', qmask[b,0].sum(), mmask[b,0].sum())# no query or mask pixels -> skip readcontinue# [128,284]qk_b = qkey[b, :, qmask[b, 0]]  # dk, Nqmv_b = mval[b, :, mmask[b, 0]]  # dv, Nmmk_b = mkey[b, :, mmask[b, 0]]  # dk, Nm   #mkey(1,128,4,60,106)  mmask(1,1,5,60,106)# print(mv_b.shape)p = torch.mm(torch.transpose(mk_b, 0, 1), qk_b)  # Nm, Nqp = p / math.sqrt(Dk)  # 防止过拟合p = torch.softmax(p, dim=0)read = torch.mm(mv_b, p)  # dv, Nq# qval[b,:,qmask[b,0]] = read # dv, Nqqval[b, :, qmask[b, 0]] = qval[b, :, qmask[b, 0]] + read  # dv, Nqreturn qvalclass Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.conv3d = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=8, dilation=8,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3c = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=4, dilation=4,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3b = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=2, dilation=2,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv3a = GatedConv2d(128, 128, kernel_size=3, stride=1, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 4self.conv32 = GatedConv2d(128, 64, kernel_size=3, stride=1, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 2self.conv2 = GatedConv2d(64, 64, kernel_size=3, stride=1, padding=1,activation=nn.LeakyReLU(negative_slope=0.2))  # 2self.conv21 = GatedConv2d(64, 3, kernel_size=5, stride=1, padding=2, activation=None)  # 1self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))def forward(self, x):x = self.conv3d(x)x = self.conv3c(x)x = self.conv3b(x)x = self.conv3a(x)x = F.interpolate(x, scale_factor=2, mode='nearest')  # 2x = self.conv32(x)x = self.conv2(x)x = F.interpolate(x, scale_factor=2, mode='nearest')  # 2x = self.conv21(x)p = (x * self.std) + self.meanreturn pclass OPN(nn.Module):def __init__(self, mode='Train', CPU_memory=False, thickness=8):super(OPN, self).__init__()self.Encoder = Encoder()self.MaskedRead = MaskedRead()self.Decoder = Decoder()self.thickness = thicknessself.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))self.register_buffer('mean3d', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1, 1))def memorize(self, frames, valids, dists):'''encode every frame of *valid* area into key:valueDone once as initialization'''# padding到同一尺寸(frames, valids, dists), pad = pad_divide_by([frames, valids, dists], 4, (frames.size()[3], frames.size()[4]))# make holeholes = (dists > 0).float()frames = (1 - holes) * frames + holes * self.mean3dbatch_size, _, num_frames, height, width = frames.size()# num_frames 图片的张数 :5张# encoding...key_ = []val_ = []for t in range(num_frames):key, val = self.Encoder(frames[:, :, t], valids[:, :, t], holes[:, :, t])key_.append(key)val_.append(val)keys = torch.stack(key_, dim=2)vals = torch.stack(val_, dim=2)hols = (F_upsample3d(holes, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) > 0)return keys, vals, holsdef read(self, mkey, mval, mhol, frame, valid, dist):''' ## assume single frame query1) encode current status of frames -> query2) read from memmories (computed calling 'memorize')3) decode readed feature4) compute loss on peel area'''thickness = self.thickness# padding(frame, valid, dist), pad = pad_divide_by([frame, valid, dist], 4, (frame.size()[2], frame.size()[3]))batch_size, _, height, width = frame.size()# make hole and peel..hole = (dist > 0).float()peel = hole * (dist <= thickness).float()# 更新distnext_dist = torch.clamp(dist - thickness, 0, 9999)# get 1/4 scale maskpeel3 = (F.upsample(peel, size=(int(height / 4), int(width / 4)), mode='bilinear', align_corners=False) >= 0.5)# 更新frameframe = (1 - hole) * frame + hole * self.mean# reading and decoding...qkey, qval = self.Encoder(frame, valid, hole)qpel = peel3# read 不对称注意力块.read = self.MaskedRead(qkey, qval, qpel, mkey, mval, ~mhol)# decodepred = self.Decoder(read)comp = (1 - peel) * frame + peel * pred  # fill peel areaif pad[2] + pad[3] > 0:comp = comp[:, :, pad[2]:-pad[3], :]next_dist = next_dist[:, :, pad[2]:-pad[3], :]if pad[0] + pad[1] > 0:comp = comp[:, :, :, pad[0]:-pad[1]]next_dist = next_dist[:, :, :, pad[0]:-pad[1]]# 防止颜色通道信息的溢出comp = torch.clamp(comp, 0, 1)return comp, next_dist, peeldef forward(self, *args, **kwargs):# print(len(args), len(kwargs))if len(args) == 3:return self.memorize(*args)else:return self.read(*args, **kwargs)

这里是论文中的相关部分,说实话这技术太厉害了,一般人根本写不出来,全是抄的就不说了

五.model_common.py

from __future__ import division# general libs
import sys
import torch.nn.functional as F
sys.path.insert(0, '../utils/')
from utils.helpers import *##########################################
############   Generic   #################
##########################################def pad_divide_by(in_list, d, in_size):out_list = []h, w = in_sizeif h % d > 0:new_h = h + d - h % delse:new_h = hif w % d > 0:new_w = w + d - w % delse:new_w = wlh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)pad_array = (int(lw), int(uw), int(lh), int(uh))for inp in in_list:out_list.append(F.pad(inp, pad_array))return out_list, pad_arrayclass ConvGRU(nn.Module):def __init__(self, mdim, kernel_size=3, padding=1):super(ConvGRU, self).__init__()self.convIH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)self.convHH = nn.Conv2d(mdim, 3 * mdim, kernel_size=kernel_size, padding=padding)def forward(self, input, hidden_tm1):if hidden_tm1 is None:hidden_tm1 = torch.zeros_like(input)gi = self.convIH(input)gh = self.convHH(hidden_tm1)i_r, i_i, i_n = torch.chunk(gi, 3, dim=1)h_r, h_i, h_n = torch.chunk(gh, 3, dim=1)resetgate = torch.sigmoid(i_r + h_r)  # resetinputgate = torch.sigmoid(i_i + h_i)  # updatenewgate = F.tanh(i_n + resetgate * h_n)# hidden_t = inputgate * hidden_tm1 + (1-inputgate)*newgatehidden_t = newgate + inputgate * (hidden_tm1 - newgate)return hidden_tdef F_upsample3d(x, size=None, scale_factor=None, mode='nearest', align_corners=None):num_frames = x.size()[2]up_s = []for f in range(num_frames):up = F.interpolate(x[:, :, f], size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)up_s.append(up)ups = torch.stack(up_s, dim=2)return upsdef F_upsample(x, size=None, scale_factor=None, mode='nearest', align_corners=None):if x.dim() == 5:  # 3dreturn F_upsample3d(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)else:return F.interpolate(x, size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners)class GatedConv2d(nn.Module):def __init__(self, in_channels, out_channels, kernel_size, stride=1,padding=0, dilation=1, groups=1, bias=True, activation=None):super().__init__()self.input_conv = nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding, dilation, groups, bias)self.gating_conv = nn.Conv2d(in_channels, out_channels, kernel_size,stride, padding, dilation, groups, bias)init_He(self)self.activation = activationdef forward(self, input):# O = act(Feature) * sig(Gating)feature = self.input_conv(input)if self.activation:feature = self.activation(feature)gating = torch.sigmoid(self.gating_conv(input))return feature * gating

这里是模型的相关配置文件,也是抄的

六.train.py

import os
import sys
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader
from torchvision import modelsfrom common import *
from datalist import Dataset2
from models.OPN import OPN
from utils.helpers import *sys.path.append('utils/')
sys.path.append('models/')style_weights = {'conv1_1': 1,'conv2_1': 0.8,'conv3_1': 0.5,'conv4_1': 0.3,'conv5_1': 0.1,
}
from config import parserclass train(object):def __init__(self):self.args = parser.parse_args()print(f"-----------{self.args.project_name}-----------")use_cuda = self.args.use_cuda and torch.cuda.is_available()if use_cuda:torch.cuda.manual_seed(self.args.seed)else:torch.manual_seed(self.args.seed)self.device = torch.device("cuda" if use_cuda else "cpu")kwargs = {'num_workers': 0, 'pin_memory': True} if use_cuda else {}'''构造DataLoader'''# ToDo 数据集需要重新制备print("Create Dataloader")self.train_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)self.test_loader = DataLoader(Dataset2(), batch_size=1, shuffle=True, **kwargs)'''定义模型'''print("Create Model")self.model = OPN().to(self.device)
#        self.model = nn.DataParallel(OPN())if use_cuda:# self.model = self.model.cuda()cudnn.benchmark = True'''根据需要加载预训练的模型权重参数'''# VGG16模型配合预训练的模型用于检测self.vgg = models.vgg16(pretrained=True).to(self.device).featuresfor i in self.vgg.parameters():i.requires_grad = Falsetry:if self.args.resume and self.args.pretrained_weight:self.model.load_state_dict(torch.load(os.path.join('OPN.pth')), strict=False)print("模型加载成功")except:print("模型加载失败")'''cuda加速'''if use_cuda:#   self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))cudnn.benchmark = True'''构造loss目标函数选择优化器学习率变化选择'''print("Establish the loss, optimizer and learning_rate function")self.loss_tv = TVLoss()self.loss_l1=L1_Loss()# 另外还有style—loss 和 content—loss# self.optimizer = optim.SGD(#     params=self.model.parameters(),#     lr=self.args.lr,#     weight_decay=self.args.weight_decay,#     momentum=0.5# )self.optimizer = optim.Adam(params=self.model.parameters(),lr=0.001,betas=(0.9, 0.999),eps=1e-8,  # 为了防止分母为0weight_decay=0)# self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=5, eta_min=1e-5)'''模型开始训练'''print("Start training")for epoch in tqdm(range(1, self.args.epoch + 1)):self.train(epoch)if epoch % 20==0:self.test(epoch)torch.cuda.empty_cache()print("finish model training")def train(self, epoch):self.model.train()for data in self.train_loader:self.content_loss = 0self.style_loss = 0midx = list(range(0, 5))# frames被破损的图像,valids可获取的像素区域,dists填补的像素区域frames, valids, dists, label = dataframes, valids, dists, label = frames.to(self.device), valids.to(self.device), dists.to(self.device), label.to(self.device)# 每一张图片都被encoder过了获得的key和val shape为(1,128,5,60,106),hol为(1,1,5,60,106)mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])allloss=0for f in range(5):loss=0# 对每张图取其他4张图作为reference的参考ridx = [i for i in range(len(midx)) if i != f]fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]# 图像补全for r in range(5):if r == 0:# 取主图comp = frames[:, :, f]dist = dists[:, :, f]# comp是破损的图片,逐层补全图片# valids是没有缺失信息的区域# dist是缺失信息的区域'''按dist的指导,逐8个像素的距离,循环修复图片,其中valids表示空洞部分的区域(0,1)comp是在frame的基础之上补充的,相似度极高,只计算这一部分的loss'''comp, dist, peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)# 每次循环中分别在像素空间和深层特征空间最小化和GT的L1距离。loss += 100 * L1(comp, label[:, :, f], peel)# loss += L1(comp, label[:, :, f], valids[:,:,f])loss+=0.2*self.loss_l1(comp,label[:,:,f],valids[:,:,midx])# loss+=100*ll1(comp,frames[:,:,f])# content losscontent_features = get_features(frames[:, :, f], self.vgg)target_features = get_features(comp, self.vgg)self.content_loss = torch.mean(torch.abs((target_features['conv4_2'] - content_features['conv4_2'])))loss = loss + 0.05 * self.content_loss# style lossstyle_features = get_features(comp, self.vgg)style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}'''加上每一层的gram_matrix矩阵的损失'''for layer in style_weights:target_feature = target_features[layer]target_gram = gram_matrix(target_feature)_, d, h, w = target_feature.shapestyle_gram = style_grams[layer]layer_style_loss = style_weights[layer] * torch.mean(torch.abs((target_gram - style_gram)))self.style_loss += layer_style_loss / (d * h * w)  # 加到loss = loss + 120 * self.style_loss# tv lossloss += 0.01 * self.loss_tv(comp)allloss+=lossself.optimizer.zero_grad()allloss.backward()self.optimizer.step()# self.scheduler.step()# print("epoch{}".format(epoch) + "  loss:{}".format(loss.cpu()))def test(self, epoch):self.model.eval()for frames, valids, dists, _ in self.test_loader:midx = list(range(0, 5))# frames, valids, dists = dataframes, valids, dists = frames.to(self.device), valids.to(self.device), dists.to(self.device)with torch.no_grad():# 先把这5张图片都encoder一下mkey, mval, mhol = self.model(frames[:, :, midx], valids[:, :, midx], dists[:, :, midx])# 对每张图取其他4张图作为reference的参考for f in range(5):ridx = [i for i in range(len(midx)) if i != f]fkey, fval, fhol = mkey[:, :, ridx], mval[:, :, ridx], mhol[:, :, ridx]# 图像补全for r in range(999):if r == 0:comp = frames[:, :, f]dist = dists[:, :, f]with torch.no_grad():comp, dist,peel = self.model(fkey, fval, fhol, comp, valids[:, :, f], dist)comp, dist = comp.detach(), dist.detach()# 空隙填满进入后,把图片保存,然后进入下一轮图片的计算过程中if torch.sum(dist).item() == 0:breakif self.args.save:# visualize..est = (comp[0].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)true = (frames[0, :, f].permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8)  # h,w,3mask = (dists[0, 0, f].detach().cpu().numpy() > 0).astype(np.uint8)  # h,w,1ov_true = overlay_davis(true, mask, colors=[[0, 0, 0], [100, 100, 0]], cscale=2, alpha=0.4)canvas = np.concatenate([ov_true, est], axis=0)save_path = os.path.join('Results')if not os.path.exists(save_path):os.makedirs(save_path)canvas = Image.fromarray(canvas)canvas.save(os.path.join(save_path, 'res_{}_{}.jpg'.format(epoch, f)))# print("epoch{}".format(epoch) + " test finished")if __name__ == "__main__":train()

这里就是我自己写的train,主要还是借鉴了demo的相关内容,我自己测试了一下训练出来的图片虽然效果不佳,但是层次感和色块的分布到有一点感觉了。
毕竟跟作者在论文中提到的一样,v100这种设备普通人家都没有的,我现在感觉2080都是垃圾。当然我能运行,是因为我有机会能接触到设备呗。


总结

上面的图片我只跑了半个小时,论文中用到了learning-scheduler,batch—size=5,并且跑了4天。从上面的途中可以看出模型能大概估计出损失部分的轮廓应该长什么样子.我觉得需要继续深入,复现应该可以做到

pytorch上分之路——视频补全算法(onion peel network)相关推荐

  1. Pytorch上分之路—ShuffleNetv3(鸟群分类算法)

    Pytorch上分之路-ShuffleNetv3(鸟群分类算法) 本次的内容是用pytorch写一个简单的分类算法,选择了200鸟群的数据集,数据集的话可以自己到网上去找,挺容易的. 目录 **Pyt ...

  2. 3dcnn视频分类算法-pytorch上分之路

    3DCNN-视频分类 项目结构 config.py datalist.py model.py train.py 最后 项目结构 config.py import argparse''' trainin ...

  3. LPR车牌识别-pytorch上分之路

    LPRNet-车牌识别 项目目录 config.py datalist.py model.py utils.py train.py trainDDP.py 最后 项目目录 这里主要就看我在目录里强调的 ...

  4. 基于深度学习的RGBD深度图补全算法文章鉴赏

    点击上方"计算机视觉工坊",选择"星标" 干货第一时间送达 [GiantPandaCV导语]本文针对3维视觉中的深度图补全问题,介绍了一下近年基于深度学习的RG ...

  5. 【深度补全算法】基于RGBD相机的深度补全算法(非Lidar)论文与GitHub代码总结

    目录 前言 一.经典的深度补全算法(2018-2019) 1.Deep Depth Completion of a Single RGB-D Image 2.Indoor Depth Completi ...

  6. 知识图谱补全算法综述(动态知识图谱补全)

    论文阅读笔记:知识图谱补全算法综述 论文:丁建辉, 贾维嘉. 知识图谱补全算法综述[J]. 信息通信技术. 概念 表示学习相关理论 知识图谱表示学习 静态知识图谱补全(static KGC) 动态知识 ...

  7. ICRA2021|嵌入式系统的鲁棒单目视觉惯性深度补全算法

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 标题:Robust Monocular Visual-Inertial Depth Completio ...

  8. 3D点云补全算法汇总及最新进展

    点击上方"3D视觉工坊",选择"星标" 干货第一时间送达 Part 1  前言 在探讨3D 点云补全专题前,先介绍三个概念: 概念一:partial obser ...

  9. 低秩矩阵补全算法matlab实现,推荐系统中的矩阵补全算法

    最基本的问题,以用户电影评分为例,也就是这个用户-电影矩阵. 表中是用户多电影的评分,但评分有缺失,因为用户不可能对所有电影作出评价. 那么推荐问题就是给用户合理推荐一个没看过的电影,合理是指,预测用 ...

最新文章

  1. 王兴和张一鸣和我们的互联网启蒙
  2. traceview android studio,TraceView 的正确打开方式
  3. Confluence 6 下载和安装 MySQL 驱动
  4. 12.1 Bootstrap介绍
  5. Vue如何在data中正确引入图片路径
  6. Requests库实战(二)---破解百度翻译
  7. 常用的HTTP响应头
  8. PHP+Ajax手机移动端发红包实例
  9. python3 urllib3文档_python urllib3
  10. 图片也要查重了?期刊用AI审论文防造假,旋转/翻转/拉伸都不行
  11. IIS 7 托管管道模式 经典模式(Classic) 集成模式(Integrated) 分析与理解
  12. yytextview 复制_用YYTextView 实现填空题作答功能
  13. 正则表达式去除连续重复的字符
  14. 计量经济学计算机实验报告,综合实训报告范文
  15. 计算机一级b考试电子表格,计算机等级考试一级B第1-50套题
  16. 机器学习- 吴恩达Andrew Ng - week3-3 Multiclass Classification
  17. shc文件wegt服务器,Shc如何配置_Shc安装问题-华为云
  18. IEEE1588精密网络同步时钟协议(PTP)-v2.0协议浅析
  19. android wifi认证,android 怎么检测连接的wlan wifi需要portal认证
  20. 百度飞桨(Python+AI)入门

热门文章

  1. websphere 实用_将WebSphere Cast Iron Studio PGP活动与外部PGP实用程序一起使用
  2. 对微机用户来说 为了防止计算机意外故障,对于微机用户来说,为了防止计算机意外故障而丢失重要数据,对重要数据应定期进行备份。下列移动存储器中,最不常用的一种是...
  3. 我的世界显示服务器领地指令,我的世界领地指令介绍 我的世界领地指令怎么设置...
  4. FYD-Focus Your Distribution-关注你的分布:异常检测和定位的从粗到细的非对比性学习-FYD
  5. 软件工程作业之CSDN测评
  6. xml 以及JSON学习记录
  7. 玩vr游戏的计算机配置要求,4款电脑横向评测: 寻找最适合玩VR的PC配置
  8. 项目管理全流程,让你的工作和生活事半功倍
  9. org.apache.shiro.session.ExpiredSessionException: Session with id异常排查
  10. vmei-day04-Jcenter方式集成极光推送