算法原理

具体算法原理,在这里不再阐述,可以参考:

https://blog.csdn.net/stdcoutzyx/article/details/53872121
https://ask.julyedu.com/question/7681

基于keras的原代码

from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.optimizers import SGD
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import mathdef generator_model():model = Sequential()model.add(Dense(input_dim=100, output_dim=1024))model.add(Activation('tanh'))model.add(Dense(128*7*7))model.add(BatchNormalization())model.add(Activation('tanh'))model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))model.add(UpSampling2D(size=(2, 2)))model.add(Conv2D(64, (5, 5), padding='same'))model.add(Activation('tanh'))model.add(UpSampling2D(size=(2, 2)))model.add(Conv2D(1, (5, 5), padding='same'))model.add(Activation('tanh'))return modeldef discriminator_model():model = Sequential()model.add(Conv2D(64, (5, 5),padding='same',input_shape=(28, 28, 1)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Conv2D(128, (5, 5)))model.add(Activation('tanh'))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(1024))model.add(Activation('tanh'))model.add(Dense(1))model.add(Activation('sigmoid'))return modeldef generator_containing_discriminator(g, d):model = Sequential()model.add(g)d.trainable = Falsemodel.add(d)return modeldef combine_images(generated_images):num = generated_images.shape[0]width = int(math.sqrt(num))height = int(math.ceil(float(num)/width))shape = generated_images.shape[1:3]image = np.zeros((height*shape[0], width*shape[1]),dtype=generated_images.dtype)for index, img in enumerate(generated_images):i = int(index/width)j = index % widthimage[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = \img[:, :, 0]return imagedef train(BATCH_SIZE):(X_train, y_train), (X_test, y_test) = mnist.load_data()X_train = (X_train.astype(np.float32) - 127.5)/127.5X_train = X_train[:, :, :, None]X_test = X_test[:, :, :, None]# X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])d = discriminator_model()g = generator_model()d_on_g = generator_containing_discriminator(g, d)#使用判断器来监督生成器的训练d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)g.compile(loss='binary_crossentropy', optimizer="SGD")#只有生成器d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)#编译生成器d.trainable = Trued.compile(loss='binary_crossentropy', optimizer=d_optim)#只有判断器for epoch in range(100):print("Epoch is", epoch)print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))for index in range(int(X_train.shape[0]/BATCH_SIZE)):#每轮有多少个step,数据/batch_sizenoise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))#生成100维度的噪声image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]generated_images = g.predict(noise, verbose=0)#生成器生成的图像if index % 20 == 0:image = combine_images(generated_images)#这个是把20张图像拼接起来,组成一个大图像image = image*127.5+127.5#从-1~1还原成0~255Image.fromarray(image.astype(np.uint8)).save(str(epoch)+"_"+str(index)+".png")#保存图像X = np.concatenate((image_batch, generated_images))y = [1] * BATCH_SIZE + [0] * BATCH_SIZEd_loss = d.train_on_batch(X, y)#训练判断器print("batch %d d_loss : %f" % (index, d_loss))noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))d.trainable = False#把判断器冻结g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)#固定判断器参数,训练生成器d.trainable = Trueprint("batch %d g_loss : %f" % (index, g_loss))if index % 10 == 9:#每更新十次,保存模型g.save_weights('generator', True)d.save_weights('discriminator', True)#加载训练好的模型进行预测
def generate(BATCH_SIZE, nice=False):g = generator_model()g.compile(loss='binary_crossentropy', optimizer="SGD")g.load_weights('generator')#生成比较好的模型if nice:d = discriminator_model()d.compile(loss='binary_crossentropy', optimizer="SGD")d.load_weights('discriminator')noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))generated_images = g.predict(noise, verbose=1)d_pret = d.predict(generated_images, verbose=1)index = np.arange(0, BATCH_SIZE*20)index.resize((BATCH_SIZE*20, 1))pre_with_index = list(np.append(d_pret, index, axis=1))pre_with_index.sort(key=lambda x: x[0], reverse=True)nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)nice_images = nice_images[:, :, :, None]for i in range(BATCH_SIZE):idx = int(pre_with_index[i][1])nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]image = combine_images(nice_images)#生成一般的模型else:noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))generated_images = g.predict(noise, verbose=1)image = combine_images(generated_images)image = image*127.5+127.5Image.fromarray(image.astype(np.uint8)).save("generated_image.png")def get_args():parser = argparse.ArgumentParser()parser.add_argument("--mode", type=str,default='train')parser.add_argument("--batch_size", type=int, default=128)parser.add_argument("--nice", dest="nice", action="store_true")parser.set_defaults(nice=False)args = parser.parse_args()return argsif __name__ == "__main__":args = get_args()if args.mode == "train":train(BATCH_SIZE=args.batch_size)elif args.mode == "generate":generate(BATCH_SIZE=args.batch_size, nice=args.nice)

实验结果

迭代的次数还不够,生成的效果不算太好,若有兴趣,可以多增加迭代次数,会生成更好的图像。


keras实现DCGAN生成mnist原代码相关推荐

  1. 生死看淡,不服就GAN(五)----用DCGAN生成MNIST手写体

    搭建DCGAN网络 #*************************************** 生死看淡,不服就GAN ************************************* ...

  2. Pytorch 使用DCGAN生成动漫人物头像 入门级实战教程

    有关DCGAN实战的小例子之前已经更新过一篇,感兴趣的朋友可以点击查看 Pytorch 使用DCGAN生成MNIST手写数字 入门级教程 有关DCGAN的相关原理:DCGAN论文解读-----DCGA ...

  3. 搭建简单GAN生成MNIST手写体

    Keras搭建GAN生成MNIST手写体 GAN简介 生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来复杂分布上无监督学习最具前 ...

  4. DCGAN生成彩色头像

    最近是一段特殊时期,因为发热,虽然未到过疫区或接触过疑似患者,但仍被要求在家隔离,本人非计算机专业,业余爱好编程,最感兴趣的是深度神经网络,天天看大佬们发各种博文或帖子,一直再学习,刚好利用在家隔离的 ...

  5. 好像还挺好玩的GAN2——Keras搭建DCGAN利用深度卷积神经网络实现图片生成

    好像还挺好玩的GAN2--Keras搭建DCGAN利用深度卷积神经网络实现图片生成 注意事项 学习前言 什么是DCGAN 神经网络构建 1.Generator 2.Discriminator 训练思路 ...

  6. DCGAN生成cifar10, cifar100, mnist, fashion_mnist,STL10,Anime图片(pytorch)

    代码下载地址下载地址https://www.lanzouw.com/ipl8Yo37qxihttps://www.lanzouw.com/ipl8Yo37qxi Anime数据请在Anime Face ...

  7. R语言使用keras包实现卷积自动编码器模型(Convolutional Autoencoder)、加载keras自带的mnist数据集、训练中动态生成每个epoch后模型训练的loss曲线

    R语言使用keras包实现卷积自动编码器模型(Convolutional Autoencoder).加载keras自带的mnist数据集.训练中动态生成每个epoch后模型训练的loss曲线 目录

  8. vs2019 利用Pytorch和TensorFlow分别实现DCGAN生成动漫头像

    这是针对于博客vs2019安装和使用教程(详细)的DCGAN生成动漫头像项目新建示例 目录 一.DCGAN架构及原理 二.项目结构 1.TensorFlow 2.Pytorch 三.数据集下载(两种方 ...

  9. DCGAN生成动漫头像【学习】

    DCGAN生成动漫头像 在假期看了李宏毅老师的GAN的介绍,看到了课后题DCGAN生成动漫头像的作业,实现一下.记录学习过程. 参考的文章: [Keras] 基于GAN自动生成动漫头像 因为使用的是t ...

  10. 体验DCGAN生成漫画头像

    DCGAN是一种深度学习模型,针对复杂分布的无监督学习有很好的发展前景. 这种模型实现的方法是通过两个模块(生成器和判别器)互相为难对方来提升模型的准确率.通俗来说就是老师和学生的关系,老师给学生出题 ...

最新文章

  1. java 老年代回收_Java垃圾回收之老年代垃圾收集器
  2. react-native for android windows开发环境搭建详细记录
  3. 破解入门(六)-----实战“内存镜像法”脱壳
  4. Struts2中使用OGNL表达式投影(过滤)集合
  5. Java静态域与静态方法
  6. c语言常考的程序,复试C语言常考趣味程序方案.doc
  7. eclipse中文乱码解决_Stata中文乱码顽疾解决方法-一行命令
  8. js类似matlab_JavaScript与MATLAB的计算性能差异对比研究
  9. word2vec原理知识铺垫
  10. 训练集、测试集loss容易出现的问题总结
  11. YouCompleteMe自动补全的安装配置与使用
  12. Shiro 权限验证原理
  13. 轻量级音乐服务器LMS
  14. 高可用的分布式Hadoop大数据平台搭建,超详细,附代码。
  15. 查询高校名【Python习题】(保姆级图文+实现代码)
  16. 微服务学习第四十七节 Nacos一致性协议:Distro协议
  17. JS中flag使用场景之一
  18. 麻将牌型说明最全版(图文介绍)
  19. Android错误:unexpected text found in layout file
  20. 史上最全的Altium Designer 20安装教程

热门文章

  1. js右下角广告[兼容]
  2. 使用UMDH进行内心泄露分析
  3. ExtJS4.2学习(13)基于表格的扩展插件---rowEditing
  4. Android TouchEvent事件传递机制
  5. ORA-12514: TNS:listener does not currently know of service …
  6. Guava - Collections - Immutable collections
  7. 滚轮事件的防冒泡、阻止默认行为
  8. IE 8 HTML Parsing Error:Unable to modify the parent container element before the child element is...
  9. Protobuf3详细介绍
  10. lodop指定打印机打印_2020年打印机推荐选购,看这篇就够了