本文中主要采用自动编码器(Auto Encoder),生成对抗网络(Generative Adversarial Networks )的深度学习方法来对图像进行修复,主要由数据预处理、模型构建,模型训练和模型测试等部分组成。

一、自定义数据集

利用celeba数据生成数据集: 训练集[3003, 218, 178, 3],测试集[1001, 218, 178, 3],数据处理过程中为方便参数设定,在数据处理过程中改变了图片的尺寸大小[b, 218, 181, 3]
数据处理代码datasets.py如下:

import tensorflow as tf
import glob
import random
import csv
import os
import numpy as npos.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'"""file_path为我们之前获得图片数据的目录,filename为我们要加载的csv文件,pic_dict为我们获取的图片字典"""
def load_csv(file_path, filename, pic_dict):"""如果file_path目录下不存在filename文件,则创建filename文件"""if not os.path.exists(os.path.join(file_path, filename)):images = []"""遍历字典里所有的元素"""for name in pic_dict.keys():"""将该路径下图片的路径写到images列表中"""images += glob.glob(os.path.join(file_path, name, '*.jpg'))with open(os.path.join(file_path, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images:"""遍历images列表, 读取图片路径存入csv文件"""writer.writerow([img])"""打开前面写入的文件,读取图片路径添加到imgs列表"""imgs = []with open(os.path.join(file_path, filename)) as f:reader = csv.reader(f)for row in reader:img = rowimgs.append(img)return imgsdef load_datasets(file_path, mode='train'):"""创建图片字典"""pic_dict = {}"""遍历file_path路径下的文件夹"""for name in sorted(os.listdir(os.path.join(file_path))):"""跳过file_path目录下不是文件夹的文件"""if not os.path.isdir(os.path.join(file_path, name)):continue"""name为file_path目录下文件夹的名字"""pic_dict[name] = len(pic_dict.keys())"""调用load_csv方法,返回值images为储存图片的目录的列表"""images = load_csv(file_path, 'images.csv', pic_dict)"""我们将前60%取为训练集,后20%取为验证集,最后20%取为测试集,并返回"""if mode == 'train':images = images[:int(0.6 * len(images))]elif mode == 'val':images = images[int(0.6 * len(images)):int(0.8 * len(images))]else:images = images[int(0.8 * len(images)):]return  images"""将列表类型转化为tensor类型"""
def get_tensor(x):ims = []print(x)for i in x:"""读取路径下的图片"""p = tf.io.read_file(i)"""对图片进行解码,RGB,3通道"""p = tf.image.decode_jpeg(p, channels=3)"""修改图片大小"""p = tf.image.resize(p, [192, 224])# p = tf.image.resize(p, [181, 218])ims.append(p)"""将List类型转换为tensor类型,并返回"""ims = tf.convert_to_tensor(ims)return ims

二、自动编码器(Auto Encoder)实现

1.加载数据集

"""数据预处理,将3通道0-255的像素值转换为0-1,简化计算"""
def preprocess(x):x = tf.cast(x, dtype=tf.float32) / 255.return x"""加载数据集"""
images_train = load_datasets(root_img, mode='train')#训练集
images_test = load_datasets(root_img, mode='test')#测试集x_train= get_tensor(images_train)#把列表转化为张量,以便运算
x_test= get_tensor(images_test)
print(x_train.shape, x_test.shape)# (x_train, _), (x_test, _) = tf.keras.datasets.cifar10.load_data
db_train = tf.data.Dataset.from_tensor_slices((x_train))#切片操作
db_train = db_train.shuffle(100).map(preprocess).batch(20)#以100为单位进行打乱,每次处理20张图片db_test = tf.data.Dataset.from_tensor_slices((x_test))
db_test = db_test.map(preprocess).batch(20)#测试不用打乱,也不用预处理

2.保存生成图片到文件夹

"""保存25张图片,尺寸为224*192"""
def save_images(imgs, name):new_im = Image.new('RGB', (1120, 960))index = 0for i in range(0, 1120, 224):for j in range(0, 960, 192):im = x_concat[index]im = Image.fromarray(im, mode='RGB')new_im.paste(im, (i, j))index += 1new_im.save(name)

3.定义自编码函数及实现其传播过程

"""定义自动编码器函数"""
class AE(keras.Model):def __init__(self):super(AE, self).__init__()"""定义编码函数""""""(b, 224, 192, 3 ) => (b, 112, 96, 8)"""self.conv1 = layers.Conv2D(8, kernel_size=[3, 3], padding='same', activation= tf.nn.relu)self.down_pool1 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')"""(b, 112, 96, 8) => (b, 56, 48, 16)"""self.conv2 = layers.Conv2D(16, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.down_pool2 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')"""(b, 56, 48, 16) => (b, 28, 24, 32)"""self.conv3 = layers.Conv2D(32, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.down_pool3 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')"""(b, 28, 24, 32) => (b, 14, 12, 64)"""self.conv4 = layers.Conv2D(64, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.down_pool4 = layers.MaxPool2D(pool_size=[2, 2], strides=2, padding='same')"""定义解码函数""""""(b, 14, 12, 64) => (b, 28, 24, 32)"""self.transpose_conv1 = layers.Conv2D(32, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.up_pool1 = layers.UpSampling2D(size=[2, 2])"""(b, 28, 24, 32) => (b, 56, 48, 16)"""self.transpose_conv2 = layers.Conv2D(16, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.up_pool2 = layers.UpSampling2D(size=[2, 2])"""(b, 56, 48, 16) => (b, 112, 96, 8)"""self.transpose_conv3 = layers.Conv2D(8, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.up_pool3 = layers.UpSampling2D(size=[2, 2])"""(b, 112, 96, 8) => (b, 224, 192, 8)"""self.transpose_conv4 = layers.Conv2D(8, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)self.up_pool4 = layers.UpSampling2D(size=[2, 2])"""(b, 224, 192, 8) => (b, 208, 176, 3)"""self.transpose_conv5 = layers.Conv2D(3, kernel_size=[3, 3], padding='same', activation=tf.nn.relu)"""定义编码函数过程"""def encoder(self, x):x = self.conv1(x)x = self.down_pool1(x)x = self.conv2(x)x = self.down_pool2(x)x = self.conv3(x)x = self.down_pool3(x)x = self.conv4(x)x = self.down_pool4(x)return x"""定义解码函数过程"""def decoder(self, x):x = self.transpose_conv1(x)x = self.up_pool1(x)x = self.transpose_conv2(x)x = self.up_pool2(x)x = self.transpose_conv3(x)x = self.up_pool3(x)x = self.transpose_conv4(x)x = self.up_pool4(x)x = self.transpose_conv5(x)return x"""编码和解码"""def call(self, inputs, training =None):encode_num = self.encoder(inputs)decode_num = self.decoder(encode_num)return decode_num
"""定义模型对象,build输入的属性"""
model = AE()
model.build(input_shape=(None, 218, 178, 3))#由于参数设置原因,该ae模型输出的图片尺寸与输入存在偏差
model.summary()

4.训练及预测

"""开始训练"""
for epoch in range(100):for step, x in enumerate(db_train):with tf.GradientTape() as tape:#自动更新权值和偏置x_restruct = tf.nn.sigmoid(model(x))loss = tf.reduce_mean(tf.square(x - x_restruct))#最小二乘法,梯度下降# loss = tf.reduce_mean(tf.losses.categorical_crossentropy(x, x_restruct, from_logits=True))grads = tape.gradient(loss, model.trainable_variables)#计算梯度tf.optimizers.Adam(lr).apply_gradients(zip(grads,model.trainable_variables))#利用梯度进行参数优化if step% 1 == 0:print(epoch, step, float(loss))"""边训练边测试"""x = next(iter(db_test))x_pred =model(x)"""显示图片"""x_concat = tf.concat([x, x_pred], axis=0)x_concat = (x_concat.numpy() * 255.).astype(np.uint8)save_images(x_concat, 'D:\Files\digital_image_inpainting_paper\Restruct_image1\Rec_epoch_%d.png'%epoch)

三、生成对抗网络(GAN)实现

1.定义生成器

'''把100维的噪声通过卷积变成我们想要的张量'''
class Generator(keras.Model):def __init__(self):super(Generator, self).__init__()"""[b, 100] => [b, 224, 192, 3]"""self.fc = layers.Dense(4*5*512)self.de_conv1 = layers.Conv2DTranspose(256, kernel_size=[3, 3], strides=[3, 3], padding='same')self.bn1 = layers.BatchNormalization()self.de_conv2 = layers.Conv2DTranspose(128, kernel_size=[3, 3], strides=[3, 1], padding='same')self.bn2 = layers.BatchNormalization()self.de_conv3 = layers.Conv2DTranspose(64, kernel_size=[3, 3], strides=[3, 3], padding='same')self.bn3 = layers.BatchNormalization()self.de_conv4 = layers.Conv2DTranspose(32, kernel_size=[3, 3], strides=[2, 2], padding='same')self.bn4 = layers.BatchNormalization()self.de_conv5 = layers.Conv2DTranspose(3, kernel_size=[3, 3], strides=[1, 2], padding='valid')def call(self, inputs, training= None):x = self.fc(inputs)x = tf.reshape(x, [-1, 4, 5, 512])x = tf.nn.leaky_relu(x)x = tf.nn.leaky_relu(self.bn1(self.de_conv1(x), training=training))x = tf.nn.leaky_relu(self.bn2(self.de_conv2(x), training=training))x = tf.nn.leaky_relu(self.bn3(self.de_conv3(x), training=training))x = tf.nn.leaky_relu(self.bn4(self.de_conv4(x), training=training))x = self.de_conv5(x)x = tf.tanh(x)return x

2.定义判别器

"""定义判别器, 实质是分类的作用,输出为0~1的概率值"""
class Discriminator(keras.Model):#分类器def __init__(self):super(Discriminator, self).__init__()"""利用卷积进行降维处理,逐渐增加卷积核个数""""""[b, 224, 192, 3] => [b, 1]"""self.conv1 = layers.Conv2D(64, kernel_size=[5, 5], strides=3, padding='valid' )self.bn1 = layers.BatchNormalization()self.conv2 = layers.Conv2D(128, kernel_size=[5, 5], strides=3, padding='valid')self.bn2 = layers.BatchNormalization()self.conv3 = layers.Conv2D(256, kernel_size=[5, 5], strides=3, padding='valid')self.bn3 = layers.BatchNormalization()"""[b, w, h, c]""""""tf.Flatten函数把多维转化为一维,常用于卷积层到全连接层的过渡"""self.flatten = layers.Flatten()#类似于reshape"""定义最后的全连接层神经元个数为1"""self.fc = layers.Dense(1)"""定义传播过程"""def call(self, inputs, training= None):x = tf.nn.leaky_relu(self.bn1(self.conv1(inputs), training= training))x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))x = self.flatten(x)logits = self.fc(x)return logits

3. 定义生成器和判别器损失函数

"""采用sigmoid交叉熵损失函数进行参数更新"""
def celoss_ones(logits, smooth=0.0):return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,labels=tf.ones_like(logits)*(1.0 - smooth)))def celoss_zeros(logits, smooth=0.0):return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,labels=tf.zeros_like(logits)*(1.0 - smooth)))def d_loss_fn(generator, discriminator, input_noise, real_image, is_trainig):fake_image = generator(input_noise, is_trainig)d_real_logits = discriminator(real_image, is_trainig)d_fake_logits = discriminator(fake_image, is_trainig)d_loss_real = celoss_ones(d_real_logits, smooth=0.1)d_loss_fake = celoss_zeros(d_fake_logits, smooth=0.0)loss = d_loss_real + d_loss_fakereturn lossdef g_loss_fn(generator, discriminator, input_noise, is_trainig):fake_image = generator(input_noise, is_trainig)d_fake_logits = discriminator(fake_image, is_trainig)loss = celoss_ones(d_fake_logits, smooth=0.1)return loss

