1.SGAN简介   

       半监督学习只为训练数据集的一小部分提供类别标签。通过内部数据中的隐藏结构,半监督学习从标注数据点的小子集中归纳,以有效对从未见过的新样本进行分类。

       要使半监督学习有效,标签数据和无标签数据必须来自相同分布。

       半监督生成对抗网络是一种生成对抗网络,其鉴定器是多分类器,不止区分真假两个类,而是学会区分N+1类,其中N是训练数据集中的类数,生成器 生成的伪样本为一个类。

SGAN主要关心的是鉴别器。训练过程的目标是使该网络成为仅使用一部分标签数据的半监督分类器,其准确率接近全监督的分类器(就是其训练数据集中的每个样本都有标签)。

1.2架构图

 生成器将随机噪音转换为伪样本;鉴别器输入有标签的真实图像(x,y),无标签的真实图像(x)和生成器生成的伪图像(x*)。先用sigmoid区分真伪,然后使用softmax区分类别。

2.代码实现

2.1导入声明

%matplotlib inlineimport matplotlib.pyplot as plt
import numpy as npfrom keras import backend as Kfrom keras.datasets import mnist
from keras.layers import (Activation, BatchNormalization, Concatenate, Dense,Dropout, Flatten, Input, Lambda, Reshape)
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.models import Model, Sequential
from keras.optimizer_v2 import adam as Adam
from tensorflow.keras.utils import to_categorical

2.2模型输入维度

img_rows = 28
img_cols = 28
channels = 1# 输入图像维度
img_shape = (img_rows, img_cols, channels)# 噪声向量的大小,用作生成器的输入
z_dim = 100# 数据集中类别的数量
num_classes = 10

2.3数据集

此处使用的是MNIST数据集,里面包含50000张含有标签的图片,但是我们只取其中一部分用于训练,其他的都是假设其没有标签。

class Dataset:def __init__(self, num_labeled):# 训练中使用的有标签图像的数量self.num_labeled = num_labeled# 加载MINST数据集(self.x_train, self.y_train), (self.x_test,self.y_test) = mnist.load_data()def preprocess_imgs(x):# 灰度像素值从[0,255]缩放到[-1,1]x = (x.astype(np.float32) - 127.5) / 127.5# 将图像尺寸扩展到宽*高*通道数x = np.expand_dims(x, axis=3)return xdef preprocess_labels(y):return y.reshape(-1, 1)# 训练数据self.x_train = preprocess_imgs(self.x_train)self.y_train = preprocess_labels(self.y_train)# 测试数据self.x_test = preprocess_imgs(self.x_test)self.y_test = preprocess_labels(self.y_test)def batch_labeled(self, batch_size):#获取随机批量有标签图像及其标签idx = np.random.randint(0, self.num_labeled, batch_size)imgs = self.x_train[idx]labels = self.y_train[idx]return imgs, labelsdef batch_unlabeled(self, batch_size):# 获取随机批量的无标签图像idx = np.random.randint(self.num_labeled, self.x_train.shape[0],batch_size)imgs = self.x_train[idx]return imgsdef training_set(self):x_train = self.x_train[range(self.num_labeled)]y_train = self.y_train[range(self.num_labeled)]return x_train, y_traindef test_set(self):return self.x_test, self.y_test# 要使用的有标签样本的数量
num_labeled = 100dataset = Dataset(num_labeled)

2.4生成器

此处的生成器网络和DCGAN网络相同,只用转置卷积将随机噪声向量转换为28*28*1的图像

def build_generator(z_dim):model = Sequential()# 通过一个全连接层改变输入为一个7*7*256的张量model.add(Dense(256 * 7 * 7, input_dim=z_dim))model.add(Reshape((7, 7, 256)))# 转置卷积层,张量从7*7*256变为14*14*128model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))# 批归一化model.add(BatchNormalization())# Leaky ReLU 的激活函数model.add(LeakyReLU(alpha=0.01))# 转置卷积层,张量从14*14*128变为14*14*64model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))# 拟归一化model.add(BatchNormalization())# Leaky ReLU 的激活函数model.add(LeakyReLU(alpha=0.01))#转置卷积层,张量从14*14*64变为28*28*1model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))# 带有tanh激活函数的输出层model.add(Activation('tanh'))return model

2.5鉴别器

有着双重目标:

区别真实样本和伪样本。(使用sigmoid函数)

对于真实样本,还要对其标签进行分类。(使用softmax函数)

2.5.1核心鉴别网络

添加一个dropout,通过在训练过程中随机丢弃神经元来防止过拟合。

