GAN

  • 6 COGAN(耦合生成对抗网络,1个模型2个用途)
  • 7 LSGAN(最小二乘GAN,均方差替换交叉熵)
    • 7.1 训练思路
  • 8 CycleGAN(风格转换)
    • 8.1 训练思路
  • 9 SRGAN(图像分辨率提升GAN)
    • 9.1 生成网络
    • 9.2 判别网络
    • 9.3 训练思路
    • 9.4 几种尺度
    • 9.5 3个文件代码
      • 第一个文件进行图像数据的读取,
      • 第二个文件,进行模型的训练
      • 第3个文件进行预测
  • 10 总结
  • 参考资料

6 COGAN(耦合生成对抗网络,1个模型2个用途)

COGAN是一种耦合生成式对抗网络,其内部具有一定的耦合,可以对同一个输入有不同的输出。

其具体实现方式就是:
1、建立两个生成模型,两个判别模型。
2、两个生成模型的特征提取部分有一定的重合,在最后生成图片的部分分开,以生成不同类型的图片。
3、两个判别模型的特征提取部分有一定的重合,在最后判别真伪的部分分开,以判别不同类型的图片。

核心思想是权重共享,生成两种不同分割的图片,一个网络两种用途
相当于一个网络实现了两个网络的功能
COGAN的训练思路分为如下几个步骤:
1、创建两个风格不同的数据集。
2、随机生成batch_size个N维向量,利用两个不同的生成模型生成图片。
3、利用两个判别模型分别对两个不同的生成模型的生成图片进行判别、对两个风格不同的数据集进行随机选取并进行判别。
4、根据两个判别模型的结果与1对比,对两个生成模型进行训练。

