虚拟试穿

简介:本文梳理虚拟试穿算法框架结构,展示模特虚拟试穿上衣的效果,细说设计流程的详细步骤,提供相应的数据资源。

  1. 算法仓库:https://github.com/beauthy/DeepFashion_Try_On
    github上不了,就访问:码云:虚拟试穿上衣测试:https://gitee.com/rpr/try-on_parse.git
  2. 链接:https://pan.baidu.com/s/1nKUevnIMcGjaitVwIb7SRg
    提取码:59wk
  3. 测试用模型,鼓励大家根据网络训练自己的模型。
    具体效果看下图or视频,想测试可以参见模型资源下载,之后Load和测试。

上述模型资源包括算法所设及的全部网络模型:latest_net_U.pth,latest_net_G1.pth,latest_net_G2.pth,latest_net_G.pth

测试效果:如图

测试效果:如视频

计算机视觉神经网络虚拟试穿测试

虚拟试穿-上衣

  • 虚拟试穿
  • 前言
  • 一、Try_On算法里面有什么?
  • 二、梳理步骤
    • 1.环境
    • 2.读入数据
    • 3. 输入配置
    • 4. 数据集处理详细
    • 5.模型结构
  • 总结
  • 参考文献

前言

本文将梳理算法实现过程原理。


提示:本文内容仅供学术研究与参考。

一、Try_On算法里面有什么?

0.环境; 1. 数据读取; 2. 数据模型:U-Net,G-Net; 3.损失函数; 4.调试常见的bug。

二、梳理步骤

1.环境

代码如下(示例):



以上只等下次优化成requirements.txt再传上来。

2.读入数据

模型输入数据需要哪些呢?

测试数据集长什么样?

数据直观内容分析,我把模型需要的输入放一起,展示如下:

实际上,pose_关键点数据,和label_分割数据,是img_模特数据得到的(怎么生成关键点数据和人物分割数据的详细解读和代码,我再开一篇博客放上来);edge_数据就是待穿衣服color生成的。mask掩码数据是根据需要随机生成的。所以,完整的项目,的输入只需要模特和服装款式即可,也就是说可以实现给个人物和一件衣服就给实现换装。
再看,

看具体情况:通过photoshop的拾色器可以直观看到数据的值如下(label是灰度图):

背景的亮度L:0;面部的亮度L:9,左胳膊的亮度L:10,右胳膊的亮度L:8,上衣衣服位置的L:2。
把肢体图像分割出精确部分,用不同的亮度表示,到时候换衣服就有边界了。学姿势,纹理和褶皱等也有边界。
注意此处的L并不是该位置的像素值,只是亮度值。像素值可以用代码打印出来看。一下给出不同块儿的像素值。

# 具体划分区域Segmentation Label
0 -> Background
1 -> Hair
4 -> Upclothes
5 -> Left-shoe
6 -> Right-shoe
7 -> Noise
8 -> Pants
9 -> Left_leg
10 -> Right_leg
11 -> Left_arm
12 -> Face
13 -> Right_arm

注:名字带mask的三张掩码图,黑色区域亮度0,白色区域亮度为100,它们没有实际意义,可用于增加噪声,让模型稳定性好一些(我是这样理解的,因为训练的时候中间结果也有损失函数的backforword).

ok,以上就是输入数据,理清了没?
下面分析怎么使用的,以及后面模型是怎么组合的。

3. 输入配置

怎么读取数据集,生成模型输入需要的数据?项目写了一个配置options,专用于对数据集目录信息,训练测试信息和超参数进行配置的文件。

test.py文件中:

 opt = TrainOptions().parse()

ctrl+鼠标左键点击TestOptions,找到opt对象的具体内容:

class TestOptions(BaseOptions):def initialize(self):BaseOptions.initialize(self)......

ctrl+鼠标左键点击BaseOptions

class BaseOptions():def __init__(self):self.parser = argparse.ArgumentParser()self.initialized = Falsedef initialize(self):.....

