Image Matting 图像抠图技术与深度学习抠图
Image Matting: 图像抠图技术是指从静态图像或者视频序列中抽取感兴趣目标的过程,在ps和视频编辑中有重要的应用。
1.Image Matting
Matting 技术可以表示为下面的图,与语义分割不同,它可以针对感兴趣前景物体进行细节处理、包括细微的毛发和透明的物体等。
其公式可以表示为前景、掩膜与背景三者间的关系(如果αp\alpha_pαp在0~1范围内,公式表示matting/composition问题,需要考虑透明度;如果αp\alpha_pαp为二值化的0/1,则称为了分割问题。):
matting技术主要包括了交互式抠图、幕布抠图和基于学习的抠图(参考:What’s the Role of Image Matting in Image Segmentation?):
2.Deep Image Matting
而随着深度学习技术的发展研究人员也提出了基于深度学习方法解决图像抠图问题的方案,英伟达的Deep Image Matting就是其中的代表(下图中skip层没有画出来*):
研究人员利用低层级和高层级的信息,利用原图和trimap作为输入得到粗糙的alpha通道map,随后利用一个小的卷积来优化得到更精细的matting结果,同时得到mask的alpha损失。其中还用前景瑜背景进行合成与原图进行比较得到图像组合损失。损失函数的定义如下:
一个是预测蒙版alpha与基准的损失:
Lαi=(api−agi)+ϵ2L_\alpha^i = \sqrt{(a_p^i-a_g^i)+\epsilon^2}Lαi=(api−agi)+ϵ2
另一个是预测出matting 与基准前景、基准背景合成的图像与输入之间的损失:
Lαi=(cpi−cgi)+ϵ2L_\alpha^i = \sqrt{(c_p^i-c_g^i)+\epsilon^2}Lαi=(cpi−cgi)+ϵ2
训练中的数据集来自于http://alphamatting.com/和研究人员自己提出的Composition-1k数据集。
这种方法的结果与其他相比如下:
2.2 代码实现
对于DeepImageMatting中,这里有一份基于Keras的实现:
首先构建了编码器和解码器架构,以及对应的优化模块。可以参照前文的图像理解,编解码部分各包含了五组操作模块。。
# copy from:https://github.com/foamliu/Deep-Image-Matting/blob/master/model.py
def build_encoder_decoder():# Encoder#--------------------编码器部分----------------------------### 根据架构图,编码与解码器各有五个操作组,编码器包括卷积和最大池化,解码器包括卷积和解卷积上采样input_tensor = Input(shape=(320, 320, 4))x = ZeroPadding2D((1, 1))(input_tensor)x = Conv2D(64, (3, 3), activation='relu', name='conv1_1')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(64, (3, 3), activation='relu', name='conv1_2')(x)orig_1 = x # 用于做skip-layerx = MaxPooling2D((2, 2), strides=(2, 2))(x)# >>首先两层卷积加一个池化层x = ZeroPadding2D((1, 1))(x)x = Conv2D(128, (3, 3), activation='relu', name='conv2_1')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(128, (3, 3), activation='relu', name='conv2_2')(x)orig_2 = x # 用于做skip-layerx = MaxPooling2D((2, 2), strides=(2, 2))(x)# >>第二次两层卷积加一个池化层x = ZeroPadding2D((1, 1))(x)x = Conv2D(256, (3, 3), activation='relu', name='conv3_1')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(256, (3, 3), activation='relu', name='conv3_2')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(256, (3, 3), activation='relu', name='conv3_3')(x)orig_3 = x # 用于做skip-layerx = MaxPooling2D((2, 2), strides=(2, 2))(x)# >>第一个三层卷积加一个池化层x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv4_1')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv4_2')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv4_3')(x)orig_4 = x # 用于做skip-layerx = MaxPooling2D((2, 2), strides=(2, 2))(x)# >>第二个三层卷积加一个池化层x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv5_1')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv5_2')(x)x = ZeroPadding2D((1, 1))(x)x = Conv2D(512, (3, 3), activation='relu', name='conv5_3')(x)orig_5 = x # 用于做skip-layerx = MaxPooling2D((2, 2), strides=(2, 2))(x)# >>第三个三层卷积加一个池化层# 解码器输出编码后的特征图#--------------------解码器部分----------------------------### Decoder# x = Conv2D(4096, (7, 7), activation='relu', padding='valid', name='conv6')(x)# x = BatchNormalization()(x) #细化编码部分,一维编码,没有使用# x = UpSampling2D(size=(7, 7))(x)x = Conv2D(512, (1, 1), activation='relu', padding='same', name='deconv6', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = UpSampling2D(size=(2, 2))(x)the_shape = K.int_shape(orig_5)shape = (1, the_shape[1], the_shape[2], the_shape[3])origReshaped = Reshape(shape)(orig_5) # 跳接层# print('origReshaped.shape: ' + str(K.int_shape(origReshaped)))xReshaped = Reshape(shape)(x)# print('xReshaped.shape: ' + str(K.int_shape(xReshaped)))together = Concatenate(axis=1)([origReshaped, xReshaped])# print('together.shape: ' + str(K.int_shape(together)))x = Unpooling()(together)# >>卷积上采样,512个核 Deconv6x = Conv2D(512, (5, 5), activation='relu', padding='same', name='deconv5', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = UpSampling2D(size=(2, 2))(x)the_shape = K.int_shape(orig_4)shape = (1, the_shape[1], the_shape[2], the_shape[3])origReshaped = Reshape(shape)(orig_4) # 跳接层xReshaped = Reshape(shape)(x)together = Concatenate(axis=1)([origReshaped, xReshaped])x = Unpooling()(together)# >>卷积上采样,512个5*5核 Deconv5x = Conv2D(256, (5, 5), activation='relu', padding='same', name='deconv4', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = UpSampling2D(size=(2, 2))(x)the_shape = K.int_shape(orig_3)shape = (1, the_shape[1], the_shape[2], the_shape[3])origReshaped = Reshape(shape)(orig_3) # 跳接层xReshaped = Reshape(shape)(x)together = Concatenate(axis=1)([origReshaped, xReshaped])x = Unpooling()(together)# >>卷积上采样,256个5*5核 Deconv4x = Conv2D(128, (5, 5), activation='relu', padding='same', name='deconv3', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = UpSampling2D(size=(2, 2))(x)the_shape = K.int_shape(orig_2)shape = (1, the_shape[1], the_shape[2], the_shape[3])origReshaped = Reshape(shape)(orig_2) # 跳接层xReshaped = Reshape(shape)(x)together = Concatenate(axis=1)([origReshaped, xReshaped])x = Unpooling()(together)# >>卷积上采样,128个5*5核 Deconv3x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv2', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = UpSampling2D(size=(2, 2))(x)the_shape = K.int_shape(orig_1)shape = (1, the_shape[1], the_shape[2], the_shape[3])origReshaped = Reshape(shape)(orig_1) # 跳接层xReshaped = Reshape(shape)(x)together = Concatenate(axis=1)([origReshaped, xReshaped])x = Unpooling()(together)# >>卷积上采样,64个5*5核 Deconv2x = Conv2D(64, (5, 5), activation='relu', padding='same', name='deconv1', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)# >>卷积上采样,64个5*5核 Deconv1x = Conv2D(1, (5, 5), activation='sigmoid', padding='same', name='pred', kernel_initializer='he_normal',bias_initializer='zeros')(x)# >>得到最终的输出Raw Alpha Pred 输出model = Model(inputs=input_tensor, outputs=x)return model
def build_refinement(encoder_decoder):input_tensor = encoder_decoder.inputinput = Lambda(lambda i: i[:, :, :, 0:3])(input_tensor)# 包含了输入的RGB图和编码器输入的粗糙mattex = Concatenate(axis=3)([input, encoder_decoder.output]) x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)x = Conv2D(64, (3, 3), activation='relu', padding='same', kernel_initializer='he_normal',bias_initializer='zeros')(x)x = BatchNormalization()(x)# (Covn+Relu)*三次重复x = Conv2D(1, (3, 3), activation='sigmoid', padding='same', name='refinement_pred', kernel_initializer='he_normal',bias_initializer='zeros')(x)# Covn输入精炼的refine aplha的残差model = Model(inputs=input_tensor, outputs=x)return model
# from:https://github.com/foamliu/Deep-Image-Matting/blob/master/utils.py
# 蒙版预测损失
def alpha_prediction_loss(y_true, y_pred):mask = y_true[:, :, :, 1]diff = y_pred[:, :, :, 0] - y_true[:, :, :, 0]diff = diff * masknum_pixels = K.sum(mask)return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon)
# 图像重建损失
def compositional_loss(y_true, y_pred):mask = y_true[:, :, :, 1]mask = K.reshape(mask, (-1, img_rows, img_cols, 1))image = y_true[:, :, :, 2:5]fg = y_true[:, :, :, 5:8]bg = y_true[:, :, :, 8:11]c_g = imagec_p = y_pred * fg + (1.0 - y_pred) * bgdiff = c_p - c_gdiff = diff * masknum_pixels = K.sum(mask)return K.sum(K.sqrt(K.square(diff) + epsilon_sqr)) / (num_pixels + epsilon)
#总损失
def overall_loss(y_true, y_pred):w_l = 0.5return w_l * alpha_prediction_loss(y_true, y_pred) + (1 - w_l) * compositional_loss(y_true, y_pred)
# copy from https://github.com/foamliu/Deep-Image-Matting/blob/master/train.py
#----- 省略各种初始化内容 ------## Load our model, added support for Multi-GPUsnum_gpu = len(get_available_gpus())if num_gpu >= 2:with tf.device("/cpu:0"):model = build_encoder_decoder() #构建编码器model = build_refinement(model) #构建精炼模块if pretrained_path is not None:model.load_weights(pretrained_path)else:migrate_model(model)final = multi_gpu_model(model, gpus=num_gpu)decoder_target = tf.placeholder(dtype='float32', shape=(None, None, None, None))final.compile(optimizer='nadam', loss=overall_loss, target_tensors=[decoder_target]) #编译,加入模型、损失等等参数# Final callbackscallbacks = [tensor_board, model_checkpoint, early_stop, reduce_lr] # 优化模型,训练过程final.fit_generator(train_gen(),steps_per_epoch=num_train_samples // batch_size,validation_data=valid_gen(),validation_steps=num_valid_samples // batch_size,epochs=epochs,verbose=1,callbacks=callbacks,use_multiprocessing=True,workers=2)
3.AlphaGAN 自然图像抠图
4.基于深度学习的其他抠图方法
A Late Fusion CNN for Digital Matting,基于融合的Matting方法_CVPR2019,这一方法中使用了两个解码器分支分别预测前景和背景,而后在经过融合得到了更为细腻的alpha通道结果。两个解码器使得模型的容量更大,更容易训练出好结果。同时由于分类器得到了前景背景的中间结果,这种方法还可以提供Trimap。
模型中包含了一个编码器和两个解码器(分割部分),以及一个用于融合的全卷积网络(融合部分)。最后输出的matting是预测出前景和背景的概率图和融合权重map的线性融合结果。
对于任何一张图像来说,可以认为是前景图F、背景图B和对应的蒙版α三者通过下面的公式合成的:
Ip=αpFp+(1−αp)BpI_p = \alpha_pF_p+(1-\alpha_p)B_pIp=αpFp+(1−αp)Bp
一般方法首先学习前景、背景;而后计算每个像素属于前背景的概率;最后得到alpha通道。在本文的方法中,研究人员通过融合输出的前景背景来得到最终的蒙版通道,则alphaalphaalpha的计算公式(也就是融合网络部分)可以被写为:
αp=βpFpˉ+(1−βp)(1−Bpˉ)\alpha_p = \beta_p\bar{F_p}+(1-\beta_p)(1-\bar{B_p})αp=βpFpˉ+(1−βp)(1−Bpˉ)
如果从优化的视角来看,对上述的方程右边对β进行差分,则可以得到下面的式子:
Bpˉ+Fpˉ=1\bar{B_p}+\bar{F_p}=1Bpˉ+Fpˉ=1
融合网络在前景背景预测精确的时候可以聚焦于学习前景和背景间的转变区域的学习、另一方面可以在F+B≠1的渐变区域通过精心设计的损失函数来提供更有效的梯度训练融合网络。
整个网络的数据集主要分为三个部分:基于VOC的25张背景图像和成测试集;基于228 human images with high-quality alpha mattes combined with another 211 human foreground objects from the DIM dataset [39];利用COCO作为背景与人物前景构建合成图像作为训练集;Composition-1k testing dataset in [39]。(from:part4)
下表显示了与其他方法结果的比较:
一些得到的结果:
github:https://github.com/yunkezhang/FusionMatting
Learning-based Sampling for Natural Image Matting,基于前背景颜色层估计的matting方法_cvpr2019,这篇文章通过估计前景和背景颜色层来作为网络先验,最后实现透明度估计。颜色层估计更适合于神经网络,同时颜色的有效性减少了合成公式中的未知量,提高透明度预测的有效性。
模型主要的流程如下图所示,背景可以视为被前景遮挡住的不透明材质,具有连续的纹理和结构;首先基于连续结构估计出背景图像,随后利用背景估计得到前景图像,最后利用前景和背景作为输入来估计最终的matting。
web:http://people.inf.ethz.ch/aksoyy/samplenet/
Natural Image Matting using Deep Convolutional Neural Networks结合局域和非局域的方法来重建高质量的alpha通道_ECCV2016,结合局域和非局域的matting方法来得到全局和局部的高质量alpha。需要融合的两种方法优劣如下:
得到的CNN 如下图所示,输入包含了RGB,和两种形式的alphamap,输出为得到优化后的蒙版:
高分辨率抠图 HDMatt: High-Resolution Deep Image Matting,利用Trimap引导的方法实现长程依赖的获取,提升了基于片元抠图方法的连续性和信息依赖性。
ICCV2019, TPAMI, IndexNet ,在特征图条件下动态预测每个局域区域的索引值,预测出的index引导编码过程中的下采样和解码过程中的上采样。
code
AAAI20 提出的GCAMatting,基于引导上下文注意力的抠图方法。(From SJTU,Yaoyi Li, Hongtao Lu∗)
5.抠图的相关知识教程
5.1 布朗大学的计算摄影和图像操作
5.2 CMU的计算机摄影学
Gradient-domain image processing:matting在这里
Image Matting 图像抠图技术与深度学习抠图相关推荐
- 第 11 章 基于小波技术进行图像融合--MATLAB人工智能深度学习模块
matlab实现基于小波技术进行图像融合–人工智能深度学习模块 该案例相对简单.实现程序 % MAINFORM MATLAB code for MainForm.fig % MAINFORM, by ...
- 深度学习图像融合_基于深度学习的图像超分辨率最新进展与趋势【附PDF】
因PDF资源在微信公众号关注公众号:人工智能前沿讲习回复"超分辨"获取文章PDF 1.主题简介 图像超分辨率是计算机视觉和图像处理领域一个非常重要的研究问题,在医疗图像分析.生物特 ...
- 基于病害区域图像的植物病害识别深度学习(创新点好理解)
Deep Learning for Plant Disease Identification from Disease Region Images 1.摘要解读 [目的]提出了一种利用病理分割的病害区 ...
- 我作为bertelsmann技术和深度学习纳米学位毕业生的经验
One of the responsible things to do when a year is ending is to reflect on it. What accomplishments ...
- 单帧图像超分辨率与深度学习
单帧图像超分辨率技术涉及到增加小图像的大小,同时尽可能地防止其质量下降.这一技术有着广泛用途,包括卫星和航天图像分析.医疗图像处理.压缩图像/视频增强及其他应用.我们将在本文借助三个深度学习模型解决这 ...
- 图像图片处理_深度学习
20220811 图像数据增强:尺寸减小剪裁,水平翻转 中间和四个角分别裁剪, 水平翻转之后,再裁剪 rgb通道做协方差矩阵,减少受光照的影响的程度? imagenet 20220714 202206 ...
- docker 训练深度学习_利用RGB图像训练MultiModality的深度学习模型进行图像分割
▼更多精彩推荐,请关注我们▼ Dragonfly软件的一个特色功能就是可以让用户自己方便快速地训练深度学习的模型,实现图像分割等工作的智能完成.关于Dragonfly里面深度学习工具和智能分割向导工具 ...
- 基于图像的小麦真菌病害深度学习识别(数据+平台)
摘要 由病原真菌引起的谷物病害会显著降低作物产量.许多文化都与他们接触.这种疾病很难大规模控制;因此,相关的方法之一是农田监测,这有助于在早期发现病害,并采取措施防止其传播.基于数字图像分析的疾病识别 ...
- 【图像修复】基于深度学习的图像修复算法的MATLAB仿真
1.软件版本 matlab2021a 2.本算法理论知识 在许多领域,人们对图像质量的要求都很高,如医学图像领域.卫星遥感领域等.随着信息时代的快速发展,低分辨率图像已经难以满足特定场景的需要.因此, ...
- Deep Learning for Image and Point Cloud Fusion in Autonomous Driving: A Review(自动驾驶图像点云融合深度学习综述)论文笔记
原文链接:https://arxiv.org/pdf/2004.05224.pdf II.深度学习的简要回顾 B.点云深度学习 本文将点云深度学习方法分为5类,即基于体素.基于2D视图.基于点.基于图 ...
最新文章
- Spring Boot 2.x基础教程:使用 Thymeleaf开发Web页面
- (三)ajax请求不同源之websocket跨域
- Jquery 获取 radio选中值
- 绝对定位多个字居中显示的css
- 投递简历得不到回复,并不是你的简历不好,可能是这个原因
- 【手机】Windows Mobile手机软件安装卸载方法
- nlp基础—6.EM算法
- 一致性Hash与负载均衡
- Tensorflow实现fashion-mnist数据集的图片识别项目代码
- oracle数据导出工具sqluldr2安装及使用
- 自定义httpSession
- Python项目:爬取IT互联网高薪热门职位数据并进行可视化分析
- PX4飞控学习与开发(六)-利用 VScode 修改源码
- c语言函数初始化,c语言初始化输入和输出函数
- 反病毒工具-WinDBG
- 跑通CHPDet模型
- 各大EMM厂商功能比较 第三部分 Network Gateway比较
- python商务图表_Excel 数据之美:科学图表与商业图表的绘制(全彩)
- Java如何创建支付接口
- 判断点是否在视景体内的参考资料