from __future__ import print_function, division
import scipyfrom tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, GlobalAveragePooling2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import os
import matplotlib.pyplot as plt
import numpy as npclass COGAN():def __init__(self):# 输入shapeself.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)# 分十类self.num_classes = 10self.latent_dim = 100# adam优化器optimizer = Adam(0.0002, 0.5)# 生成两个判别器self.d1, self.d2 = self.build_discriminators()self.d1.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.d2.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])# 建立生成器self.g1, self.g2 = self.build_generators()z = Input(shape=(self.latent_dim,))img1 = self.g1(z)img2 = self.g2(z)self.d1.trainable = Falseself.d2.trainable = Falsevalid1 = self.d1(img1)valid2 = self.d2(img2)self.combined = Model(z, [valid1, valid2])self.combined.compile(loss=['binary_crossentropy', 'binary_crossentropy'],optimizer=optimizer)def build_generators(self):# 共享权值部分noise = Input(shape=(self.latent_dim,))x = Dense(32 * 7 * 7, activation="relu", input_dim=self.latent_dim)(noise)x = Reshape((7, 7, 32))(x)x = Conv2D(64, kernel_size=3, padding="same")(x)x = BatchNormalization(momentum=0.8)(x)x = Activation("relu")(x)x = UpSampling2D()(x)x = Conv2D(128, kernel_size=3, padding="same")(x)x = BatchNormalization(momentum=0.8)(x)x = Activation("relu")(x)x = UpSampling2D()(x)x = Conv2D(128, kernel_size=3, padding="same")(x)x = BatchNormalization(momentum=0.8)(x)feature_repr = Activation("relu")(x)model = Model(noise, feature_repr)noise = Input(shape=(self.latent_dim,))feature_repr = model(noise)# 生成模型1g1 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)g1 = BatchNormalization(momentum=0.8)(g1)g1 = Activation("relu")(g1)g1 = Conv2D(64, kernel_size=3, padding="same")(g1)g1 = BatchNormalization(momentum=0.8)(g1)g1 = Activation("relu")(g1)g1 = Conv2D(64, kernel_size=1, padding="same")(g1)g1 = BatchNormalization(momentum=0.8)(g1)g1 = Activation("relu")(g1)g1 = Conv2D(self.channels, kernel_size=1, padding="same")(g1)img1 = Activation("tanh")(g1)# 生成模型2g2 = Conv2D(64, kernel_size=1, padding="same")(feature_repr)g2 = BatchNormalization(momentum=0.8)(g2)g2 = Activation("relu")(g2)g2 = Conv2D(64, kernel_size=3, padding="same")(g2)g2 = BatchNormalization(momentum=0.8)(g2)g2 = Activation("relu")(g2)g2 = Conv2D(64, kernel_size=1, padding="same")(g2)g2 = BatchNormalization(momentum=0.8)(g2)g2 = Activation("relu")(g2)g2 = Conv2D(self.channels, kernel_size=1, padding="same")(g2)img2 = Activation("tanh")(g2)return Model(noise, img1), Model(noise, img2)def build_discriminators(self):# 共享权值部分img = Input(shape=self.img_shape)x = Conv2D(64, kernel_size=3, strides=2, padding="same")(img)x = BatchNormalization(momentum=0.8)(x)x = Activation("relu")(x)x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)x = BatchNormalization(momentum=0.8)(x)x = Activation("relu")(x)x = Conv2D(64, kernel_size=3, strides=2, padding="same")(x)x = BatchNormalization(momentum=0.8)(x)x = GlobalAveragePooling2D()(x)feature_repr = Activation("relu")(x)model = Model(img, feature_repr)img1 = Input(shape=self.img_shape)img2 = Input(shape=self.img_shape)img1_embedding = model(img1)img2_embedding = model(img2)# 生成评价模型1validity1 = Dense(1, activation='sigmoid')(img1_embedding)# 生成评价模型2validity2 = Dense(1, activation='sigmoid')(img2_embedding)return Model(img1, validity1), Model(img2, validity2)def train(self, epochs, batch_size=128, sample_interval=50):(X_train, _), (_, _) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)X1 = X_train[:int(X_train.shape[0] / 2)]X2 = X_train[int(X_train.shape[0] / 2):]X2 = scipy.ndimage.interpolation.rotate(X2, 90, axes=(1, 2))valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# ---------------------- ##  训练评价者# ---------------------- #idx = np.random.randint(0, X1.shape[0], batch_size)imgs1 = X1[idx]imgs2 = X2[idx]noise = np.random.normal(0, 1, (batch_size, 100))gen_imgs1 = self.g1.predict(noise)gen_imgs2 = self.g2.predict(noise)d1_loss_real = self.d1.train_on_batch(imgs1, valid)d2_loss_real = self.d2.train_on_batch(imgs2, valid)d1_loss_fake = self.d1.train_on_batch(gen_imgs1, fake)d2_loss_fake = self.d2.train_on_batch(gen_imgs2, fake)d1_loss = 0.5 * np.add(d1_loss_real, d1_loss_fake)d2_loss = 0.5 * np.add(d2_loss_real, d2_loss_fake)# ------------------ ##  训练生成模型# ------------------ #g_loss = self.combined.train_on_batch(noise, [valid, valid])print("%d [D1 loss: %f, acc.: %.2f%%] [D2 loss: %f, acc.: %.2f%%] [G loss: %f]" \% (epoch, d1_loss[0], 100 * d1_loss[1], d2_loss[0], 100 * d2_loss[1], g_loss[0]))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 4, 4noise = np.random.normal(0, 1, (r * int(c / 2), 100))gen_imgs1 = self.g1.predict(noise)gen_imgs2 = self.g2.predict(noise)gen_imgs = np.concatenate([gen_imgs1, gen_imgs2])gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')axs[i, j].axis('off')cnt += 1fig.savefig("images/mnist_%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = COGAN()gan.train(epochs=30000, batch_size=256, sample_interval=200)

训练结果很差,基本判断原因是因为学习率太大,生成网络的损失值一直下不去

7 LSGAN(最小二乘GAN,均方差替换交叉熵)

LSGAN是一种最小二乘GAN。

其主要特点为将loss函数的计算方式由交叉熵更改为均方差。

无论是判别模型的训练,还是生成模型的训练,都需要将交叉熵更改为均方差。

在普通GAN的基础上替换损失函数

7.1 训练思路

