MXNet网络模型(四)对抗神经网络

  • 概述
  • 原理
  • MXNet代码

概述

GAN 神经网络(2014)

  • 《Generative Adversarial Network》
  • https://arxiv.org/pdf/1406.2661.pdf



GAN网络也是一直在进步

原理

GAN训练可以分为5步

  1. 样本图片x输入分类器
  2. 随机种子z经由生成器生成模拟图片x*
  3. 模拟图片x*输入分类器
  4. 反馈给分类器
  5. 反馈给生成器

MXNet代码

导入库

import time
import gzip
import numpy as np
import matplotlib.pyplot as plt![请添加图片描述](https://img-blog.csdnimg.cn/153a1da1e6d8438681cddca1fae72021.bmp)import mxnet as mx

环境和批大小

batch_size = 10
device = mx.cpu()

读取训练数据

def load_dataset():transform = mx.gluon.data.vision.transforms.ToTensor()train_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]train_lbl = [ np.array(lbl)            for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]eval_img  = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]eval_lbl  = [ np.array(lbl)            for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]return train_img, train_lbl, eval_img, eval_lbltrain_img, train_lbl, eval_img, eval_lbl = load_dataset()train_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(train_img, train_lbl),batch_size=batch_size,shuffle=True
)
eval_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(eval_img, eval_lbl),batch_size=batch_size,shuffle=False
)

预览训练数据

idxs = (25, 47, 74, 88, 92)
for i in range(5):plt.subplot(1, 5, i + 1)idx = idxs[i]plt.xticks([])plt.yticks([])img = train_img[idx][0].astype( np.float32 )plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()


分类器

class Discriminator():def __init__(self):self.loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()self.metric = mx.metric.Loss()self.net = mx.gluon.nn.HybridSequential()# 第一层self.net.add(mx.gluon.nn.Flatten(),mx.gluon.nn.Dense(units=200),mx.gluon.nn.LeakyReLU(alpha=0.02),mx.gluon.nn.LayerNorm())# 第二层self.net.add(mx.gluon.nn.Dense(units=1),)self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )self.trainer = mx.gluon.Trainer(params=self.net.collect_params(),optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True))self.net.summary(mx.ndarray.zeros(shape=(50, 1, 28, 28), dtype=np.float32, ctx=device))discriminator = Discriminator()
--------------------------------------------------------------------------------Layer (type)                                Output Shape         Param #
================================================================================Input                             (50, 1, 28, 28)               0Flatten-1                                   (50, 784)               0Dense-2                                   (50, 200)          157000LeakyReLU-3                                   (50, 200)               0LayerNorm-4                                   (50, 200)             400Dense-5                                     (50, 1)             201
================================================================================
Parameters in forward computation graph, duplicate includedTotal params: 157601Trainable params: 157601Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 157601
--------------------------------------------------------------------------------

生成器

class Generator():def __init__(self):#self.loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()self.metric = mx.metric.Loss()self.net = mx.gluon.nn.HybridSequential()# 第一层self.net.add(mx.gluon.nn.Dense(units=200),mx.gluon.nn.LeakyReLU(alpha=0.02),mx.gluon.nn.LayerNorm(),)# 第二层self.net.add(mx.gluon.nn.Dense(units=784),mx.gluon.nn.Activation(activation='sigmoid'),mx.gluon.nn.HybridLambda(lambda F, x: F.reshape(x, shape=(0, -1, 28, 28))))self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )self.trainer = mx.gluon.Trainer(params=self.net.collect_params(),optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True))self.net.summary(mx.ndarray.zeros(shape=(50, 100), dtype=np.float32, ctx=device))generator = Generator()

生成种子

def make_seed(batach, num, device=None):if device is None:device = mx.cpu()return mx.ndarray.normal(loc=0, scale=1, shape=(batach, num), ctx=device)

训练