TestOptions类的initialize函数系重写,但还是调用了BaseOptions.initialize(self)的,所以BaseOptions.initialize的数据也包含了的。根据训练和测试所需要的数据不同,控制生成数据集和其他超参数。

4. 数据集处理详细

生成数据迭代器,同其他pytorch自制数据集相差不大。重点文件时aligned_dataset.py

具体一起来看看:
test.py文件中:

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()

CreateDataLoader是封装数据迭代器的类:
点进去看一眼

def CreateDataLoader(opt):from data.custom_dataset_data_loader import CustomDatasetDataLoaderdata_loader = CustomDatasetDataLoader()print(data_loader.name())data_loader.initialize(opt)return data_loader

具体内容在CustomDatasetDataLoader

class CustomDatasetDataLoader(BaseDataLoader):def name(self):return 'CustomDatasetDataLoader'def initialize(self, opt):BaseDataLoader.initialize(self, opt)self.dataset = CreateDataset(opt)self.dataloader = torch.utils.data.DataLoader(self.dataset,batch_size=opt.batchSize,shuffle=not opt.serial_batches,num_workers=int(opt.nThreads))def load_data(self):return self.dataloaderdef __len__(self):return min(len(self.dataset), self.opt.max_dataset_size)

发现还封装了一层,数据是:self.dataset = CreateDataset(opt):
发现还有函数封装,

def CreateDataset(opt):dataset = Nonefrom data.aligned_dataset import AlignedDatasetdataset = AlignedDataset()print("dataset [%s] was created" % (dataset.name()))dataset.initialize(opt)return dataset

注意dataset.initialize(opt),封装数据过程中动不动都在初始化。
继续AlignedDataset,ctrl+鼠标左键点击AlignedDataset

class AlignedDataset(BaseDataset):def initialize(self, opt):......def __getitem__(self, index):......

找到,def __getitem__(self, index):,看到它是不是很眼熟了,就是pytorch生成batch数据可迭代数据。这个类继承于BaseDataset,父类有transform等方法:

class BaseDataset(data.Dataset):def __init__(self):super(BaseDataset, self).__init__()def name(self):return 'BaseDataset'def initialize(self, opt):passdef get_params(opt, size):w, h = sizenew_h = hnew_w = wif opt.resize_or_crop == 'resize_and_crop':new_h = new_w = opt.loadSize            elif opt.resize_or_crop == 'scale_width_and_crop':new_w = opt.loadSizenew_h = opt.loadSize * h // wx = random.randint(0, np.maximum(0, new_w - opt.fineSize))y = random.randint(0, np.maximum(0, new_h - opt.fineSize))#flip = random.random() > 0.5flip = 0return {'crop_pos': (x, y), 'flip': flip}def get_transform(opt, params, method=Image.BICUBIC, normalize=True):transform_list = []if 'resize' in opt.resize_or_crop:osize = [opt.loadSize, opt.loadSize]transform_list.append(transforms.Scale(osize, method))   elif 'scale_width' in opt.resize_or_crop:transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))osize = [256,192]transform_list.append(transforms.Scale(osize, method))  if 'crop' in opt.resize_or_crop:transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))if opt.resize_or_crop == 'none':base = float(2 ** opt.n_downsample_global)if opt.netG == 'local':base *= (2 ** opt.n_local_enhancers)transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))if opt.isTrain and not opt.no_flip:transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))transform_list += [transforms.ToTensor()]if normalize:transform_list += [transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]return transforms.Compose(transform_list)def normalize():    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))def __make_power_2(img, base, method=Image.BICUBIC):ow, oh = img.size        h = int(round(oh / base) * base)w = int(round(ow / base) * base)if (h == oh) and (w == ow):return imgreturn img.resize((w, h), method)def __scale_width(img, target_width, method=Image.BICUBIC):ow, oh = img.sizeif (ow == target_width):return img    w = target_widthh = int(target_width * oh / ow)    return img.resize((w, h), method)def __crop(img, pos, size):ow, oh = img.sizex1, y1 = postw = th = sizeif (ow > tw or oh > th):        return img.crop((x1, y1, x1 + tw, y1 + th))return imgdef __flip(img, flip):if flip:return img.transpose(Image.FLIP_LEFT_RIGHT)return img