4.GAN模型训练及预测

加载数据集、保存生成图片的函数与前面类似

def main():tf.random.set_seed(22)np.random.seed(22)"""定义随机数种子"""generator = Generator()generator.build(input_shape=(batch_size, z_dim))generator.summary()discriminator = Discriminator()discriminator.build(input_shape=(batch_size, 218, 181, 3))discriminator.summary()"""定义生成器和判别器优化器"""d_optimizer = keras.optimizers.Adam(learning_rate=lr, beta_1=0.5)g_optimizer = keras.optimizers.Adam(learning_rate=lr, beta_1=0.5)for epoch in range(epochs):"""生成服从均匀分布并在—1~+1的随机噪声"""batch_z = tf.random.uniform(shape=[batch_size, z_dim], minval=-1., maxval=1.)"""加载数据"""batch_x = next(db_iter)batch_x = tf.reshape(batch_x, shape=[-1, 218, 181, 3])"""将图片像素转化为-1~+1"""batch_x = batch_x * 2.0 - 1.0"""利用梯度对判别器参数进行自动更新"""with tf.GradientTape() as tape:d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)grads = tape.gradient(d_loss, discriminator.trainable_variables)d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))"""利用梯度对生成器参数进行自动更新"""with tf.GradientTape() as tape:g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)grads = tape.gradient(g_loss, generator.trainable_variables)g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))"""参数每更新100次便进行数据预测"""if epoch % 100 == 0:print(epoch, 'd_loss:', float(d_loss), 'g_loss:', float(g_loss))val_z = np.random.uniform(-1, 1, size=(val_size, z_dim))fake_image = generator(val_z, training=False)x_concat = fake_imagex_concat = (((x_concat.numpy()+1) / 2) * 255.).astype(np.uint8)save_images(x_concat, 'D:\Files\digital_image_inpainting_paper\Restruct_image2\Rec_epoch_%d.png' % (epoch/100))if __name__ == '__main__':main()