LSGAN的训练思路分为如下几个步骤:
1、随机选取batch_size个真实的图片。
2、随机生成batch_size个N维向量,传入到Generator中生成batch_size个虚假图片。
3、真实图片的label为1,虚假图片的label为0,将真实图片和虚假图片当作训练集传入到Discriminator中进行训练,训练的loss使用均方差。
4、将虚假图片的Discriminator预测结果与1的对比作为loss对Generator进行训练(与1对比的意思是,如果Discriminator将虚假图片判断为1,说明这个生成的图片很“真实”),这个loss同样使用均方差。

from __future__ import print_function, divisionfrom tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import sys
import osclass LSGAN():def __init__(self):self.img_rows = 28self.img_cols = 28self.channels = 1self.img_shape = (self.img_rows, self.img_cols, self.channels)self.latent_dim = 100optimizer = Adam(0.0002, 0.5)self.discriminator = self.build_discriminator()self.discriminator.compile(loss='mse',optimizer=optimizer,metrics=['accuracy'])self.generator = self.build_generator()z = Input(shape=(self.latent_dim,))img = self.generator(z)self.discriminator.trainable = Falsevalid = self.discriminator(img)self.combined = Model(z, valid)self.combined.compile(loss='mse', optimizer=optimizer)def build_generator(self):# --------------------------------- ##   生成器,输入一串随机数字# --------------------------------- #model = Sequential()model.add(Dense(256, input_dim=self.latent_dim))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(1024))model.add(LeakyReLU(alpha=0.2))model.add(BatchNormalization(momentum=0.8))model.add(Dense(np.prod(self.img_shape), activation='tanh'))model.add(Reshape(self.img_shape))noise = Input(shape=(self.latent_dim,))img = model(noise)return Model(noise, img)def build_discriminator(self):# ----------------------------------- ##   评价器,对输入进来的图片进行评价# ----------------------------------- #model = Sequential()# 输入一张图片model.add(Flatten(input_shape=self.img_shape))model.add(Dense(512))model.add(LeakyReLU(alpha=0.2))model.add(Dense(256))model.add(LeakyReLU(alpha=0.2))# 判断真伪model.add(Dense(1))img = Input(shape=self.img_shape)validity = model(img)return Model(img, validity)def train(self, epochs, batch_size=128, sample_interval=50):(X_train, _), (_, _) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5) / 127.5X_train = np.expand_dims(X_train, axis=3)valid = np.ones((batch_size, 1))fake = np.zeros((batch_size, 1))for epoch in range(epochs):# --------------------------- ##   随机选取batch_size个图片#   对discriminator进行训练# --------------------------- #idx = np.random.randint(0, X_train.shape[0], batch_size)imgs = X_train[idx]noise = np.random.normal(0, 1, (batch_size, self.latent_dim))gen_imgs = self.generator.predict(noise)d_loss_real = self.discriminator.train_on_batch(imgs, valid)d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# --------------------------- ##  训练generator# --------------------------- #noise = np.random.normal(0, 1, (batch_size, self.latent_dim))g_loss = self.combined.train_on_batch(noise, valid)print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))if epoch % sample_interval == 0:self.sample_images(epoch)def sample_images(self, epoch):r, c = 5, 5noise = np.random.normal(0, 1, (r * c, self.latent_dim))gen_imgs = self.generator.predict(noise)gen_imgs = 0.5 * gen_imgs + 0.5fig, axs = plt.subplots(r, c)cnt = 0for i in range(r):for j in range(c):axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')axs[i,j].axis('off')cnt += 1fig.savefig("images/%d.png" % epoch)plt.close()if __name__ == '__main__':if not os.path.exists("./images"):os.makedirs("./images")gan = LSGAN()gan.train(epochs=30000, batch_size=512, sample_interval=200)

8 CycleGAN(风格转换)

CycleGAN是一种完成图像到图像的转换的一种GAN。

图像到图像的转换是一类视觉和图形问题,其目标是获得输入图像和输出图像之间的映射。

但是,对于许多任务,配对的训练数据将不可用。

CycleGAN提出了一种在没有成对例子的情况下学习将图像从源域X转换到目标域Y的方法。