既然找到数据了,就看看AlignedDataset的数据生成方式吧。
首先是初始化,

def initialize(self, opt):......

到这里还记得输入数据都有哪些吗?
待穿服装:color,轮廓edge;
模特:img,pose关键点,模特分割数据label;
掩码:两个黑背景掩码,一个白背景的掩码。

 dir_C = '_color'self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)self.C_paths = sorted(make_dataset(self.dir_C))self.CR_paths = make_dataset(self.dir_C)dir_E = '_edge'self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)self.E_paths = sorted(make_dataset(self.dir_E))self.ER_paths = make_dataset(self.dir_E)dir_B =  '_img'self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)self.B_paths = sorted(make_dataset(self.dir_B))self.BR_paths = sorted(make_dataset(self.dir_B))#  pose的关键点名称和img模特命名相差不大:pose_name = B_path.replace('.jpg', '_keypoints.json').replace('test_img', 'test_pose')dir_A = '_label'self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)self.A_paths = sorted(make_dataset(self.dir_A))self.AR_paths = make_dataset(self.dir_A)

发现初始化就是把大家的地址放到明面上了,有的还排好了序,def __getitem__(self, index):函数取数据就方便多了。
下面两个函数,就是具体生成路径字典或列表的:

def make_dataset(dir):images = []assert os.path.isdir(dir), '%s is not a valid directory' % dirf = dir.split('/')[-1].split('_')[-1]print(dir, f)dirs = os.listdir(dir)for img in dirs:path = os.path.join(dir, img)# print(path)images.append(path)return imagesdef build_index(self, dirs):for k, dir in enumerate(dirs):name = dir.split('/')[-1]name = name.split('-')[0]# print(name)for k, d in enumerate(dirs[max(k - 20, 0):k + 20]):if name in d:if name not in self.diction.keys():self.diction[name] = []self.diction[name].append(d)else:self.diction[name].append(d)

到这里了,就最后看一眼,生成的数据长什么样吧!

def __getitem__(self, index):......if self.opt.isTrain:input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor,'path': A_path, 'path_ref': AR_path,'edge': E_tensor, 'color': C_tensor, 'mask': M_tensor, 'colormask': MC_tensor,'pose': P_tensor, 'name': name}else:input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}return input_dict

对,就是他input_dict,返回的字典。

注意:原作者共享的代码中,有这些后缀的文件夹,如图,我并没有相应后缀名文件,就配置训练模式,来测试数据了。

5.模型结构


继续看test.py文件,到模型板块

model = create_model(opt)

训练和测试使用的输入数据有区别的,训练的时候,除了将带穿衣服和模特的数据输入以外,还需要将穿好的结果输入,在最后面模型输出相比较,得出损失函数。

def create_model(opt):if opt.model == 'pix2pixHD':from .pix2pixHD_model import Pix2PixHDModel, InferenceModelif opt.isTrain:model = Pix2PixHDModel()else:model = InferenceModel()model.initialize(opt)if opt.verbose:print("model [%s] was created" % (model.name()))if opt.isTrain and len(opt.gpu_ids):model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)return model

测试model = InferenceModel()类实例也是继承的Pix2PixHDModel,重写了前向传播函数forward.

class InferenceModel(Pix2PixHDModel):def forward(self, inp):label = inpreturn self.inference(label)

我们直接去看Pix2PixHDModel:

pix2pixHD可以实现高分辨率图像生成和图片的语义编辑。
对于一个生成对抗网络(GAN),学习的关键就是理解生成器、判别器和损失函数这三部分。
pix2pixHD的生成器和判别器都是多尺度的,损失函数由GAN loss、Feature matching loss和Content loss组成。