def build_discriminator_net(img_shape):model = Sequential()# 卷积层,张量从28*28*1变成14*14*32model.add(Conv2D(32,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))# Leaky ReLU 的激活函数model.add(LeakyReLU(alpha=0.01))# 卷积层,张量从14*14*32到7*7*64model.add(Conv2D(64,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))# 批归一化model.add(BatchNormalization())# Leaky ReLU 激活函数model.add(LeakyReLU(alpha=0.01))# 卷积层,张量从7*7*64变成3*3128model.add(Conv2D(128,kernel_size=3,strides=2,input_shape=img_shape,padding='same'))# 批归一化model.add(BatchNormalization())# Leaky ReLU 激活函数model.add(LeakyReLU(alpha=0.01))# Droupout正则model.add(Dropout(0.5))# 将张量展平model.add(Flatten())# 于unm_classes神经元完全连接的层model.add(Dense(num_classes))return model

2.5.2有监督的多分类鉴定器 

def build_discriminator_supervised(discriminator_net):model = Sequential()model.add(discriminator_net)# Softmax 激活函数, 输出真实类别的预测概率分布model.add(Activation('softmax'))return model

2.5.3无监督的二分类鉴定器

def build_discriminator_unsupervised(discriminator_net):model = Sequential()model.add(discriminator_net)def predict(x):# 将真实类别的分布转换为二元真——假概率prediction = 1.0 - (1.0 /(K.sum(K.exp(x), axis=-1, keepdims=True) + 1.0))return predictionmodel.add(Lambda(predict))return model

2.6构建模型

def build_gan(generator, discriminator):model = Sequential()# 结合生成器和鉴定器模型model.add(generator)model.add(discriminator)return model

鉴定器 


# 核心鉴定器网络:在有监督和无监督中共享
discriminator_net = build_discriminator_net(img_shape)#构建有监督鉴定器
discriminator_supervised = build_discriminator_supervised(discriminator_net)
discriminator_supervised.compile(loss='categorical_crossentropy',metrics=['accuracy'],optimizer=Adam.Adam())# 构建无监督鉴定器
discriminator_unsupervised = build_discriminator_unsupervised(discriminator_net)
discriminator_unsupervised.compile(loss='binary_crossentropy',optimizer=Adam.Adam())
# 构建生成器
generator = build_generator(z_dim)# 生成器训练时保持鉴定器参数不变
discriminator_unsupervised.trainable = False# 构建固定参数的鉴别器,以训练生成器
# 鉴别器使用无监督版本
gan = build_gan(generator, discriminator_unsupervised)
gan.compile(loss='binary_crossentropy', optimizer=Adam.Adam())

2.7训练

supervised_losses = []
iteration_checkpoints = []def train(iterations, batch_size, sample_interval):# 真实图像的标签:全为1real = np.ones((batch_size, 1))# 伪图像的标签:全为0fake = np.zeros((batch_size, 1))for iteration in range(iterations):# -------------------------#  训练鉴定器# -------------------------# 获得标签样本imgs, labels = dataset.batch_labeled(batch_size)# 独热编码标签labels = to_categorical(labels, num_classes=num_classes)# 获得无标签样本imgs_unlabeled = dataset.batch_unlabeled(batch_size)# 生成一批伪图像z = np.random.normal(0, 1, (batch_size, z_dim))gen_imgs = generator.predict(z)# 训练有标签的真实样本d_loss_supervised, accuracy = discriminator_supervised.train_on_batch(imgs, labels)#训练无标签的真实样本d_loss_real = discriminator_unsupervised.train_on_batch(imgs_unlabeled, real)# 训练伪样本d_loss_fake = discriminator_unsupervised.train_on_batch(gen_imgs, fake)d_loss_unsupervised = 0.5 * np.add(d_loss_real, d_loss_fake)# ---------------------#  训练生成器# ---------------------# 生成一批次假图像z = np.random.normal(0, 1, (batch_size, z_dim))gen_imgs = generator.predict(z)# 训练生成器g_loss = gan.train_on_batch(z, np.ones((batch_size, 1)))if (iteration + 1) % sample_interval == 0:# 保存鉴别器的有监督分类损失,以便绘制损失曲线supervised_losses.append(d_loss_supervised)iteration_checkpoints.append(iteration + 1)# 输出训练过程print("%d [D loss supervised: %.4f, acc.: %.2f%%] [D loss unsupervised: %.4f] [G loss: %f]"% (iteration + 1, d_loss_supervised, 100 * accuracy,d_loss_unsupervised, g_loss))

 2.8. 训练模型

# 设置超参数
iterations = 8000
batch_size = 32
sample_interval = 800train(iterations, batch_size, sample_interval)

2.9画出鉴定器的监督损失