这样的结构与我们所学过的语义分割的形式非常类似,因此需要先进行下采样后再进行上采样!

8.1 训练思路

CycleGAN的训练思路分为如下几个步骤:
1、创建两个生成模型,一个用于从图片风格A转换成图片风格B,一个用于从图片风格B转换成图片风格A。
2、创建两个判别模型,分别用于风格A图片的真伪判断和风格B图片的真伪判断。
3、判别模型的训练所用的损失函数与LSGAN相同,通过判断是否正确进行训练。
4、生成模型的训练需要满足下面六个准则:

a、从图片风格A转换成图片风格B的假图像需要成功欺骗判断模型B;
b、从图片风格B转换成图片风格A的假图像需要成功欺骗判断模型A;
c、从图片风格A转换成图片风格B的假图像可以通过生成模型BA成功转换成图片A;
d、从图片风格B转换成图片风格A的假图像可以通过生成模型AB成功转换成图片B;
e、真实图片A通过生成模型BA,不会发生变化。
f、真实图片B通过生成模型AB,不会发生变化。
其中c、d准则是为了让生成器找到最需要修改的地方,比如 斑马转黄马就只要改变马的颜色就可以欺骗判断模型,风格A的图片经过生成模型AB只需要转化 斑马 即可。
其中e、f准则是为了让 两种生成模型可以区分两种图片风格,生成模型AB只对风格A的图片进行处理,生成模型BA只对风格B的图片进行处理。

这个代码有些问题,自己用tensorflow复现的时候有一些问题直接看下面的博客吧
(8条消息) 好像还挺好玩的GAN7——CycleGAN实现图像风格转换_Bubbliiiing的学习小课堂-CSDN博客_gan风格转换
https://blog.csdn.net/weixin_44791964/article/details/103780922

在这里插入代码片

9 SRGAN(图像分辨率提升GAN)

SRGAN出自论文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

其主要的功能就是输入一张低分辨率图片,生成高分辨率图片。

文章提到,普通的超分辨率模型训练网络时只用到了均方差作为损失函数,虽然能够获得很高的峰值信噪比,但是恢复出来的图像通常会丢失高频细节。

SRGAN利用感知损失(perceptual loss)和对抗损失(adversarial loss)来提升恢复出的图片的真实感。

其中感知损失是利用卷积神经网络提取出的特征,通过比较生成图片经过卷积神经网络后的特征和目标图片经过卷积神经网络后的特征的差别,使生成图片和目标图片在语义和风格上更相似

对抗损失由GAN提供,根据图像是否可以欺骗过判别网络进行训练。

9.1 生成网络


此图从左至右来看,我们可以知道:
SRGAN的生成网络由三个部分组成。
1、低分辨率图像进入后会经过一个卷积+RELU函数
2、然后经过B个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。
3、然后进入上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。

前两部分用于特征提取,第三部分用于提高分辨率。

9.2 判别网络


SRGAN的判别网络由不断重复的 卷积+LeakyRELU和标准化 组成。

9.3 训练思路

1、对判别模型进行训练
将真实的高分辨率图像和虚假的高分辨率图像传入判别模型中。
将真实的高分辨率图像的判别结果与1对比得到loss。
将虚假的高分辨率图像的判别结果与0对比得到loss。
利用得到的loss进行训练。
两个loss
2、对生成模型进行训练
两个loss
将低分辨率图像传入生成模型,得到高分辨率图像,利用该高分辨率图像获得判别结果与1进行对比得到loss。
将真实的高分辨率图像和虚假的高分辨率图像传入VGG网络,获得两个图像的特征,通过这两个图像的特征进行比较获得loss。

3、数据集只需要高分辨率图片就行了,通过直接降维得到低分辨率图片

9.4 几种尺度

(1)512512,这是原始高分辨率图像
(2)128
128,这是低分辨率图像
(3)512512,这是低分辨率图像生成的高分辨率图像
(4)32
32,这是vgg19得到的特征进行对比

9.5 3个文件代码

第一个文件进行图像数据的读取,