for epoch in range(120):discriminator.metric.reset(); discriminator.metric.reset(); tic = time.time()for datas, _ in train_data:# 批大小actually_batch_size = datas.shape[0]# CPU 移到 GPUdatas  = mx.gluon.utils.split_and_load( datas,  [device] )# 生成种子seeds = [make_seed(data.shape[0], 100, device) for data in datas]# 训练鉴别器for data, seed in zip(datas, seeds):length = data.shape[0]lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)lbl_fake = mx.ndarray.zeros(shape=(length,1), ctx=device)with mx.autograd.record():a = discriminator.loss_fn(discriminator.net(data), lbl_real)b = discriminator.loss_fn(discriminator.net(generator.net(seed).detach()), lbl_fake)d_loss = a + bd_loss.backward()discriminator.metric.update(_, preds=d_loss)discriminator.trainer.step(actually_batch_size)# 训练生成器for seed in seeds:length = seed.shape[0]lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)with mx.autograd.record():img = generator.net(seed)g_loss = discriminator.loss_fn(discriminator.net(img), lbl_real)g_loss.backward()generator.metric.update(_, preds=g_loss)generator.trainer.step(actually_batch_size)print("Epoch {:>2d}: cost:{:.1f}s d_loss:{:.3f} g_loss:{:.3f}".format(epoch, time.time()-tic, discriminator.metric.get()[1], generator.metric.get()[1]))for i in range(10):plt.subplot(1, 10, i + 1)plt.xticks([]); plt.yticks([])seed = make_seed(length, 100, device)img = generator.net(seed).asnumpy()plt.imshow(img[0][0], interpolation='none', cmap='Blues')plt.show()

Epoch 0: cost:43.4s d_loss:0.372 g_loss:4.094
Epoch 1: cost:40.2s d_loss:0.602 g_loss:3.242
Epoch 2: cost:40.5s d_loss:0.661 g_loss:2.887
Epoch 3: cost:40.2s d_loss:0.701 g_loss:2.688
Epoch 4: cost:40.3s d_loss:0.707 g_loss:2.563
Epoch 5: cost:39.2s d_loss:0.708 g_loss:2.480
Epoch 6: cost:52.1s d_loss:0.703 g_loss:2.422
Epoch 7: cost:44.2s d_loss:0.691 g_loss:2.384
Epoch 8: cost:53.6s d_loss:0.675 g_loss:2.359
Epoch 9: cost:38.2s d_loss:0.659 g_loss:2.344
Epoch 10: cost:38.2s d_loss:0.645 g_loss:2.333
Epoch 11: cost:38.2s d_loss:0.640 g_loss:2.325
Epoch 12: cost:38.1s d_loss:0.625 g_loss:2.319
Epoch 13: cost:38.3s d_loss:0.619 g_loss:2.313
Epoch 14: cost:38.3s d_loss:0.610 g_loss:2.308
Epoch 15: cost:38.2s d_loss:0.606 g_loss:2.304
Epoch 16: cost:38.2s d_loss:0.598 g_loss:2.301
Epoch 17: cost:38.3s d_loss:0.590 g_loss:2.298
Epoch 18: cost:38.2s d_loss:0.589 g_loss:2.296
Epoch 19: cost:38.5s d_loss:0.585 g_loss:2.294
Epoch 20: cost:38.3s d_loss:0.585 g_loss:2.292

训练20轮输出如图

损失值评价

最理想的分离器损失值为
entropy = -1 x ln(0.5) = 0.693

训练早期(前6轮),分类器损失值迅速从0.372上升到0.708。同时生成器损失值也在迅速下降。此时生成器轻微领先。

随后通过继续学习(7-26轮),分类器和生成器的损失值都在下降,双方不断进步。

之后开始(27轮),分类器损失继续下降,生成器损失上升,说明分类器已经抛开生成器。