class Pix2PixHDModel(BaseModel):def name(self):return 'Pix2PixHDModel'def initialize(self, opt):BaseModel.initialize(self, opt)......with torch.no_grad():self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval()self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval()self.G = networks.define_Refine(24, 3, self.gpu_ids).eval()
......def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):......

代码贴太多了,容易看不清重点。我就只贴些关键的,帮助理清楚整个网络的脉络。
Pix2PixHDModel继承了BaseModelBaseModel类有初始化函数,save_network函数,load_network函数。就没什么可以看的了。

from . import networks

networks.py文件中写了多种网络的具体结构,用类封装:

pix2pixHD模型类,初始化的时候,会对需要用到的网络进行初始化:

然后,在forward函数里面,是组合网络使用的方法和顺序,以及哪些地方需要计算损失来约束网络。

    def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):# Encode Inputs#ipdb.set_trace()input_label,masked_label,all_clothes_label= self.encode_input(label,clothes_mask,all_clothes_label)#ipdb.set_trace()arm1_mask=torch.FloatTensor((label.cpu().numpy()==11).astype(np.float)).cuda()arm2_mask=torch.FloatTensor((label.cpu().numpy()==13).astype(np.float)).cuda()pre_clothes_mask=torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()clothes=clothes*pre_clothes_mask......

forward函数的输入数据是:label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, mask.
我们输入的数据:input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}做了一点预处理的(加入高斯噪声,wash the label):

data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()......for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):epoch_start_time = time.time()if epoch != start_epoch:epoch_iter = epoch_iter % dataset_sizefor i, data in enumerate(dataset, start=epoch_iter):mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))img_fore = data['image'] * mask_foreimg_fore_wc = img_fore * mask_foreall_clothes_label = changearm(data['label'])############## 模型向前传播 ######################losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()), Variable(data['image'].cuda()),Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))

data里就是一次迭代获取dataset里的一个batch的数据。每个data都是input_dict样子,是字典。
mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int)),将datalabel对应的数据取出来处理(得到mask_clothes为区域分割数据label同尺寸的掩码图像,里面的值,原图等于4的是1,其他全0,就是把衣服区域取出来);label是什么还记得吗?看下图:

看:mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))(得到mask_fore 为区域分割数据label同尺寸的掩码图像,里面的值,原图大于1的是1,其他全0)
img_fore = data['image'] * mask_fore他们相乘,就是在抠图,去掉背景得到img_fore
all_clothes_label = changearm(data['label']):调用changearm函数(变胳膊区域):

    def changearm(old_label):label = old_labelarm1 = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.int))arm2 = torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.int))noise = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.int))label = label * (1 - arm1) + arm1 * 4label = label * (1 - arm2) + arm2 * 4label = label * (1 - noise) + noise * 4return label

changearm函数将模特区域分割数据label的左右胳膊取出来( == 11, == 13),把整幅噪声也找出来,将他们的值变成4.就像下图,把手也和衣服化为一个区域了。

############## Forward Pass ######################
losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()), Variable(data['image'].cuda()),Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))

所以,输入的量有,label–模特区域分割数据;edge–衣服轮廓;img_fore–去背景模特图像;mask_clothes–模特正穿着的衣服的掩码区域,color–衣服;all_clothes_label–模特区域分割label将胳膊手融入衣服的区域分割图像;image–模特;pose–模特关键点;mask_fore–模特区域;
对比看一看,模型中形参叫什么:

def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,grid,mask):

那就进入pix2pixHD模型的forward函数:

# Encode Inputs
input_label, masked_label, all_clothes_label = self.encode_input(label, clothes_mask, all_clothes_label)

对输入数据编码处理:

    def encode_input(self, label_map, clothes_mask, all_clothes_label):size = label_map.size()oneHot_size = (size[0], 14, size[2], size[3])input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()masked_label = masked_label.scatter_(1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0)c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()c_label = c_label.scatter_(1, all_clothes_label.data.long().cuda(), 1.0)input_label = Variable(input_label)return input_label, masked_label, c_label

