使用tensorflow 2.5.0搭建wgan网络:

代码:

import argparse
from inspect import classify_class_attrs
import tensorflow as tf
import tensorflow.keras as K
import numpy as np
import cv2 as cv
import os
import time#采用静态图的形式,可关闭急切模式
#tf.compat.v1.disable_eager_execution() parser = argparse.ArgumentParser()
parser.add_argument('--data', required=True, type=str)
parser.add_argument('--type',default='gp')
args = parser.parse_args()class datasets:def __init__(self, datasetsPath:str,type='gp'):self.dataPath = datasetsPathself.type = typeself.noise_dim = 128self.dim = 64self.epochs = 400self.batch_size = 16self.data_num = 12500self.learning_rate = 2e-4self.save_step = 300self.n_critic = 1self.n_generate = 3self.save_path = 'generateImage/'self.model_path = 'checkpoints/'self.classifyImages()self.buildTrainData()def classifyImages(self):imageList = os.listdir(self.dataPath)np.random.seed(10101)np.random.shuffle(imageList)self.catImages = []self.dogImages = []for index in imageList:if 'cat' in index:self.catImages.append(os.path.join(self.dataPath,index))elif 'dog' in index:self.dogImages.append(os.path.join(self.dataPath,index))def load_image(self,imagePath:tf.Tensor):img = img = tf.io.read_file(imagePath)img = tf.image.decode_jpeg(img) #此处为jpeg格式img = tf.image.resize(img,(self.dim,self.dim))/255.0#img = tf.reshape(img,[self.dim,self.dim,3])img = tf.cast(img,tf.float32)return imgdef buildTrainData(self):'''you can watch the datasets use function take;\\For example:img = traindata.ds_train.take(3)\\print(np.shape(np.array(list(img.as_numpy_iterator())))) #(3, 32, 64, 64, 3)for img in traindata.ds_train.take(3):\\print(img)\\image = np.array(img[0]*255,np.uint8)\\cv.imshow("asf",image)\\cv.waitKey(0)\\''' self.ds_train = tf.data.Dataset.from_tensor_slices(self.catImages) \.map(self.load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE) \.repeat(1000) \.shuffle(buffer_size = 500).batch(self.batch_size) \.prefetch(tf.data.experimental.AUTOTUNE).cache()  self.itertor_train = iter(self.ds_train)#-------------------------------------------------------------------
#                            define resBlock                       |
#-------------------------------------------------------------------   class resBlock(K.layers.Layer):def __init__(self,num_filters, resampling,strides=2):super().__init__()self.num_filters = num_filtersself.resampling = resamplingself.strides = stridesdef build(self, input_shape):self.upsampl1 = K.layers.UpSampling2D()self.covC1 = K.layers.Conv2D(self.num_filters, kernel_size=(3,3), strides=1, padding='same',kernel_initializer='he_normal')self.covC1_1 = K.layers.Conv2D(self.num_filters, kernel_size=(3,3), strides=self.strides,  padding='same',kernel_initializer='he_normal')self.BN = K.layers.BatchNormalization()self.relu = K.layers.Activation('relu')self.LeakyRelu = K.layers.LeakyReLU()self.cov1 = K.layers.Conv2D(self.num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer='he_normal')self.BN1 = K.layers.BatchNormalization()self.cov2 = K.layers.Conv2D(self.num_filters, kernel_size=(3,3), strides=1,padding='same',kernel_initializer='he_normal')self.BN2 = K.layers.BatchNormalization()self.upsampl2 = K.layers.UpSampling2D()self.covC2 = K.layers.Conv2D(self.num_filters, kernel_size=(1,1), strides=1, padding='same',kernel_initializer='he_normal')self.covC2_1 = K.layers.Conv2D(self.num_filters, kernel_size=(1,1), strides=self.strides,  padding='same',kernel_initializer='he_normal')self.BN3 = K.layers.BatchNormalization()self.add = K.layers.Add()return super().build(input_shape)def call(self, inputs):#F1,F2,F3 = num_filtersX_shortcut = inputs#//up or downif self.resampling is 'up':x = self.upsampl1(inputs)x = self.covC1(x)#x = keras.layers.Conv2DTranspose(num_filters,kernel_size=kernel_size, strides=2,  padding='same',#              kernel_initializer=keras.initializers.RandomNormal())(x)elif self.resampling is 'down':x = self.covC1_1(inputs)#//BN_relux = self.BN(x)# x = self.relu(x)x = self.LeakyRelu(x)#//cov2dx = self.cov1(x)#//BN_relux = self.BN1(x)# x = self.relu(x)x = self.LeakyRelu(x)#//cov2dx = self.cov2(x)#//BN_relux = self.BN2(x)#//add_shortcutif self.resampling is 'up':X_shortcut = self.upsampl2(X_shortcut)X_shortcut = self.covC2(X_shortcut)#x = keras.layers.Conv2DTranspose(num_filters,kernel_size=kernel_size, strides=2,  padding='same',#              kernel_initializer=keras.initializers.RandomNormal())(x)elif self.resampling is 'down':X_shortcut = self.covC2_1(X_shortcut)X_shortcut = self.BN3(X_shortcut)X_add = self.add([x,X_shortcut])#X_add = self.relu(X_add)X_add = self.LeakyRelu(X_add)return X_addclass generator(K.Model):def __init__(self,resampling='up'):super(generator,self).__init__()self.resampling = resamplingdef build(self, input_shape):self.dense_1 = K.layers.Dense(512*4*4)self.reshape_1 = K.layers.Reshape((4,4,512))self.BN = K.layers.BatchNormalization()self.relu = K.layers.Activation('relu')self.resblock_1 = resBlock(num_filters=512,resampling=self.resampling)self.resblock_2 = resBlock(num_filters=256,resampling=self.resampling)self.resblock_3 = resBlock(num_filters=128,resampling=self.resampling)self.resblock_4 = resBlock(num_filters=64,resampling=self.resampling)self.conv2d = K.layers.Conv2D(3, kernel_size=(3,3), strides=1, padding='same',kernel_initializer='he_normal')self.tanh = K.layers.Activation('tanh')return super().build(input_shape)def call(self, inputs):g = self.dense_1(inputs)g = self.reshape_1(g)#//BN_relug = self.BN(g)g = self.relu(g)#g = keras.layers.LeakyReLU()(g)#4*4*512g = self.resblock_1(g)#8*8*512g = self.resblock_2(g)#16*16*256g = self.resblock_3(g)#32*32*128g = self.resblock_4(g)#64*64*64g = self.conv2d(g)#64*64*3g_out = self.tanh(g)return g_outclass discriminator(K.Model):def __init__(self,resampling='down'):super(discriminator,self).__init__()self.resampling = resamplingdef build(self, input_shape):self.conv2d = K.layers.Conv2D(64, kernel_size=(3,3), padding='same',strides=1,kernel_initializer='he_normal')self.BN = K.layers.BatchNormalization()self.relu = K.layers.Activation('relu')self.LeakyRelu = K.layers.LeakyReLU()self.resblock_1 = resBlock(num_filters=128,resampling=self.resampling)self.resblock_2 = resBlock(num_filters=256,resampling=self.resampling)self.resblock_3 = resBlock(num_filters=512,resampling=self.resampling)self.resblock_4 = resBlock(num_filters=512,resampling=self.resampling)self.averagePool2d = K.layers.GlobalAveragePooling2D()self.dense = K.layers.Dense(1)return super().build(input_shape)def call(self, inputs):d = self.conv2d(inputs)#//BN_relud = self.BN(d)# d = self.relu(d)d = self.LeakyRelu(d)#64*64*64d = self.resblock_1(d) #32*32*128d = self.resblock_2(d)#16*16*256d = self.resblock_3(d)#8*8*512d = self.resblock_4(d)#4*4*512'''GlobalAveragePooling :it can replace the full connection layeryou can use the Dense to test the network'''d = self.averagePool2d(d)d_out = self.dense(d)return d_outclass wgan:def __init__(self,datasets:datasets):self.traindata = datasetsdef build(self):assert self.traindata.type == 'gp' or self.traindata.type == 'div', f'please confirm the type is {self.traindata.type}'#------------------------------#define the generate model    *#------------------------------self.generate_model = generator()#--------------------------------#define the discriminator model *#--------------------------------self.discriminator_model = discriminator()#--------------------------------#      combine the model        *#--------------------------------z_noise = K.layers.Input(shape=(self.traindata.noise_dim,))#get the scorefake_img = self.generate_model(z_noise)fake_score = self.discriminator_model(fake_img)self.combineModel = K.Model(z_noise,fake_score)#--------------------------------#        optimizer              *#--------------------------------self.discriminator_optimizer=K.optimizers.Adam(self.traindata.learning_rate, 0.5)self.generator_optimizer=K.optimizers.Adam(self.traindata.learning_rate, 0.5)self.generate_model.summary()self.discriminator_model.summary()self.combineModel.summary()def train_discriminator(self,z_noise,train_data,u_niform):k = 2p = 6u = u_niformwith tf.GradientTape() as tape,\tf.GradientTape() as d_tape:D_fake_img = self.generate_model(z_noise)D_fake_score = self.discriminator_model(D_fake_img)D_real_score = self.discriminator_model(train_data)#get a new input consist from fake and realx_ = (1.-u)*train_data+u*D_fake_img#//#-------------------------------------------------------------------#                            wgan div loss function                |#                               n_critic = 1                       |#                          arxiv.org/pdf/1712.01026.pdf            |#-------------------------------------------------------------------#//if self.traindata.type == 'div':gradients = tape.gradient(self.discriminator_model(x_), [x_])[0]grad_norm = K.backend.sqrt(K.backend.sum(gradients ** 2, axis=[1, 2, 3])) ** pgrad_penalty = k * K.backend.mean(grad_norm)discriminator_loss = K.backend.mean(D_real_score - D_fake_score)#//#-------------------------------------------------------------------#                            wgan gp  loss function                |#                               n_critic = 5                       |#                          arxiv.org/pdf/1704.00028.pdf            |#-------------------------------------------------------------------#//elif self.traindata.type == 'gp':gradients = tape.gradient(self.discriminator_model(x_), [x_])[0]grad_norm = K.backend.sqrt(K.backend.sum(gradients ** 2, axis=[1, 2, 3]))grad_norm = K.backend.square(1-grad_norm)grad_penalty =  10*K.backend.mean(grad_norm)discriminator_loss = K.backend.mean(D_fake_score-D_real_score)discriminator_loss_all = grad_penalty + discriminator_lossgradients_d = d_tape.gradient(discriminator_loss_all,self.discriminator_model.trainable_variables)self.discriminator_optimizer.apply_gradients(zip(gradients_d,self.discriminator_model.trainable_variables))return discriminator_loss_alldef train_generator(self,z_noise):with tf.GradientTape() as g_tape:G_fake_img = self.generate_model(z_noise)G_fake_score = self.discriminator_model(G_fake_img)if self.traindata.type == 'div':generate_loss = K.backend.mean(G_fake_score)if self.traindata.type == 'gp':generate_loss = -K.backend.mean(G_fake_score)#min this valuegradients_g = g_tape.gradient(generate_loss,self.generate_model.trainable_variables)self.generator_optimizer.apply_gradients(zip(gradients_g,self.generate_model.trainable_variables))return generate_lossdef train(self):if os.path.exists(os.path.join(self.traindata.model_path,'gan.weights')):self.combineModel.load_weights(os.path.join(self.traindata.model_path,'gan.weights'))if os.path.exists(os.path.join(self.traindata.model_path,'history.npy')):history = np.load(os.path.join(self.traindata.model_path,'./history.npy'), allow_pickle=True).tolist()#read the last data use -1 index,and use 0 to read the first data#\\last_iter = int(history[-1][0])print('Find the npy file, the last save iter:%d' % (last_iter))else:history = []last_iter = -1else:print('There is no .npy file, creating a new file---------')history = []last_iter = -1for step in range(last_iter+1,int(self.traindata.epochs*self.traindata.data_num/self.traindata.batch_size+1)):try:#get the timestart_time = time.time()train_data = self.traindata.itertor_train.get_next()z_noise = np.random.normal(size=self.traindata.batch_size*self.traindata.noise_dim)\.reshape([self.traindata.batch_size,self.traindata.noise_dim])u_niform = np.random.uniform(low=0.0,high=1.0,size=(self.traindata.batch_size,1,1,1))# training the modelfor i in range(self.traindata.n_critic):discriminator_loss_all = self.train_discriminator(z_noise,train_data,u_niform)for i in range(self.traindata.n_generate):generate_loss = self.train_generator(z_noise)#get the time duration = time.time()-start_time#-----------------------------------------#            print the loss              |#-----------------------------------------if step % 5 == 0:tf.print("The step is %s,d_loss:%s,g_loss:%s, "%(step,discriminator_loss_all,generate_loss),end=' ')tf.print('%.2f s/step'%(duration))#-----------------------------------------#       plot the train history           |#-----------------------------------------#\\if step % 5 == 0 :history.append([step, discriminator_loss_all, generate_loss])#-----------------------------------------#       save the model_weights           |#-----------------------------------------#\\if step % self.traindata.save_step == 0 and step != 0:# save the train stepsnp.save(os.path.join(self.traindata.model_path,'./history.npy'), history)self.combineModel.save_weights(os.path.join(self.traindata.model_path,'gan.weights'))#-----------------------------------------#       save the image of generate       |#-----------------------------------------#\\if step % 50 == 0 and step != 0:noise_test = np.random.normal(size=[1,self.traindata.noise_dim])noise_test = np.array(noise_test,np.float32)fake_image = self.generate_model(noise_test)'''复原图像1.乘以255后需要映射成uint8的类型2.也可以保持[0,1]的float32类型,依然可以直接输出'''arr_img = np.array([fake_image],np.float32).reshape([self.traindata.dim,self.traindata.dim,3])*255arr_img = np.array(arr_img,np.uint8)#保存为tfrecords用的是PIL.Image,即打开为RGB,所以在用cv显示时需要转换为BGRarr_img = cv.cvtColor(arr_img,cv.COLOR_RGB2BGR)cv.imwrite(self.traindata.save_path+str(step)+'.jpg',arr_img)except tf.errors.OutOfRangeError: tf.print("the iter is out of range\n")if __name__=='__main__':traindata = datasets(args.data,args.type)mygan = wgan(traindata)mygan.build()mygan.train()