MXNet网络模型(四)GAN神经网络相关推荐

  1. 人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练

    人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练 MXNet 是一个轻量级.可移植.灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 ...

  2. 人工神经网络模型定义,人工神经网络基本框架

    人工神经网络评价法 人工神经元是人工神经网络的基本处理单元,而人工智能的一个重要组成部分又是人工神经网络.人工神经网络是模拟生物神经元系统的数学模型,接受信息主要是通过神经元来进行的. 首先,人工神经 ...

  3. 「紫禁之巅」四大图神经网络架构

    近年来,人们对深度学习方法在图数据上的扩展越来越感兴趣.在深度学习的成功推动下,研究人员借鉴了卷积网络.循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构.图神经网络的火热使得各 ...

  4. 深度学习(3)之经典神经网络模型整理:神经网络、CNN、RNN、LSTM

    本文章总结以下经典的神经网络模型整理,大体讲下模型结构及原理- 如果想深入了解模型架构及pytorch实现,可参考我的Pytorch总结专栏 -> 划重点!!!Pytorch总结文章之目录归纳 ...

  5. gan神经网络_神经联觉:当艺术遇见GAN

    gan神经网络 Neural Synesthesia is an AI art project that aims to create new and unique audiovisual exper ...

  6. 写给人类的机器学习 四、神经网络和深度学习

    四.神经网络和深度学习 原文:Machine Learning for Humans, Part 4: Neural Networks & Deep Learning 作者:Vishal Ma ...

  7. 【机器学习算法面试题】四.深度神经网络中激活函数有哪些?

    欢迎订阅本专栏:<机器学习算法面试题> 订阅地址:https://blog.csdn.net/m0_38068876/category_11810806.html [机器学习算法面试题]一 ...

  8. matlab光谱实验,实验四Matlab神经网络及应用于近红外光谱的汽油辛烷值预测

    . 实验四Matlab神经网络以及应用于汽油辛烷值预测 一.实验目的 1. 掌握MATLAB创建BP神经网络并应用于拟合非线性函数 2. 掌握MATLAB创建REF神经网络并应用于拟合非线性函数 3. ...

  9. 人工神经网络模型有哪些,神经网络分类四种模型

    有哪些深度神经网络模型 目前经常使用的深度神经网络模型主要有卷积神经网络(CNN).递归神经网络(RNN).深信度网络(DBN).深度自动编码器(AutoEncoder)和生成对抗网络(GAN)等. ...

  10. 从神经元到CNN、RNN、GAN…神经网络看本文绝对够了

    在深度学习十分火热的今天,不时会涌现出各种新型的人工神经网络,想要实时了解这些新型神经网络的架构还真是不容易.光是知道各式各样的神经网络模型缩写(如:DCIGN.BiLSTM.DCGAN--还有哪些? ...

最新文章

  1. LeetCode Python题解(二)----排序
  2. RecyclerView + SnapHelper实现炫酷ViewPager效果
  3. 在微型计算机中pci指的是一种,2010新疆维吾尔自治区计算机等级考试二级理论考试试题及答案...
  4. 红旗桌面版本最新运用要领和结果解答100例-3
  5. php100并发cpu告警,多线程并发导致CPU100%的一种原因和解决办法
  6. 【Linux网络编程学习】使用socket实现简单服务器——多进程多线程版本
  7. Java中装箱与拆箱
  8. 蚂蚁森林:国庆节前组织网友去阿拉善等三地参与秋季验收
  9. RDS SQL Server死锁(Deadlock)系列之四利用Service Broker事件通知捕获死锁
  10. KMP算法的C++实现
  11. python画图颜色代码rgb_python – matplotlib 3D散点图,其标记颜色对应于RGB值
  12. 百会与Zoho达成战略合作,向中国用户推出在线办公套件!
  13. 黑客帝国中比较酷炫的代码雨的实现
  14. Mac 上使用 zmodem 发送和接收堡垒机文件
  15. JAVA SpringBoot接科大讯飞TTS语音合成保姆式教程附源代码
  16. Linux下格式化sd卡和重新分区
  17. 游戏服务器稳定ping值,网友玩游戏时Ping值超过了2亿!
  18. linux 系统速度慢,Linux运维人员你知道Linux系统运行速度太慢的原因吗?
  19. WORD文档中插入图片(1)
  20. 剪头发啦,实在不宜出门

热门文章

  1. 趣图:说一说你不知道的世界
  2. ARC120F-Wine Thief(非F2)——序列化环
  3. 基于51单片机的恒温加热系统--main.c文件
  4. 2021年高光谱图像文献追踪_ISPRS_V.180_10
  5. submit事件监听问题
  6. 使用Kotlin语言两年后,我有话要说
  7. 以Skyline问题来看hard问题在面试的时候如何解决?
  8. 手把手教你Windows操作系统添加Virtio驱动
  9. J2EE配置文件加密
  10. php实现阳历阴历互转的方法