手动实现one_hot 时,关于scatter_()函数: scatter_()函数有三个参数 scatter_(dim, index, src)

  1. dim指的是在哪个维度进行索引
  2. index指的是:用来进行索引的tensor
  3. src指scatter的源元素,可以是一个标量也可以是一个张量。

一句话解释上面的scatter:
input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
既input_label.scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填进input_label中。

继续net_G1(conditional GAN):

 G1_in = torch.cat([pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape)], dim=1)arm_label = self.G1.refine(G1_in)arm_label = self.sigmoid(arm_label)CE_loss = self.cross_entropy2d(arm_label, (label * (1 - clothes_mask)).transpose(0, 1)[0].long()) * 10


直观了吧,看网络G1只有一个输出arm_label。在训练中,就是模特换好新衣服后的分割图,与网络输出做损失反向传输。(模特原来是长袖,后面要穿短袖,自然胳膊是重点。脖子似乎没有人关心高领和低领的问题,待改进)。测试时直接生成要的穿color款衣服的模特分割图。

armlabel_map = generate_discrete_label(arm_label.detach(), 14, False)
dis_label = generate_discrete_label(arm_label.detach(), 14)

生成离散标签函数的输入是G1网络的输出结果arm_label

def generate_discrete_label(inputs, label_nc, onehot=True, encode=True):pred_batch = []size = inputs.size()for input in inputs:input = input.view(1, label_nc, size[2], size[3])pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)pred_batch.append(pred)pred_batch = np.array(pred_batch)pred_batch = torch.from_numpy(pred_batch)label_map = []for p in pred_batch:p = p.view(1, 256, 192)label_map.append(p)label_map = torch.stack(label_map, 0)if not onehot:return label_map.float().cuda()size = label_map.size()oneHot_size = (size[0], label_nc, size[2], size[3])input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)return input_label

上面要是看不明白出入输出变化,可以把结果或者结果的shape打印出来,对比着看。

继续net_G2:

G2_in = torch.cat([pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape)], 1)
fake_cl = self.G2.refine(G2_in)
fake_cl = self.sigmoid(fake_cl)
CE_loss += self.BCE(fake_cl, clothes_mask) * 10


G2的输入,是G1的输出+Pose+color+edge+noise组合输入,输出为模特穿上新衣后衣服的轮廓。训练的时候,是模特穿上新的衣服数据的衣服轮廓与G2的输出做损失,反向传播的。测试时,G2输出模特换上新衣的轮廓数据,此时,还没有图案和纹理的变化。
损失函数BCE参考:https://blog.csdn.net/qq_22210253/article/details/85222093

继续:

fake_cl_dis = torch.FloatTensor((fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
fake_cl_dis = morpho(fake_cl_dis, 1, True)
def morpho(mask, iter, bigger=True):kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))new = []for i in range(len(mask)):tem = mask[i].cpu().detach().numpy().squeeze().reshape(256, 192, 1) * 255tem = tem.astype(np.uint8)if bigger:tem = cv2.dilate(tem, kernel, iterations=iter)else:tem = cv2.erode(tem, kernel, iterations=iter)tem = tem.astype(np.float64)tem = tem.reshape(1, 256, 192)new.append(tem.astype(np.float64) / 255.0)new = np.stack(new)new = torch.FloatTensor(new).cuda()return new

detach(): 神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,torch.tensor.detach()和torch.tensor.detach_()函数来切断一些分支的反向传播。

