MXNet网络模型(四)GAN神经网络
MXNet网络模型(四)对抗神经网络
- 概述
- 原理
- MXNet代码
概述
GAN 神经网络(2014)
- 《Generative Adversarial Network》
- https://arxiv.org/pdf/1406.2661.pdf
GAN网络也是一直在进步
原理
GAN训练可以分为5步
- 样本图片x输入分类器
- 随机种子z经由生成器生成模拟图片x*
- 模拟图片x*输入分类器
- 反馈给分类器
- 反馈给生成器
MXNet代码
导入库
import time
import gzip
import numpy as np
import matplotlib.pyplot as plt![请添加图片描述](https://img-blog.csdnimg.cn/153a1da1e6d8438681cddca1fae72021.bmp)import mxnet as mx
环境和批大小
batch_size = 10
device = mx.cpu()
读取训练数据
def load_dataset():transform = mx.gluon.data.vision.transforms.ToTensor()train_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]train_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=True )]eval_img = [ transform(img).asnumpy() for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]eval_lbl = [ np.array(lbl) for img, lbl in mx.gluon.data.vision.datasets.MNIST(root='mnist', train=False)]return train_img, train_lbl, eval_img, eval_lbltrain_img, train_lbl, eval_img, eval_lbl = load_dataset()train_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(train_img, train_lbl),batch_size=batch_size,shuffle=True
)
eval_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(eval_img, eval_lbl),batch_size=batch_size,shuffle=False
)
预览训练数据
idxs = (25, 47, 74, 88, 92)
for i in range(5):plt.subplot(1, 5, i + 1)idx = idxs[i]plt.xticks([])plt.yticks([])img = train_img[idx][0].astype( np.float32 )plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()
分类器
class Discriminator():def __init__(self):self.loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()self.metric = mx.metric.Loss()self.net = mx.gluon.nn.HybridSequential()# 第一层self.net.add(mx.gluon.nn.Flatten(),mx.gluon.nn.Dense(units=200),mx.gluon.nn.LeakyReLU(alpha=0.02),mx.gluon.nn.LayerNorm())# 第二层self.net.add(mx.gluon.nn.Dense(units=1),)self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )self.trainer = mx.gluon.Trainer(params=self.net.collect_params(),optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True))self.net.summary(mx.ndarray.zeros(shape=(50, 1, 28, 28), dtype=np.float32, ctx=device))discriminator = Discriminator()
--------------------------------------------------------------------------------Layer (type) Output Shape Param #
================================================================================Input (50, 1, 28, 28) 0Flatten-1 (50, 784) 0Dense-2 (50, 200) 157000LeakyReLU-3 (50, 200) 0LayerNorm-4 (50, 200) 400Dense-5 (50, 1) 201
================================================================================
Parameters in forward computation graph, duplicate includedTotal params: 157601Trainable params: 157601Non-trainable params: 0
Shared params in forward computation graph: 0
Unique parameters in model: 157601
--------------------------------------------------------------------------------
生成器
class Generator():def __init__(self):#self.loss_fn = mx.gluon.loss.SigmoidBinaryCrossEntropyLoss()self.metric = mx.metric.Loss()self.net = mx.gluon.nn.HybridSequential()# 第一层self.net.add(mx.gluon.nn.Dense(units=200),mx.gluon.nn.LeakyReLU(alpha=0.02),mx.gluon.nn.LayerNorm(),)# 第二层self.net.add(mx.gluon.nn.Dense(units=784),mx.gluon.nn.Activation(activation='sigmoid'),mx.gluon.nn.HybridLambda(lambda F, x: F.reshape(x, shape=(0, -1, 28, 28))))self.net.initialize( init=mx.init.Xavier(rnd_type='gaussian'), ctx=device )self.trainer = mx.gluon.Trainer(params=self.net.collect_params(),optimizer=mx.optimizer.Adam(learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-08, lazy_update=True))self.net.summary(mx.ndarray.zeros(shape=(50, 100), dtype=np.float32, ctx=device))generator = Generator()
生成种子
def make_seed(batach, num, device=None):if device is None:device = mx.cpu()return mx.ndarray.normal(loc=0, scale=1, shape=(batach, num), ctx=device)
训练
for epoch in range(120):discriminator.metric.reset(); discriminator.metric.reset(); tic = time.time()for datas, _ in train_data:# 批大小actually_batch_size = datas.shape[0]# CPU 移到 GPUdatas = mx.gluon.utils.split_and_load( datas, [device] )# 生成种子seeds = [make_seed(data.shape[0], 100, device) for data in datas]# 训练鉴别器for data, seed in zip(datas, seeds):length = data.shape[0]lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)lbl_fake = mx.ndarray.zeros(shape=(length,1), ctx=device)with mx.autograd.record():a = discriminator.loss_fn(discriminator.net(data), lbl_real)b = discriminator.loss_fn(discriminator.net(generator.net(seed).detach()), lbl_fake)d_loss = a + bd_loss.backward()discriminator.metric.update(_, preds=d_loss)discriminator.trainer.step(actually_batch_size)# 训练生成器for seed in seeds:length = seed.shape[0]lbl_real = mx.ndarray.ones(shape=(length,1), ctx=device)with mx.autograd.record():img = generator.net(seed)g_loss = discriminator.loss_fn(discriminator.net(img), lbl_real)g_loss.backward()generator.metric.update(_, preds=g_loss)generator.trainer.step(actually_batch_size)print("Epoch {:>2d}: cost:{:.1f}s d_loss:{:.3f} g_loss:{:.3f}".format(epoch, time.time()-tic, discriminator.metric.get()[1], generator.metric.get()[1]))for i in range(10):plt.subplot(1, 10, i + 1)plt.xticks([]); plt.yticks([])seed = make_seed(length, 100, device)img = generator.net(seed).asnumpy()plt.imshow(img[0][0], interpolation='none', cmap='Blues')plt.show()
Epoch 0: cost:43.4s d_loss:0.372 g_loss:4.094
Epoch 1: cost:40.2s d_loss:0.602 g_loss:3.242
Epoch 2: cost:40.5s d_loss:0.661 g_loss:2.887
Epoch 3: cost:40.2s d_loss:0.701 g_loss:2.688
Epoch 4: cost:40.3s d_loss:0.707 g_loss:2.563
Epoch 5: cost:39.2s d_loss:0.708 g_loss:2.480
Epoch 6: cost:52.1s d_loss:0.703 g_loss:2.422
Epoch 7: cost:44.2s d_loss:0.691 g_loss:2.384
Epoch 8: cost:53.6s d_loss:0.675 g_loss:2.359
Epoch 9: cost:38.2s d_loss:0.659 g_loss:2.344
Epoch 10: cost:38.2s d_loss:0.645 g_loss:2.333
Epoch 11: cost:38.2s d_loss:0.640 g_loss:2.325
Epoch 12: cost:38.1s d_loss:0.625 g_loss:2.319
Epoch 13: cost:38.3s d_loss:0.619 g_loss:2.313
Epoch 14: cost:38.3s d_loss:0.610 g_loss:2.308
Epoch 15: cost:38.2s d_loss:0.606 g_loss:2.304
Epoch 16: cost:38.2s d_loss:0.598 g_loss:2.301
Epoch 17: cost:38.3s d_loss:0.590 g_loss:2.298
Epoch 18: cost:38.2s d_loss:0.589 g_loss:2.296
Epoch 19: cost:38.5s d_loss:0.585 g_loss:2.294
Epoch 20: cost:38.3s d_loss:0.585 g_loss:2.292
训练20轮输出如图
损失值评价
最理想的分离器损失值为
entropy = -1 x ln(0.5) = 0.693
训练早期(前6轮),分类器损失值迅速从0.372上升到0.708。同时生成器损失值也在迅速下降。此时生成器轻微领先。
随后通过继续学习(7-26轮),分类器和生成器的损失值都在下降,双方不断进步。
之后开始(27轮),分类器损失继续下降,生成器损失上升,说明分类器已经抛开生成器。
MXNet网络模型(四)GAN神经网络相关推荐
- 人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练
人工智能深度学习框架MXNet实战:深度神经网络的交通标志识别训练 MXNet 是一个轻量级.可移植.灵活的分布式深度学习框架,2017 年 1 月 23 日,该项目进入 Apache 基金会,成为 ...
- 人工神经网络模型定义,人工神经网络基本框架
人工神经网络评价法 人工神经元是人工神经网络的基本处理单元,而人工智能的一个重要组成部分又是人工神经网络.人工神经网络是模拟生物神经元系统的数学模型,接受信息主要是通过神经元来进行的. 首先,人工神经 ...
- 「紫禁之巅」四大图神经网络架构
近年来,人们对深度学习方法在图数据上的扩展越来越感兴趣.在深度学习的成功推动下,研究人员借鉴了卷积网络.循环网络和深度自动编码器的思想,定义和设计了用于处理图数据的神经网络结构.图神经网络的火热使得各 ...
- 深度学习(3)之经典神经网络模型整理:神经网络、CNN、RNN、LSTM
本文章总结以下经典的神经网络模型整理,大体讲下模型结构及原理- 如果想深入了解模型架构及pytorch实现,可参考我的Pytorch总结专栏 -> 划重点!!!Pytorch总结文章之目录归纳 ...
- gan神经网络_神经联觉:当艺术遇见GAN
gan神经网络 Neural Synesthesia is an AI art project that aims to create new and unique audiovisual exper ...
- 写给人类的机器学习 四、神经网络和深度学习
四.神经网络和深度学习 原文:Machine Learning for Humans, Part 4: Neural Networks & Deep Learning 作者:Vishal Ma ...
- 【机器学习算法面试题】四.深度神经网络中激活函数有哪些?
欢迎订阅本专栏:<机器学习算法面试题> 订阅地址:https://blog.csdn.net/m0_38068876/category_11810806.html [机器学习算法面试题]一 ...
- matlab光谱实验,实验四Matlab神经网络及应用于近红外光谱的汽油辛烷值预测
. 实验四Matlab神经网络以及应用于汽油辛烷值预测 一.实验目的 1. 掌握MATLAB创建BP神经网络并应用于拟合非线性函数 2. 掌握MATLAB创建REF神经网络并应用于拟合非线性函数 3. ...
- 人工神经网络模型有哪些,神经网络分类四种模型
有哪些深度神经网络模型 目前经常使用的深度神经网络模型主要有卷积神经网络(CNN).递归神经网络(RNN).深信度网络(DBN).深度自动编码器(AutoEncoder)和生成对抗网络(GAN)等. ...
- 从神经元到CNN、RNN、GAN…神经网络看本文绝对够了
在深度学习十分火热的今天,不时会涌现出各种新型的人工神经网络,想要实时了解这些新型神经网络的架构还真是不容易.光是知道各式各样的神经网络模型缩写(如:DCIGN.BiLSTM.DCGAN--还有哪些? ...
最新文章
- LeetCode Python题解(二)----排序
- RecyclerView + SnapHelper实现炫酷ViewPager效果
- 在微型计算机中pci指的是一种,2010新疆维吾尔自治区计算机等级考试二级理论考试试题及答案...
- 红旗桌面版本最新运用要领和结果解答100例-3
- php100并发cpu告警,多线程并发导致CPU100%的一种原因和解决办法
- 【Linux网络编程学习】使用socket实现简单服务器——多进程多线程版本
- Java中装箱与拆箱
- 蚂蚁森林:国庆节前组织网友去阿拉善等三地参与秋季验收
- RDS SQL Server死锁(Deadlock)系列之四利用Service Broker事件通知捕获死锁
- KMP算法的C++实现
- python画图颜色代码rgb_python – matplotlib 3D散点图,其标记颜色对应于RGB值
- 百会与Zoho达成战略合作,向中国用户推出在线办公套件!
- 黑客帝国中比较酷炫的代码雨的实现
- Mac 上使用 zmodem 发送和接收堡垒机文件
- JAVA SpringBoot接科大讯飞TTS语音合成保姆式教程附源代码
- Linux下格式化sd卡和重新分区
- 游戏服务器稳定ping值,网友玩游戏时Ping值超过了2亿!
- linux 系统速度慢,Linux运维人员你知道Linux系统运行速度太慢的原因吗?
- WORD文档中插入图片(1)
- 剪头发啦,实在不宜出门