Generative Image Inpainting with Contextual Attention

今天介绍CVPR 2018Generative Image Inpainting with Contextual Attention

paper: https://arxiv.org/abs/1801.07892, demo http://jiahuiyu.com/deepfill

github:https://github.com/JiahuiYu/generative_inpainting

先看效果:

上述是作者修复的结果,我自己训练后修复的如下:

这里生成了两个不同情况的图,因为使用了两个不同的pre-train Model

下面介绍如何使用:

  1. Requirements:

    • Install python3.
    • Install tensorflow (tested on Release 1.3.0, 1.4.0, 1.5.0, 1.6.0, 1.7.0).
    • Install tensorflow toolkit neuralgym (run pip install git+https://github.com/JiahuiYu/neuralgym).
  2. Training:
    • Prepare training images filelist and shuffle it (example).
    • Modify inpaint.yml to set DATA_FLIST, LOG_DIR, IMG_SHAPES and other parameters.
    • Run python3 train.py.

这里重点介绍如何准备自己的训练集,直接写了个python脚本自动处理即可。gen_flist.py自动将源数据集划分为训练集和验证集。并生成项目需要的格式。

# 将原数据集分为training ,validation  by gavin
import os
import randomimport argparse#划分验证集训练集
_NUM_TEST = 20000parser = argparse.ArgumentParser()
parser.add_argument('--folder_path', default='/home/gavin/Dataset/celeba', type=str,help='The folder path')
parser.add_argument('--train_filename', default='./data/celeba/train_shuffled.flist', type=str,help='The train filename.')
parser.add_argument('--validation_filename', default='./data/celeba/validation_static_view.flist', type=str,help='The validation filename.')def _get_filenames(dataset_dir):photo_filenames = []image_list = os.listdir(dataset_dir)photo_filenames = [os.path.join(dataset_dir, _) for _ in image_list]return photo_filenamesif __name__ == "__main__":args = parser.parse_args()data_dir = args.folder_path# get all file namesphoto_filenames = _get_filenames(data_dir)print("size of celeba is %d" % (len(photo_filenames)))# 切分数据为测试训练集random.seed(0)random.shuffle(photo_filenames)training_file_names = photo_filenames[_NUM_TEST:]validation_file_names = photo_filenames[:_NUM_TEST]print("training file size:",len(training_file_names))print("validation file size:", len(validation_file_names))# make output file if not existedif not os.path.exists(args.train_filename):os.mknod(args.train_filename)if not os.path.exists(args.validation_filename):os.mknod(args.validation_filename)# write to filefo = open(args.train_filename, "w")fo.write("\n".join(training_file_names))fo.close()fo = open(args.validation_filename, "w")fo.write("\n".join(validation_file_names))fo.close()# print processprint("Written file is: ", args.train_filename)

最终生成的格式如下图:

  1. Resume training:

    • Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
    • Run python3 train.py.
  2. Testing:
    • Run python test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir.

大概就是以上操作,后面贴上我实际训练和测试的脚本。

配置文件

其中inpaint.yml中要注意的是,在恢复训练模型的时候,MODEL_RESTORE的值:

多GPU模式训练

如果使用多个GPU训练,需要改三处地方,分别是inpaint.yml中两处,如下

# training
NUM_GPUS: 2
GPU_ID: [0,1]  # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3]

分别指定将gpu使用的个数及各自的id,第三处,也是最重要而且特别容易忽略的,在train.py中修改这里

# train generator with primary trainer ,MultiGPUTrainer. for multi gpu,and add num_gpus=config.NUM_GPUStrainer = ng.train.Trainer(optimizer=g_optimizer,var_list=g_vars,max_iters=config.MAX_ITERS,graph_def=multigpu_graph_def,grads_summary=config.GRADS_SUMMARY,gradient_processor=gradient_processor,graph_def_kwargs={'model': model, 'data': data, 'config': config, 'loss_type': 'g'},spe=config.TRAIN_SPE,log_dir=log_prefix,)'''trainer = ng.train.MultiGPUTrainer(optimizer=g_optimizer,var_list=g_vars,max_iters=config.MAX_ITERS,graph_def=multigpu_graph_def,grads_summary=config.GRADS_SUMMARY,gradient_processor=gradient_processor,graph_def_kwargs={'model': model, 'data': data, 'config': config, 'loss_type': 'g'},spe=config.TRAIN_SPE,log_dir=log_prefix,num_gpus = config.NUM_GPUS,)'''

即有两种调用方式,一种单GPU跑,一种多GPU模式,而多GPU模式下需要加上参数

num_gpus = config.NUM_GPUS,

脚本:

# training
python3 train.py# Resume training:
Modify MODEL_RESTORE flag in inpaint.yml. E.g., MODEL_RESTORE: 20180115220926508503_places2_model.
Run python3 train.py.#Testing:python3 test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir.python3 test.py --image examples/celeba/celebahr_patches_164787_input.png --mask examples/center_mask_256.png
--output examples/output_celeba.png --checkpoint_dir model_logs/celebA_model/snap-60000# for any other image,you can generate mask and masked image first ,then predict1. python3 generate_mask.py --img ./examples/celeba/000035.jpg --HEIGHT 64 --WIDTH 642. python3 test.py --image ./data/mask_img/masked/000035.jpg --mask ./data/mask_img/mask/000035.jpg \
--output examples/output_000035.png --checkpoint_dir model_logs/celebA_model/snap-90000

测试

实际测试过程中,对于任一张图,需要输入mask,和input,这里需要我们自己生成,为了便于随机生成mask,我写了如下代码,可以随机生成规则及不规则的mask

'''
利用opencv随机给图像生成带mask区域的图
author:gavin
'''# import itertools
# import matplotlib
# import matplotlib.pyplot as plt
from copy import deepcopy
from random import randint
import numpy as np
import cv2
import os
import sys
import tensorflow as tfimport argparseparser = argparse.ArgumentParser()
parser.add_argument('--img', default='./examples/celeba/000042.jpg', type=str,help='The input img for single image ')parser.add_argument('--input_dirimg', default='./data/mask_img/src_img/', type=str,help='The input folder path for multi-images')
parser.add_argument('--output_dirmask', default='./data/mask_img/mask/', type=str,help='The output file path of mask.')
parser.add_argument('--output_dirmasked', default='./data/mask_img/masked/', type=str,help='The output file path of masked.')
parser.add_argument('--MAX_MASK_NUMS', default='16', type=int,help='max numbers of masks')parser.add_argument('--MAX_DELTA_HEIGHT', default='32', type=int,help='max height of delta')
parser.add_argument('--MAX_DELTA_WIDTH', default='32', type=int,help='max width of delta')parser.add_argument('--HEIGHT', default='128', type=int,help='max height of delta')
parser.add_argument('--WIDTH', default='128', type=int,help='max width of delta')parser.add_argument('--IMG_SHAPES', type=eval, default=(256, 256, 3))# 随机生成不规则掩膜
def random_mask(height, width, config,channels=3):"""Generates a random irregular mask with lines, circles and elipses"""img = np.zeros((height, width, channels), np.uint8)# Set size scalesize = int((width + height) * 0.02)if width < 64 or height < 64:raise Exception("Width and Height of mask must be at least 64!")# Draw random linesfor _ in range(randint(1, config.MAX_MASK_NUMS)):x1, x2 = randint(1, width), randint(1, width)y1, y2 = randint(1, height), randint(1, height)thickness = randint(3, size)cv2.line(img, (x1, y1), (x2, y2), (1, 1, 1), thickness)# Draw random circlesfor _ in range(randint(1, config.MAX_MASK_NUMS)):x1, y1 = randint(1, width), randint(1, height)radius = randint(3, size)cv2.circle(img, (x1, y1), radius, (1, 1, 1), -1)# Draw random ellipsesfor _ in range(randint(1, config.MAX_MASK_NUMS)):x1, y1 = randint(1, width), randint(1, height)s1, s2 = randint(1, width), randint(1, height)a1, a2, a3 = randint(3, 180), randint(3, 180), randint(3, 180)thickness = randint(3, size)cv2.ellipse(img, (x1, y1), (s1, s2), a1, a2, a3, (1, 1, 1), thickness)return 1 - img'''
# this for test
#  %matplotlib inline   ==> plt.show()
# Plot the results
_, axes = plt.subplots(5, 5, figsize=(20, 20))
axes = list(itertools.chain.from_iterable(axes))for i in range(len(axes)):# Generate imageimg = random_mask(500, 500)# Plot image on axisaxes[i].imshow(img * 255)plt.show()'''def random_bbox(config):"""Generate a random tlhw with configuration.Args:config: Config should have configuration including IMG_SHAPES,VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH.Returns:tuple: (top, left, height, width)"""img_shape = config.IMG_SHAPESimg_height = img_shape[0]img_width = img_shape[1]maxt = img_height  - config.HEIGHTmaxl = img_width  - config.WIDTHt = tf.random_uniform([], minval=0, maxval=maxt, dtype=tf.int32)l = tf.random_uniform([], minval=0, maxval=maxl, dtype=tf.int32)h = tf.constant(config.HEIGHT)w = tf.constant(config.WIDTH)return (t, l, h, w)def bbox2mask(bbox, config, name='mask'):"""Generate mask tensor from bbox.Args:bbox: configuration tuple, (top, left, height, width)config: Config should have configuration including IMG_SHAPES,MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH.Returns:tf.Tensor: output with shape [1, H, W, 1]"""def npmask(bbox, height, width, delta_h, delta_w):mask = np.zeros((1, height, width, 1), np.float32)h = np.random.randint(delta_h//2+1)w = np.random.randint(delta_w//2+1)mask[:, bbox[0]+h:bbox[0]+bbox[2]-h,bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1.return maskwith tf.variable_scope(name), tf.device('/cpu:0'):img_shape = config.IMG_SHAPESheight = img_shape[0]width = img_shape[1]mask = tf.py_func(npmask,[bbox, height, width,config.MAX_DELTA_HEIGHT, config.MAX_DELTA_WIDTH],tf.float32, stateful=False)mask.set_shape([1] + [height, width] + [1])return mask# 对于矩形mask随机生成
def random_mask_rect(img_path,config,bsave=True):# Load imageimg_data = cv2.imread(img_path)#img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)'''# generate mask, 1 represents masked pointbbox = random_bbox(config)mask = bbox2mask(bbox, config, name='mask_c')img_pos = img_data / 127.5 - 1.masked_img = img_pos * (1. - mask)'''# 创建矩形区域,填充白色255img_shape = config.IMG_SHAPESimg_height = img_shape[0]img_width = img_shape[1]image = cv2.resize(img_data, (img_width, img_height))rectangle = np.zeros(image.shape[0:2], dtype=np.uint8)maxt = img_height - config.HEIGHTmaxl = img_width - config.WIDTHh = config.HEIGHTw = config.WIDTHx = randint(0, maxt - 1)y = randint(0, maxl - 1)mask = cv2.rectangle(rectangle,(x, y), (x+w, y+h) , 255, -1)  # 修改这里 (78, 30), (98, 46)masked_img = deepcopy(image)masked_img[mask == 255] = 255print("shape of mask:",mask.shape)print("shape of masked_img:",masked_img.shape)if bsave:save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])cv2.imwrite(save_name_mask,mask)save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])cv2.imwrite(save_name_masked, masked_img)return masked_img,maskdef get_path(config):if not os.path.exists(config.input_dirimg):os.mkdir(config.input_dirimg)if not os.path.exists(config.output_dirmask):os.mkdir(config.output_dirmask)if not os.path.exists(config.output_dirmasked):os.mkdir(config.output_dirmasked)# 给单个图像生成带mask区域的图
def load_mask(img_path,config,bsave=False):# Load imageimg = cv2.imread(img_path)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)shape = img.shapeprint("Shape of image is: ",shape)# Load maskmask = random_mask(shape[0], shape[1],config)# Image + maskmasked_img = deepcopy(img)masked_img[mask == 0] = 255mask = mask * 255if bsave:save_name_mask = os.path.join(config.output_dirmask, img_path.split('/')[-1])cv2.imwrite(save_name_mask,mask)save_name_masked = os.path.join(config.output_dirmasked, img_path.split('/')[-1])cv2.imwrite(save_name_masked, masked_img)return masked_img,mask# 批量生成带mask区域的图像
def img2maskedImg(dataset_dir):files = []image_list = os.listdir(dataset_dir)files = [os.path.join(dataset_dir, _) for _ in image_list]length = len(files)for index,jpg in enumerate(files):try:sys.stdout.write('\r>>Converting image %d/%d ' % (index,length))sys.stdout.flush()load_mask(jpg,config,True)# 将已经转换的图片移动到指定位置#shutil.move(png, output_dirHR)except IOError as e:print('could not read:',jpg)print('error:',e)print('skip it\n')sys.stdout.write('Convert Over!\n')sys.stdout.flush()# python3 generate_mask.py --img ./examples/celeba/000042.jpg --HEIGHT 64 --WIDTH 64if __name__ == '__main__':config = parser.parse_args()get_path(config)# 单张图像生成mask#img = './data/test.jpg'#masked_img,mask = load_mask(img,config,True)# 批量图像处理==>圆形,椭圆,直线#img2maskedImg(config.input_dirimg)# 矩形特殊处理 处理同样shape的图片(256,256,3) fix me#img = './examples/celeba/000042.jpg'img = config.imgmasked_img, mask = random_mask_rect(img,config)'''# Show side by side_, axes = plt.subplots(1, 3, figsize=(20, 5))axes[0].imshow(img)axes[1].imshow(mask*255)axes[2].imshow(masked_img)plt.show()'''

效果:

mask,masked,output

图像修复实例解析(二)相关推荐

  1. 图像修复 图像补全_图像修复简介

    图像修复 图像补全 In practical applications, images are often corroded by noise. These noises are dust or wa ...

  2. 结构感知图像修复:ICCV2019论文解析

    结构感知图像修复:ICCV2019论文解析 StructureFlow: Image Inpainting via Structure-aware Appearance Flow 论文链接: http ...

  3. Zxing和QR CODE 生成与解析二维码实例(普通篇)

    首先下载对应的jar包,本实例用的是Zxing2.2jar 下载地址:http://download.csdn.net/detail/gao36951/8161861 Zxing是Google提供的关 ...

  4. OpenCV通过填充修复损坏的图像的实例(附完整代码)

    OpenCV通过填充修复损坏的图像的实例 OpenCV通过填充修复损坏的图像的实例 OpenCV通过填充修复损坏的图像的实例 #include "opencv2/imgcodecs.hpp& ...

  5. UG/NX二次开发Siemens官方NXOPEN实例解析—1.6 BlockStyler/SelectionExample

    列文章目录 UG/NX二次开发Siemens官方NXOPEN实例解析-1.1 BlockStyler/ColoredBlock UG/NX二次开发Siemens官方NXOPEN实例解析-1.2 Blo ...

  6. OpenCVSharp(C# OpenCV)图像去水印实例(二) 去除水印并保留文本原始色彩

    导读 具体介绍与实现步骤请参考下面文章: 实战 | 文本图片去水印--同时保持文本原始色彩(附源码)_Color Space的博客-CSDN博客点击下方卡片,关注"OpenCV与AI深度学习 ...

  7. UG/NX二次开发Siemens官方NXOPEN实例解析—2.8 DrawingCycle(图纸打印)

    列文章目录 UG/NX二次开发Siemens官方NXOPEN实例解析-2.1 AssemblyViewer(树列表) UG/NX二次开发Siemens官方NXOPEN实例解析-2.2 Selectio ...

  8. UG/NX二次开发Siemens官方NXOPEN实例解析—1.8 BlockStyler/UDB_CreateCylinder

    列文章目录 UG/NX二次开发Siemens官方NXOPEN实例解析-1.1 BlockStyler/ColoredBlock UG/NX二次开发Siemens官方NXOPEN实例解析-1.2 Blo ...

  9. UG/NX二次开发Siemens官方NXOPEN实例解析—1.2 BlockStyler/EditExpression

    列文章目录 UG/NX二次开发Siemens官方NXOPEN实例解析-1.1 BlockStyler/ColoredBlock UG/NX二次开发Siemens官方NXOPEN实例解析-1.2 Blo ...

最新文章

  1. 初学Java——选择
  2. 多线程1.学习资料2.面试题3.知识点
  3. WARN ServletController:171 - Can't find the the request for xxxx's Observer
  4. js正则--验证6-12位至少包含数字、小写字母和大些字母中至少两种字符,
  5. 移动平台MOBA发热与帧率优化
  6. api过滤器_了解播放过滤器API
  7. 使用Istio进行多集群部署管理:单控制平面 Gateway 连接拓扑
  8. linux下使用dd命令制作ubuntu的u盘启动,Ubuntu使用dd命令制作U盘系统启动盘
  9. python爬取学籍_仝卓学籍造假微博道歉,用Python抓取微博的评论看看群众都说什么...
  10. 为什么大厂都在用 GO 语言?读透 GO 语言的切片
  11. [Toolkit]最新Silverlight Toolkit中的DragDrop支持
  12. 办暂住证,郁闷,极度不爽.
  13. 一台空调的容量是多少_大型中央空调工程节能改造方案如何选择
  14. ATmega328p使用硬件SPI与模拟SPI驱动74HC595,protues仿真
  15. RestTemplate 超时值
  16. 码云上最棒的Java管理后台系统
  17. 当你看不清自己的时候,读一些句子会有启发
  18. Unusual Sequences
  19. java中paras是什么意思_paras的使用方法是什么
  20. 2016“百度之星”-测试赛

热门文章

  1. 在Ubuntu 20.04上面搭建嵌入式开发环境
  2. 【4 于博士Cadence SPB15.7 快速入门视频】建立不规则SOIC封装NE5532
  3. 如何把操作系统迁移到新电脑/硬盘
  4. 2072-歌手大奖赛
  5. 《APP开发》APP规范实例-详细的UI设计方法
  6. ubuntu1404 安装 ppsspp
  7. 「BTC之城」的奇幻漂流
  8. 基于SSM的JSP MYSQL汽车租赁系统的汽车出租管理系统-mysqljava汽车出租管理系统租车管理系统
  9. 四年级计算机课教学安排,四年级计算机教学的计划
  10. 首先下载安装data.table包_首次揭秘“超级签”与企业包行业内幕