cv2.getStructuringElement( ) 返回指定形状和尺寸的结构元素。
函数的第一个参数表示内核的形状,有三种形状可以选择。
矩形:MORPH_RECT;
交叉形:MORPH_CROSS;
椭圆形:MORPH_ELLIPSE;
第二和第三个参数分别是内核的尺寸以及锚点的位置。一般在调用erode以及dilate函数之前,先定义一个Mat类型的变量来获得getStructuringElement函数的返回值: 对于锚点的位置,有默认值Point(-1,-1),表示锚点位于中心点。element形状唯一依赖锚点位置,其他情况下,锚点只是影响了形态学运算结果的偏移。

cv2.erode()腐蚀:将前景物体变小,理解成将图像断开裂缝变大(在图片上画上黑色印记,印记越来越大)
dst = cv.erode(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])

cv2.dilate()膨胀:将前景物体变大,理解成将图像断开裂缝变小(在图片上画上黑色印记,印记越来越小)
dst = cv2.dilate(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])

numpy.stack(arrays, axis=0)
沿着新轴连接数组的序列。
axis参数指定新轴在结果尺寸中的索引。例如,如果axis=0,它将是第一个维度,如果axis=-1,它将是最后一个维度。
参数: 数组:array_like的序列每个数组必须具有相同的形状。axis:int,可选输入数组沿其堆叠的结果数组中的轴。
返回: 堆叠:ndarray堆叠数组比输入数组多一个维。

new_arm1_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda()
new_arm2_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda()
fake_cl_dis = fake_cl_dis * (1 - new_arm1_mask) * (1 - new_arm2_mask)
fake_cl_dis *= mask_forearm1_occ = clothes_mask * new_arm1_mask
arm2_occ = clothes_mask * new_arm2_mask
bigger_arm1_occ = morpho(arm1_occ, 10)
bigger_arm2_occ = morpho(arm2_occ, 10)
arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask
arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask
armlabel_map *= (1 - new_arm1_mask)
armlabel_map *= (1 - new_arm2_mask)
armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11
armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13
armlabel_map *= (1 - fake_cl_dis)
dis_label = encode(armlabel_map, armlabel_map.shape)
fake_c, warped, warped_mask, warped_grid = self.Unet(clothes, fake_cl_dis, pre_clothes_mask, grid)
mask = fake_c[:, 3, :, :]
mask = self.sigmoid(mask) * fake_cl_dis
fake_c = self.tanh(fake_c[:, 0:3, :, :])
fake_c = fake_c * (1 - mask) + mask * warped
skin_color = self.ger_average_color((arm1_mask + arm2_mask - arm2_mask * arm1_mask),(arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image)
occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask + clothes_mask)) * (1 - bigger_arm2_occ * (arm2_mask + arm1_mask + clothes_mask))
img_hole_hand = img_fore * (1 - clothes_mask) * occlude * (1 - fake_cl_dis)
self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()

前面的G1,G2的网络我没有展开。后面会专门分析网络里面的组成,和输入输出等细节。在这里,我们溯源一下这个Unet:

def define_UnetMask(input_nc, gpu_ids=[]):netG = UnetMask(input_nc, output_nc=4)netG.cuda(gpu_ids[0])netG.apply(weights_init)return netG

Unet来源于UnetMask:

class UnetMask(nn.Module):def __init__(self, input_nc, output_nc=3):super(UnetMask, self).__init__()self.stn = STNNet()nl = nn.InstanceNorm2dself.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))......def forward(self, input, refer, mask, grid):input, warped_mask, rx, ry, cx, cy, grid = self.stn(input, torch.cat([mask, refer, input], 1), mask, grid)# print(input.shape)conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1))......conv9 = self.conv9(torch.cat([conv1, up9], 1))return conv9, input, warped_mask, grid

UnetMask有一个特殊的网络层STNNet:

class STNNet(nn.Module):def __init__(self):super(STNNet, self).__init__()range = 0.9r1 = ranger2 = rangegrid_size_h = 5grid_size_w = 5assert r1 < 1 and r2 < 1  # if >= 1, arctanh will cause error in BoundedGridLocNettarget_control_points = torch.Tensor(list(itertools.product(np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)),np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)),)))Y, X = target_control_points.split(1, dim=1)target_control_points = torch.cat([X, Y], dim=1)self.target_control_points = target_control_points# self.get_row(target_control_points,5)GridLocNet = {'unbounded_stn': UnBoundedGridLocNet,'bounded_stn': BoundedGridLocNet,}['bounded_stn']self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points)self.tps = TPSGridGen(256, 192, target_control_points)def get_row(self, coor, num):for j in range(num):sum = 0buffer = 0flag = Falsemax = -1for i in range(num - 1):differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2if not flag:second_dif = 0flag = Trueelse:second_dif = torch.abs(differ - buffer)buffer = differsum += second_difprint(sum / num)def get_col(self, coor, num):for i in range(num):sum = 0buffer = 0flag = Falsemax = -1for j in range(num - 1):differ = (coor[(j + 1) * num + i, :] - coor[j * num + i, :]) ** 2if not flag:second_dif = 0flag = Trueelse:second_dif = torch.abs(differ - buffer)buffer = differsum += second_difprint(sum)def forward(self, x, reference, mask, grid_pic):batch_size = x.size(0)source_control_points, rx, ry, cx, cy = self.loc_net(reference)source_control_points = (source_control_points)# print('control points',source_control_points.shape)source_coordinate = self.tps(source_control_points)grid = source_coordinate.view(batch_size, 256, 192, 2)# print('grid size',grid.shape)transformed_x = grid_sample(x, grid, canvas=0)warped_mask = grid_sample(mask, grid, canvas=0)warped_gpic = grid_sample(grid_pic, grid, canvas=0)return transformed_x, warped_mask, rx, ry, cx, cy, warped_gpic


U_net不仅含有简单的神经网络层,还有STN网络层(spatial transform network,空间变换网络)。

前面还完成了step3的过程:

以上,G1,G2,Unet,step3完成后,才是G3网络。G3的输入:

G_in = torch.cat([img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape)], 1)
fake_image = self.G.refine(G_in.detach())
fake_image = self.tanh(fake_image)


返回所有输出结果:

return [self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real, loss_D_fake), fake_image,clothes, arm_label, L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid]

到这里,算是解释完了。


总结


总结,流程图最左侧三个灰蓝色模型,为基本输入color(服装),和img(模特)的预处理。本项目中,已经提供了它们三个的输出(未提供相关模型,不是重点),作为输入。
所以整个网络的关键输入就是:edge,color;pose,img,label.
整个网络就是G1+G2+Unet+G3构成;中间数据输入做了些些掩码和校正外,没有其他结构了。
小伙伴们,是不是弄清楚?

本文主要是自己巩固一下学习内容,有朋友询问有没有简化输入的方法,能不能不提供解析数据(区域分割)和关键点之类的,只输入模特和服装就可以看到试穿效果的呢,答案:当然有。免解析虚拟试穿参见此博客,虽然测试效果还不好,但简化输入也有了解决方案不是。

谨以此文与大家共勉!如果你觉得对你有用,请给个点赞

参考文献

感谢原作者:
[1] Yang, Han and Zhang, Ruimao and Guo, Xiaobao and Liu, Wei and Zuo, Wangmeng and Luo, Ping.Towards Photo-Realistic Virtual Try-On by Adaptively Generating-Preserving Image Content,IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),June,2020.