四.实验效果

由于时间有限,不能较为全面的掌握深度学习算法,所以代码中参数设置不佳;同时图片尺寸和受实验条件限制,不能对模型很好地训练, 导致实验效果不理想。
AE模型图片修复图如图所示:

最后一列为修复图片,原图片为第一列图片,
GAN模型参数量巨大,模型复杂,训练输出还是一些噪点,在此不作展示

参考:
CSDN: https://blog.csdn.net/sq_damowang/article/details/103291640
课程:深度学习与Tensorflow2入门实战

基于深度学习的图像修复相关推荐

  1. 【图像修复】基于深度学习的图像修复算法的MATLAB仿真

    1.软件版本 matlab2021a 2.本算法理论知识 在许多领域,人们对图像质量的要求都很高,如医学图像领域.卫星遥感领域等.随着信息时代的快速发展,低分辨率图像已经难以满足特定场景的需要.因此, ...

  2. 震撼!英伟达用深度学习做图像修复,毫无ps痕迹

    在计算机视觉研究领域,NVIDIA常常让人眼前一亮. 比如"用Progressive Growing的方式训练 GAN,生成超逼真高清图像","用条件 GAN 进行 20 ...

  3. 毕业设计之 - 基于深度学的图像修复 图像补全

    1 前言 Hi,大家好,这里是丹成学长,今天向大家介绍 基于深度学的图像修复 图像补全 大家可用于 毕业设计 2 什么是图像内容填充修复 内容识别填充(译注: Content-aware fill , ...

  4. 基于深度学习的图像语义编辑

    深度学习在图像分类.物体检测.图像分割等计算机视觉问题上都取得了很大的进展,被认为可以提取图像高层语义特征.基于此,衍生出了很多有意思的图像应用. 为了提升本文的可读性,我们先来看几个效果图. 图1. ...

  5. 论文总结:基于深度学习的图像风格迁移研究

    基于深度学习的图像风格迁移研究 前言 图像风格迁移方法 基于图像迭代的图像风格迁移方法 基于模型迭代的图像风格迁移方法 卷积神经网络 生成对抗网络 CycleGAN 前言 什么是深度学习? 深度学习是 ...

  6. 读“基于深度学习的图像风格迁移研究综述”有感

    前言 关于传统非参数的图像风格迁移方法和现如今基于深度学习的图像风格迁移方法. 基于深度学习的图像风格迁移方法:基于图像迭代和模型迭代的两种方法的优缺点. 基于深度学习的图像风格迁移方法的存在问题及其 ...

  7. 深度学习图像融合_基于深度学习的图像超分辨率最新进展与趋势【附PDF】

    因PDF资源在微信公众号关注公众号:人工智能前沿讲习回复"超分辨"获取文章PDF 1.主题简介 图像超分辨率是计算机视觉和图像处理领域一个非常重要的研究问题,在医疗图像分析.生物特 ...

  8. 学习笔记之——基于深度学习的图像超分辨率重建

    最近开展图像超分辨率( Image Super Resolution)方面的研究,做了一些列的调研,并结合本人的理解总结成本博文~(本博文仅用于本人的学习笔记,不做商业用途) 本博文涉及的paper已 ...

  9. 基于深度学习的图像超分辨率方法 总结

    基于深度学习的SR方法 懒得总结,就从一篇综述中选取了一部分基于深度学习的图像超分辨率方法. 原文:基于深度学习的图像超分辨率复原研究进展 作者:孙旭 李晓光 李嘉锋 卓力 北京工业大学信号与信息处理 ...

  10. 基于深度学习的图像超分辨率重建

    最近开展图像超分辨率( Image Super Resolution)方面的研究,做了一些列的调研,并结合本人的理解总结成本博文~(本博文仅用于本人的学习笔记,不做商业用途) 本博文涉及的paper已 ...

最新文章

  1. Python 3 利用 Dlib 实现人脸检测和剪切
  2. 单例-双重检查锁定与延迟初始化
  3. mysql 表 类型_mysql表类型
  4. 【jQuery学习】—jQuery操作元素位置
  5. jQuery form表单的serialize()参数和其他参数 如何一起传给后端
  6. 【Hoxton.SR1版本】Spring Cloud Gateway之如何进行限流
  7. Redfish Data model (红鱼的资料模型)
  8. Lonlife-ACM 1010 - Alarm(找规律+素数打表)
  9. 复杂适应系统和swarm简介
  10. excel 隔行插入和错位
  11. 广告推广是什么意思?利用文章推广的方法做广告推广技巧总结
  12. 网络安全漏洞管理十大度量指标
  13. 钢筋计数VOC数据集
  14. 段码液晶显示屏液交期有多长?
  15. Numerical Optimization Ch10. Least-Squares Problems
  16. android画直角坐标系,用Android画个五角星
  17. 英语语法总结_01 五种基本句型
  18. 《Android源码设计模式解析与实战》读书笔记(二十一)
  19. 批量处理千万模型,3D开发必备接口程序!老子云新版API,正式上线!
  20. uniapp项目怎么连接手机真机调试

热门文章

  1. Skills | word批量修改图片为统一大小
  2. html创建表格没有网格线,excel里面的电子表格没有了网格线如何解决?
  3. CentOS 7.9命令行配置有线网卡
  4. Python 乌龟吃鱼问题求解
  5. 文件指针以及文件的打开与关闭
  6. 【算法笔记题解】《算法笔记知识点记录》第二章——快速入门4[结构体、输入输出、复杂度和黑盒测试]
  7. c语言绝对值函数作用,C语言实现abs和fabs绝对值
  8. 贪心科技机器学习训练营(六)
  9. Ubuntu安装Microsoft Windows Fonts微软字体库
  10. ODL之VTN详解-如何提供虚拟2层网络-port-map