并生成低分辨率图像,默认高分辨率128128,低分辨率3232,但是在本实验中,传递的参数是高分辨率512512,低分辨率128128

import scipy
from glob import glob
import numpy as np
import matplotlib.pyplot as pltclass DataLoader():def __init__(self, dataset_name, img_res=(128, 128)):self.dataset_name = dataset_nameself.img_res = img_resdef load_data(self, batch_size=1, is_testing=False):data_type = "train" if not is_testing else "test"path = glob('./datasets/%s/train/*' % (self.dataset_name))batch_images = np.random.choice(path, size=batch_size)imgs_hr = []imgs_lr = []for img_path in batch_images:img = self.imread(img_path)h, w = self.img_reslow_h, low_w = int(h / 4), int(w / 4)img_hr = scipy.misc.imresize(img, self.img_res)   # 高分辨率是128*128img_lr = scipy.misc.imresize(img, (low_h, low_w))  # 低分辨率是32*32# If training => do random flipif not is_testing and np.random.random() < 0.5:img_hr = np.fliplr(img_hr)img_lr = np.fliplr(img_lr)imgs_hr.append(img_hr)imgs_lr.append(img_lr)imgs_hr = np.array(imgs_hr) / 127.5 - 1.imgs_lr = np.array(imgs_lr) / 127.5 - 1.return imgs_hr, imgs_lrdef imread(self, path):return scipy.misc.imread(path, mode='RGB').astype(np.float)

第二个文件,进行模型的训练

遇到scipy 的版本问题报错,根据报错信息显示,直接改版本就行了
conda install scipy==1.2.1 就可以