参考:

使用tensorflow 1.13.1 keras 搭建wgan-gp:
https://blog.csdn.net/qq_42995327/article/details/111463011

tensorflow 2.5.0 ( keras )搭建wgan-gp 和 div相关推荐

  1. 掌声送给TensorFlow 2.0!用Keras搭建一个CNN | 入门教程

    作者 | Himanshu Rawlani 译者 | Monanfei,责编 | 琥珀 出品 | AI科技大本营(id:rgznai100) 2019 年 3 月 6 日,谷歌在 TensorFlow ...

  2. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  3. python机器学习:搭建tensorflow环境,下载Keras库并在python中成功完成导入。pycharm的相关配置。

    安装过程较长,请大家耐心阅读,其中有一些自己在安装过程中出现过的一些问题,在此一同分享给大家 一.下载Anaconda: 1.首先下载安装Anaconda,可以去官网下载 https://www.an ...

  4. 30行代码就可以实现看图识字!python使用tensorflow.keras搭建简单神经网络

    文章目录 搭建过程 1. 引入必需的库 2. 引入数据集 3. 搭建神经网络层 4. 编译神经网络模型 5. 训练模型 效果测试 大概几个月前,神经网络.人工智能等概念在我心里仍高不可攀,直到自己亲身 ...

  5. 【记录】本科毕设:基于树莓派的智能小车设计(使用Tensorflow + Keras 搭建CNN卷积神经网络 使用端到端的学习方法训练CNN)

    0 申明 这是本人2020年的本科毕业设计,内容多为毕设论文和答辩内容中挑选.最初的灵感来自于早前看过的一些项目(抱歉时间久远,只记录了这一个,见下),才让我萌生了做个机电(小车动力与驱动)和控制(树 ...

  6. Ubuntu16.04.4 + 双 NAVDA TitanX + CUDA9.0 + cudnn7.05 + TensorFlow 1.8(1.5.0) + Keras

    一.安装 Ubunt16.04.4 二.安装显卡驱动 二*.遇到问题 若驱动安装失败,不能进入系统,采用如下两种方法解决. 无法进入桌面的问题 三.安装 CUDA 四.安装 cudnn 五.在 bas ...

  7. Ubuntu16.04下CUDA 9.0 + cuDNN v7.0 + tensorflow 1.6.0(GPU)环境搭建

    由于自己攒了个主机,第一次安装GPU版本的tensorflow,mark一下. 说明一下,本篇上接<Ubuntu16.04LTS下搭建强化学习环境gym.tensorflow>这篇文章,只 ...

  8. TensorFlow 1.11.0正式版发布了,强力支持Keras

    学习栗 发自 凹非寺  量子位 报道 | 公众号 QbitAI 在rc0,rc1,rc2排队出场之后,TensorFlow 1.11.0的正式版上线了. 相比从前,新版本对Keras的支持力度更强了. ...

  9. ERROR: Cannot install keras==2.2.0 and tensorflow==1.14.0 because these package versions have confli

    ERROR: Cannot install keras2.2.0 and tensorflow1.14.0 because these package versions have conflictin ...

  10. 宝可梦 图片识别python_使用Tensorflow从0开始搭建精灵宝可梦的检测APP

    使用Tensorflow从0开始搭建精灵宝可梦的检测APP 本文为本人原创,转载请注明来源链接 环境要求 Tensorflow1.12.0 cuda 9.0 python3.6.10 Android ...

最新文章

  1. android上传本地图片到服务器上,Android使用post方式上传图片到服务器的方法
  2. Codeforces #449 div2 C题
  3. golang中字符串内置函数整理
  4. Spring Boot 系列(八)@ControllerAdvice 拦截异常并统一处理
  5. android控件ems,Android登录等待效果
  6. 从零开始入门 K8s | 有状态应用编排 - StatefulSet
  7. 1.9 编程基础之二分查找 13:整数去重 python
  8. 暑期训练日志----2018.8.24
  9. 微信分享接口 略缩图 php
  10. dorado 7 使用总结
  11. NIOS 2 软核中EPCS配置芯片的存储操作
  12. 评价微型计算机有哪些主要性能指标,计算机性能指标有哪些
  13. matlab中e如何输入,Matlab中表达e的操作方法介绍
  14. proposal中文翻译_proposal是什么意思_ proposal的翻译_音标_读音_用法_例句_爱词霸在线词典...
  15. 多任务:分层特征融合网络 NDDR-CNN
  16. vue+ElementUI 实现管理端照片墙(或广告位)效果
  17. 计算机回收站设置大小,电脑回收站无法调整容量的大小怎么办?
  18. 腾讯-算法工程师电话面试
  19. linux redis退出命令行,linux的redis启动关闭命令
  20. Rhapsody freeMaker 将任意HL7转XML

热门文章

  1. android studio2.4,Android Studio 2.4 Preview 7 发布
  2. javax.servlet.http.HttpServletResponse.setContentLengthLong(J)V,maven项目报错!!无法访问webapp下的文件,完美解决方案
  3. mysql hash分区 子分区_mysql分区管理 - hash分区
  4. 【转】VirtualDOM与diff(Vue实现).MarkDown
  5. 信用体系,生态之魂!——保险科技生态建设
  6. 最新的Scrum中文指南及更新
  7. centos 7 install VirtualBox
  8. 基于 React.js + redux + bootstrap 的 RubyChina 示例
  9. CABasicAnimation添加动画离开屏幕就动画停止的问题
  10. 多路复用输入/输出 ---- select