虚拟试穿--测试上衣代码详解相关推荐

  1. 【CV】Pytorch一小时入门教程-代码详解

    目录 一.关键部分代码分解 1.定义网络 2.损失函数(代价函数) 3.更新权值 二.训练完整的分类器 1.数据处理 2. 训练模型(代码详解) CPU训练 GPU训练 CPU版本与GPU版本代码区别 ...

  2. yii mysql 事务处理_Yii2中事务的使用实例代码详解

    前言 一般我们做业务逻辑,都不会仅仅关联一个数据表,所以,会面临事务问题. 数据库事务(Database Transaction) ,是指作为单个逻辑工作单元执行的一系列操作,要么完全地执行,要么完全 ...

  3. 代码详解|tensorflow实现 聊天AI--PigPig养成记(1)

    Chapter1.代码详解 完整代码github链接,Untitled.ipynb文件内. [里面的测试是还没训练完的时候测试的,今晚会更新训练完成后的测试结果] 修复了网上一些代码的bug,解决了由 ...

  4. DeepLearning tutorial(1)Softmax回归原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43157801 DeepLearning tutorial(1)Softmax回归原理简介 ...

  5. DeepLearning tutorial(3)MLP多层感知机原理简介+代码详解

    FROM:http://blog.csdn.net/u012162613/article/details/43221829 @author:wepon @blog:http://blog.csdn.n ...

  6. DeepLearning tutorial(4)CNN卷积神经网络原理简介+代码详解

    FROM: http://blog.csdn.net/u012162613/article/details/43225445 DeepLearning tutorial(4)CNN卷积神经网络原理简介 ...

  7. java编程数据溢出问题_Java数据溢出代码详解

    Java数据溢出代码详解 发布时间:2020-10-05 15:08:31 来源:脚本之家 阅读:103 作者:Pony小马 java是一门相对安全的语言,那么数据溢出时它是如何处理的呢? 看一段代码 ...

  8. java 文件下载详解_Java 从网上下载文件的几种方式实例代码详解

    废话不多说了,直接给大家贴代码了,具体代码如下所示: package com.github.pandafang.tool; import java.io.BufferedOutputStream; i ...

  9. socket 获取回传信息_Luat系列官方教程5:Socket代码详解

    文章篇幅较长,代码部分建议横屏查看,或在PC端打开本文链接.文末依然为爱学习的你准备了专属福利~ TCP和UDP除了在Lua代码声明时有一些不同,其他地方完全一样,所以下面的代码将以TCP长连接的数据 ...

最新文章

  1. 「AI初识境」近20年深度学习在图像领域的重要进展节点
  2. url index.php 怎么去掉,url怎么去掉index.php
  3. java new 多线程_Java多线程:Java多线程执行框架
  4. python 四足机器人运动学_撸了个四足机器人
  5. linux shell脚本链接操作符,Shell脚本中的操作符
  6. Android源码学习之如何使用eclipse+NDK
  7. C++ 一定要使用strcpy_s()函数 等来操作方法c_str()返回的指针
  8. HDU2153 仙人球的残影【数学计算+水题】
  9. OpenGL.Vertex Array Object (VAO).
  10. Drool实战系列(二)之eclipse安装drools插件
  11. w ndows7旗舰版网卡驱动,windows7万能网卡驱动官方下载
  12. adb官方最新下载链接和常用操作
  13. 22 个最常用的Python包
  14. Emmagee--APP性能测试工具的基本使用
  15. 队列与栈的原理及特点
  16. 海康威视2022内推 内推码
  17. 2021赣网杯web和misc部分wp
  18. Android Studio 4.0 新建项目gradle依赖base sdk以后报错 ‘assets/cfg/*‘ collided 的解决办法
  19. Automader 使用教程 - 01 你好,左右抽
  20. html和cssb笔记

热门文章

  1. JavaScript 执行机制
  2. Android面试问答题
  3. 全新MXone Pro自适应苹果CMSV10影视模板/亲测
  4. “带薪摸鱼”偷刷阿里老哥的面经宝典,三次挑战字节,终斩offer,修成正果!
  5. NS3 Tutorial 中文版:第四章 概念概述
  6. 做设计想要轻松接单 你要懂这些
  7. 一个赛马问题 25匹马5个赛道,每个赛道每次只能跑一匹马,问需要跑几次,能求出跑得最快的三匹马。...
  8. E2E测试---Cypress 使用
  9. 《Python数据分析与挖掘实战》第10章(下)——DNN2 筛选得“候选洗浴事件”3 构建模型
  10. netmask的作用