from __future__ import print_function, division
import scipy#from keras.datasets import mnist
#from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from tensorflow.keras.layers import PReLU, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.applications import VGG19
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import osimport tensorflow.keras.backend as Kclass SRGAN():def __init__(self):# 低分辨率图的shapeself.channels = 3self.lr_height = 128self.lr_width = 128self.lr_shape = (self.lr_height, self.lr_width, self.channels)# 高分辨率图的shapeself.hr_height = self.lr_height*4self.hr_width = self.lr_width*4self.hr_shape = (self.hr_height, self.hr_width, self.channels)# 16个残差卷积块self.n_residual_blocks = 16# 优化器optimizer = Adam(0.0002, 0.5)# 创建VGG模型,该模型用于提取特征self.vgg = self.build_vgg()self.vgg.trainable = False# 数据集self.dataset_name = 'DIV2K_train_HR'self.data_loader = DataLoader(dataset_name=self.dataset_name,img_res=(self.hr_height, self.hr_width))# patch是干什么的patch = int(self.hr_height / 2**4)self.disc_patch = (patch, patch, 1)# 建立判别模型self.discriminator = self.build_discriminator()self.discriminator.compile(loss='binary_crossentropy',optimizer=optimizer,metrics=['accuracy'])self.discriminator.summary()# 建立生成模型self.generator = self.build_generator()self.generator.summary()# 将生成模型和判别模型结合。训练生成模型的时候不训练判别模型。img_lr = Input(shape=self.lr_shape)fake_hr = self.generator(img_lr)fake_features = self.vgg(fake_hr)self.discriminator.trainable = Falsevalidity = self.discriminator(fake_hr)self.combined = Model(img_lr, [validity, fake_features])self.combined.compile(loss=['binary_crossentropy', 'mse'],loss_weights=[5e-1, 1],optimizer=optimizer)def build_vgg(self):# 建立VGG模型,只使用第9层的特征vgg = VGG19(weights="imagenet")vgg.outputs = [vgg.layers[9].output]img = Input(shape=self.hr_shape)#img_features = vgg(img)#return Model(img, img_features)return Model(vgg.input, vgg.outputs)def build_generator(self):def residual_block(layer_input, filters):d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)d = BatchNormalization(momentum=0.8)(d)d = Activation('relu')(d)d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)d = BatchNormalization(momentum=0.8)(d)d = Add()([d, layer_input])return ddef deconv2d(layer_input):u = UpSampling2D(size=2)(layer_input)u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)u = Activation('relu')(u)return u# 128,128,3img_lr = Input(shape=self.lr_shape)# 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)c1 = Activation('relu')(c1)# 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。r = residual_block(c1, 64)for _ in range(self.n_residual_blocks - 1):r = residual_block(r, 64)# 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)c2 = BatchNormalization(momentum=0.8)(c2)c2 = Add()([c2, c1])u1 = deconv2d(c2)u2 = deconv2d(u1)gen_hr = Conv2D(self.channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)return Model(img_lr, gen_hr)def build_discriminator(self):def d_block(layer_input, filters, strides=1, bn=True):"""Discriminator layer"""d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)d = LeakyReLU(alpha=0.2)(d)if bn:d = BatchNormalization(momentum=0.8)(d)return d# 由一堆的卷积+LeakyReLU+BatchNor构成d0 = Input(shape=self.hr_shape)d1 = d_block(d0, 64, bn=False)d2 = d_block(d1, 64, strides=2)d3 = d_block(d2, 128)d4 = d_block(d3, 128, strides=2)d5 = d_block(d4, 256)d6 = d_block(d5, 256, strides=2)d7 = d_block(d6, 512)d8 = d_block(d7, 512, strides=2)d9 = Dense(64*16)(d8)d10 = LeakyReLU(alpha=0.2)(d9)validity = Dense(1, activation='sigmoid')(d10)return Model(d0, validity)def scheduler(self,models,epoch):# 学习率下降if epoch % 20000 == 0 and epoch != 0:for model in models:lr = K.get_value(model.optimizer.lr)K.set_value(model.optimizer.lr, lr * 0.5)print("lr changed to {}".format(lr * 0.5))def train(self, epochs ,init_epoch=0, batch_size=1, sample_interval=50):start_time = datetime.datetime.now()if init_epoch!= 0:self.generator.load_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)self.discriminator.load_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, init_epoch),skip_mismatch=True)for epoch in range(init_epoch,epochs):self.scheduler([self.combined,self.discriminator],epoch)# ---------------------- ##  训练判别网络# ---------------------- #imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)fake_hr = self.generator.predict(imgs_lr)valid = np.ones((batch_size,) + self.disc_patch)    # 这里的self.disc_patch应该是特征图fake = np.zeros((batch_size,) + self.disc_patch)# imgs_hr和fake_hr是对应的真实x,valid和fake是真值y,就是为了让得到valid和faked_loss_real = self.discriminator.train_on_batch(imgs_hr, valid)d_loss_fake = self.discriminator.train_on_batch(fake_hr, fake)d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------- ##  训练生成网络# ---------------------- #imgs_hr, imgs_lr = self.data_loader.load_data(batch_size)valid = np.ones((batch_size,) + self.disc_patch)image_features = self.vgg.predict(imgs_hr)g_loss = self.combined.train_on_batch(imgs_lr, [valid, image_features])print(d_loss,g_loss)elapsed_time = datetime.datetime.now() - start_timeprint ("[Epoch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, feature loss: %05f] time: %s " \% ( epoch, epochs,d_loss[0], 100*d_loss[1],g_loss[1],g_loss[2],elapsed_time))if epoch % sample_interval == 0:self.sample_images(epoch)# 保存if epoch % 500 == 0 and epoch != init_epoch:os.makedirs('weights/%s' % self.dataset_name, exist_ok=True)self.generator.save_weights("weights/%s/gen_epoch%d.h5" % (self.dataset_name, epoch))self.discriminator.save_weights("weights/%s/dis_epoch%d.h5" % (self.dataset_name, epoch))def sample_images(self, epoch):os.makedirs('images/%s' % self.dataset_name, exist_ok=True)r, c = 2, 2imgs_hr, imgs_lr = self.data_loader.load_data(batch_size=2, is_testing=True)fake_hr = self.generator.predict(imgs_lr)imgs_lr = 0.5 * imgs_lr + 0.5fake_hr = 0.5 * fake_hr + 0.5imgs_hr = 0.5 * imgs_hr + 0.5titles = ['Generated', 'Original']fig, axs = plt.subplots(r, c)cnt = 0for row in range(r):for col, image in enumerate([fake_hr, imgs_hr]):axs[row, col].imshow(image[row])axs[row, col].set_title(titles[col])axs[row, col].axis('off')cnt += 1fig.savefig("images/%s/%d.png" % (self.dataset_name, epoch))plt.close()for i in range(r):fig = plt.figure()plt.imshow(imgs_lr[i])fig.savefig('images/%s/%d_lowres%d.png' % (self.dataset_name, epoch, i))plt.close()if __name__ == '__main__':gan = SRGAN()gan.train(epochs=60000,init_epoch = 0, batch_size=1, sample_interval=50)