losses = np.array(supervised_losses)plt.figure(figsize=(15, 5))
plt.plot(iteration_checkpoints, losses, label="Discriminator loss")plt.xticks(iteration_checkpoints, rotation=90)plt.title("Discriminator – Supervised Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()

2.10.模型训练和测试准确率

x, y = dataset.training_set()
y = to_categorical(y, num_classes=num_classes)# 在测试集上计算分类准确率
_, accuracy = discriminator_supervised.evaluate(x, y)
print("Training Accuracy: %.2f%%" % (100 * accuracy))

此处跑出来是100%(SGAN在训练只有100个有标签的样本用于有监督训练-可能记住了训练数据集)

x, y = dataset.test_set()
y = to_categorical(y, num_classes=num_classes)# 计算在测试集上的精度
_, accuracy = discriminator_supervised.evaluate(x, y)
print("Test Accuracy: %.2f%%" % (100 * accuracy))

达到了89%

1.4半监督生成对抗网络(SGAN)相关推荐

  1. 基于生成对抗网络的医学数据域适应研究

    点击上方蓝字关注我们 基于生成对抗网络的医学数据域适应研究 于胡飞, 温景熙, 辛江, 唐艳 中南大学计算机学院,湖南 长沙 410083   摘要:在医疗影像辅助诊断研究中,研究者通常使用不同医院( ...

  2. 实战生成对抗网络[1]:简介

    引言 2016年3月,AlphaGO横空出世,击败人类顶尖职业棋手,引爆了人工智能热潮.之后AlphaGO Master和AlphaGO Zero更是无情的碾压人类棋手,人们终于认识到,人类迎来了可怕 ...

  3. DeepDGA:基于生成对抗网络的DGA生成与检测

    基于DeepDGA: Adversarially-Tuned Domain Generation and Detection 复现(Python):GitHub地址 研究背景 由DGA引发的一系列- ...

  4. 生成对抗网络gan原理_生成对抗网络(GAN)的半监督学习

    前言 如果您曾经听说过或研究过深度学习,那么您可能就知道MNIST, SVHN, ImageNet, PascalVoc或者其他数据集.这些数据集都有一个共同点: 它们由成千上万个有标签的数据组成. ...

  5. 飞浆论文复现:用于图像到图像翻译的具有自适应层实例化的非监督的生成对抗网络

    Unsupervised generative attentional networks with adaptive layer-instance normalization for image-to ...

  6. 深度学习中的生成对抗网络GAN

    转载:一文看尽深度学习中的生成对抗网络 | CVHub带你看一看GANs架构发展的8年 (qq.com) 导读 生成对抗网络 (Generative Adversarial Networks, GAN ...

  7. ICCV2017 | 一文详解GAN之父Ian Goodfellow 演讲《生成对抗网络的原理与应用》(附完整PPT)

    当地时间 10月 22 日到10月29日,两年一度的计算机视觉国际顶级会议 International Conference on Computer Vision(ICCV 2017)在意大利威尼斯开 ...

  8. 生成对抗网络(GAN)的理论与应用完整入门介绍

    本文包含以下内容: 1.为什么生成模型值得研究 2.生成模型的分类 3.GAN相对于其他生成模型相比有什么优势 4.GAN基本模型 5.改进的GANs 6.GAN有哪些应用 7.GAN的前沿研究 一. ...

  9. 生成对抗网络(GAN)应用于图像分类

    近年来,深度学习技术被广泛应用于各类数据处理任务中,比如图像.语音和文本.而生成对抗网络(GAN)和强化学习(RL)已经成为了深度学习框架中的两颗"明珠".强化学习主要用于决策问题 ...

最新文章

  1. 后勤问题怎么办。。。(求刊登)
  2. 21个高质量的Swift开源iOS App
  3. 扫掠两条引导线_如何巧用引导线,把摄影水平再提升一个档次?
  4. chrome ui源码剖析-Accelerator(快捷键)
  5. maven中packaging的三个属性值pom、jar、war
  6. [转载] java(三)对象的序列化与static、final关键字
  7. JAVA字节流(读写文件)
  8. java模拟浏览器htmlunit,Java版本的浏览器HtmlUnit入门示例
  9. python网络爬虫系列教程——python中pyquery库应用全解
  10. php系统变量有哪些,php预定义系统变量
  11. VC知识库的5周年精华珍藏版光盘
  12. DolphinScheduler大数据调度系统
  13. coreos mysql_Fedora CoreOS 介绍
  14. python 数学公式_pythonp_word03公式编辑器,空心方阵公式,高中数学必修一公式,销售利润率公式,高一物理必修1公式...
  15. max如何渲染多张图片
  16. 怎样查看谁发的qq坦白说
  17. 发布Flv合并器的.net版
  18. 基于anchor-free的目标检测算法CenterNet研究
  19. 图像修复序列——BSCB模型
  20. 手机QQ iOS版默认不显示iPhone在线

热门文章

  1. CNCF案例研究:Uber
  2. 图像相似性匹配 快速算法
  3. http状态码说明(全)
  4. C# webbrowser专题
  5. UML系列文章(31)体系结构建模---部署图
  6. 世界卫生组织与腾讯加深合作 新冠肺炎AI自查助手全球开源
  7. js 获取指定范围随机数
  8. 如何安装JavaFX
  9. Python 5行代码生成个性二维码,要不要试一下?
  10. PIM-DM的扩散-剪枝