第3个文件进行预测

加载模型和权重进行预测

#from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from tensorflow.keras.layers import PReLU, LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from srgan import SRGAN
from PIL import Image
import numpy as np
import os
import tensorflow as tf
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"def build_generator():def residual_block(layer_input, filters):d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)d = Activation('relu')(d)d = BatchNormalization(momentum=0.8)(d)d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)d = BatchNormalization(momentum=0.8)(d)d = Add()([d, layer_input])return ddef deconv2d(layer_input):u = UpSampling2D(size=2)(layer_input)u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)u = Activation('relu')(u)return uimg_lr = Input(shape=[None,None,3])# 第一部分,低分辨率图像进入后会经过一个卷积+RELU函数c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)c1 = Activation('relu')(c1)# 第二部分,经过16个残差网络结构,每个残差网络内部包含两个卷积+标准化+RELU,还有一个残差边。r = residual_block(c1, 64)for _ in range(15):r = residual_block(r, 64)# 第三部分,上采样部分,将长宽进行放大,两次上采样后,变为原来的4倍,实现提高分辨率。c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)c2 = BatchNormalization(momentum=0.8)(c2)c2 = Add()([c2, c1])u1 = deconv2d(c2)u2 = deconv2d(u1)gen_hr = Conv2D(3, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)return Model(img_lr, gen_hr)model = build_generator()
model.load_weights(r"weights\DIV2K_train_HR\gen_epoch6000.h5")
before_image = Image.open(r"./images/before.png")new_image = Image.new('RGB', before_image.size, (128,128,128))
new_image.paste(before_image)new_image = np.array(new_image)/127.5 - 1
print("图像大小:",new_image.shape)
new_image = np.expand_dims(new_image,axis=0)
fake = (model.predict(new_image)*0.5 + 0.5)*255fake = Image.fromarray(np.uint8(fake[0]))fake.save("out.png")
fake.show()

10 总结

GAN的基本思想:
(1)构建一个普通的生成模型,比如vgg16等或者unet等
(2)构建一个普通的二分类的判别模型;
(3)先训练一个batch的判别模型,再训练一个batch的生成模型;
(4)训练判别模型就直接根据前向传播,真是图像和生成的假图像各占一半得到判别模型的损失值,训练判别模型。
(5)训练生成模型,需要用到判别模型的结果在这次生成的图像的判别结果,所以要构建从生成到判别的一条流程的模型,得到生成模型的损失值,训练生成模型。
上面的8个模型看懂一个就行

参考资料

【1】Keras 搭建自己的GAN生成对抗网络平台(Bubbliiiing 深度学习 教程)_哔哩哔哩_bilibili
https://www.bilibili.com/video/BV13J41187Fo?from=search&seid=14309111542489072351&spm_id_from=333.337.0.0

【2】下面的链接包含的研究很多
生成对抗网络的生成样本能否提高预测模型准确率? - 知乎
https://www.zhihu.com/question/372133109

【3】(8条消息) 好像还挺好玩的GAN_Bubbliiiing的学习小课堂-CSDN博客
https://blog.csdn.net/weixin_44791964/category_9625179.html

DLbest系列1——GAN生成对抗网络2相关推荐

  1. [Pytorch系列-72]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型训练CycleGAN模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  2. [Pytorch系列-69]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - test.py代码详解

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:[Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleG ...

  3. [Pytorch系列-66]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - 使用预训练模型测试pix2pix模型

    作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客 本文网址:https://blog.csdn.net/HiWangWenBing/article/detai ...

  4. 54_pytorch GAN(生成对抗网络)、Gan代码示例、WGAN代码示例

    1.54.GAN(生成对抗网络) 1.54.1.什么是GAN 2014 年,Ian Goodfellow 和他在蒙特利尔大学的同事发表了一篇震撼学界的论文.没错,我说的就是<Generative ...

  5. 深度学习(九) GAN 生成对抗网络 理论部分

    GAN 生成对抗网络 理论部分 前言 一.Pixel RNN 1.图片的生成模型 2.Pixel RNN 3.Pixel CNN 二.VAE(Variational Autoencoder) 1.VA ...

  6. 深度学习 GAN生成对抗网络-1010格式数据生成简单案例

    一.前言 本文不花费大量的篇幅来推导数学公式,而是使用一个非常简单的案例来帮助我们了解GAN生成对抗网络. 二.GAN概念 生成对抗网络(Generative Adversarial Networks ...

  7. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 下

    文章目录 1 测试鉴别器 2 建立生成器 3 测试生成器 4 训练生成器 5 使用生成器 6 内存查看 上一节,我们已经建立好了模型所必需的鉴别器类与Dataset类. 使用PyTorch构建GAN生 ...

  8. 使用PyTorch构建GAN生成对抗网络源码(详细步骤讲解+注释版)02 人脸识别 上

    文章目录 1 数据集描述 2 GPU设置 3 设置Dataset类 4 设置辨别器类 5 辅助函数与辅助类 1 数据集描述 此项目使用的是著名的celebA(CelebFaces Attribute) ...

  9. GAN (生成对抗网络) 手写数字图片生成

    GAN (生成对抗网络) 手写数字图片生成 文章目录 GAN (生成对抗网络) 手写数字图片生成 Discriminator Network Generator Network 简单版本的生成对抗网络 ...

最新文章

  1. 拟合一条曲线_数据预测与曲线拟合
  2. Linux免设置路由端口映射,2014/04/01 演示中设置linux路由器、端口过滤的使用、路由设置...
  3. matlab中安装libsvm时No supported compiler or SDK was found问题
  4. 计算机制作培训通知知识点,计算机学习计划(通用3篇)
  5. 自动驾驶1-6: 推动决策和行动Driving Decisions and Actions
  6. 图像取证:由色差发现数字篡改痕迹
  7. MangoDB的基本操作
  8. 小暑海报文案|小暑海报设计图片素材
  9. 新计算机c盘太小,电脑C盘太小,F盘太大,怎么重新调整分区容量?
  10. 游戏资讯平台APP项目计划书
  11. 在mcreator里创建你的第一个模组
  12. 浅谈javascript的原型和原型链(新手懵懂想学会原型链?看这篇文章就足够啦!!!)
  13. 图:国行HTC 8X修改市场区域
  14. 解决 MUI QQ登陆功能报错“该应用非官方正版应用,请到......100044”
  15. AJAX基础入门实例教程(含代码)
  16. 天才小毒妃 第879章 不许欺负伤残人氏
  17. 火出圈的ChatGPT,如何让安全检测更智能
  18. 一键申报税务,就找小帮软件机器人...一键申报
  19. 关于安装Adobe Illustrator AI CC 2017中遇到的问题总结
  20. iOS设备中WiFi、蓝牙和飞行模式的开启与关闭

热门文章

  1. HTML 动态夜空特效
  2. 平面设计中的插画设计技巧
  3. 《大明王朝》赵贞吉的拉扯
  4. 用PS怎样把一个字体居中整个图片
  5. A05-没有奥维vip,导入元素个数受限制,想导入三区三线图斑怎么办?
  6. 声声慢 - 程序人生
  7. 翟佳:高可用、强一致、低延迟——BookKeeper的存储实现
  8. StreamNative 联合创始人翟佳出席QCon北京峰会并发表演讲
  9. scuba 报表_是否想了解JavaScript的for循环? 这个动画的SCUBA潜水员可以提供帮助!...
  10. Won a Air Purifier in DD lucky draw