前言

CycleGAN是在今年三月底放在arxiv(论文地址CycleGAN)的一篇文章,文章名为Learning to Discover Cross-Domain Relations with Generative Adversarial Networks,同一时期还有两篇非常类似的DualGAN(论文地址:DualGAN)和DiscoGAN(论文地址:DiscoGAN),简单来说,它们的功能就是:自动将某一类图片转换成另外一类图片。不同于GAN和CGAN(上节已经介绍过),CycleGAN不需要配对的训练图像。当然了配对图像也完全可以,不过大多时候配对图像比较难获取。



配对图像

未配对的图像

CycleGAN能做什么?

CycleGAN可以完成GAN和CGAN的工作,如上述配对图像所示,可以从一个特定的场景模式图生成另外一个场景模式图,这两张场景模式中的物体完全相同。除此之外,CycleGAN还可以完成从一个模式到另外一个模式的转换,转换的过程中,物体发生了改变,比如下面的图像中从猫到狗,从男人到女人。


CycleGAN算法原理

如下图所示CycleGAN其实是由两个判别器(DxDxD_ {x}和DyDyD_ {y})和两个生成器(G和F)组成,但是为什么要连两个生成器和两个判别器呢?论文中说,是为了避免所有的X都被映射到同一个Y,比如所有男人的图像都映射到范冰冰的图像上,这显然不合理,所以为了避免这种情况,论文采用了两个生成器的方式,既能满足X->Y的映射,又能满足Y->X的映射,这一点其实就是变分自编码器VAE的思想,是为了适应不同输入图像产生不同输出图像。那么下面的四个公式也很清楚了,(1)是判别器Y对X->Y的映射G的损失,判别器X对Y->X映射的损失也非常类似(2)是两个生成器的循环损失,这里其实是L1L1L_ {1}损失(3)是总损失(4)是对总损失进行优化,先优化D然后优化G和F,这一点和GAN类似



(1)

(2)

(3)

(4)

源代码

训练源代码

import tensorflow as tf
from model import CycleGAN
from reader import Reader
from datetime import datetime
import os
import logging
from utils import ImagePoolFLAGS = tf.flags.FLAGStf.flags.DEFINE_integer('batch_size', 1, 'batch size, default: 1')
tf.flags.DEFINE_integer('image_size', 128, 'image size, default: 256')
tf.flags.DEFINE_bool('use_lsgan', True,'use lsgan (mean squared error) or cross entropy loss, default: True')
tf.flags.DEFINE_string('norm', 'instance','[instance, batch] use instance norm or batch norm, default: instance')
tf.flags.DEFINE_integer('lambda1', 10.0,'weight for forward cycle loss (X->Y->X), default: 10.0')
tf.flags.DEFINE_integer('lambda2', 10.0,'weight for backward cycle loss (Y->X->Y), default: 10.0')
tf.flags.DEFINE_float('learning_rate', 2e-4,'initial learning rate for Adam, default: 0.0002')
tf.flags.DEFINE_float('beta1', 0.5,'momentum term of Adam, default: 0.5')
tf.flags.DEFINE_float('pool_size', 50,'size of image buffer that stores previously generated images, default: 50')
tf.flags.DEFINE_integer('ngf', 64,'number of gen filters in first conv layer, default: 64')tf.flags.DEFINE_string('X', 'tfrecords/apple.tfrecords','X tfrecords file for training, default: tfrecords/apple.tfrecords')
tf.flags.DEFINE_string('Y', 'tfrecords/orange.tfrecords','Y tfrecords file for training, default: tfrecords/orange.tfrecords')
tf.flags.DEFINE_string('load_model', None,'folder of saved model that you wish to continue training (e.g. 20170602-1936), default: None')def train():if FLAGS.load_model is not None:checkpoints_dir = "checkpoints/" + FLAGS.load_modelelse:current_time = datetime.now().strftime("%Y%m%d-%H%M")checkpoints_dir = "checkpoints/{}".format(current_time)try:os.makedirs(checkpoints_dir)except os.error:passgraph = tf.Graph()with graph.as_default():cycle_gan = CycleGAN(X_train_file=FLAGS.X,Y_train_file=FLAGS.Y,batch_size=FLAGS.batch_size,image_size=FLAGS.image_size,use_lsgan=FLAGS.use_lsgan,norm=FLAGS.norm,lambda1=FLAGS.lambda1,lambda2=FLAGS.lambda1,learning_rate=FLAGS.learning_rate,beta1=FLAGS.beta1,ngf=FLAGS.ngf)G_loss, D_Y_loss, F_loss, D_X_loss, fake_y, fake_x = cycle_gan.model()optimizers = cycle_gan.optimize(G_loss, D_Y_loss, F_loss, D_X_loss)summary_op = tf.summary.merge_all()train_writer = tf.summary.FileWriter(checkpoints_dir, graph)saver = tf.train.Saver()with tf.Session(graph=graph) as sess:if FLAGS.load_model is not None:checkpoint = tf.train.get_checkpoint_state(checkpoints_dir)meta_graph_path = checkpoint.model_checkpoint_path + ".meta"restore = tf.train.import_meta_graph(meta_graph_path)restore.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))step = int(meta_graph_path.split("-")[2].split(".")[0])else:sess.run(tf.global_variables_initializer())step = 0coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:fake_Y_pool = ImagePool(FLAGS.pool_size)fake_X_pool = ImagePool(FLAGS.pool_size)while not coord.should_stop():# get previously generated imagesfake_y_val, fake_x_val = sess.run([fake_y, fake_x])# train_, G_loss_val, D_Y_loss_val, F_loss_val, D_X_loss_val, summary = (sess.run([optimizers, G_loss, D_Y_loss, F_loss, D_X_loss, summary_op],feed_dict={cycle_gan.fake_y: fake_Y_pool.query(fake_y_val),cycle_gan.fake_x: fake_X_pool.query(fake_x_val)}))if step % 100 == 0:train_writer.add_summary(summary, step)train_writer.flush()if step % 100 == 0:logging.info('-----------Step %d:-------------' % step)logging.info('  G_loss   : {}'.format(G_loss_val))logging.info('  D_Y_loss : {}'.format(D_Y_loss_val))logging.info('  F_loss   : {}'.format(F_loss_val))logging.info('  D_X_loss : {}'.format(D_X_loss_val))if step % 1000 == 0:save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)logging.info("Model saved in file: %s" % save_path)step += 1except KeyboardInterrupt:logging.info('Interrupted')coord.request_stop()except Exception as e:coord.request_stop(e)finally:save_path = saver.save(sess, checkpoints_dir + "/model.ckpt", global_step=step)logging.info("Model saved in file: %s" % save_path)# When done, ask the threads to stop.coord.request_stop()coord.join(threads)def main(unused_argv):train()if __name__ == '__main__':logging.basicConfig(level=logging.INFO)tf.app.run()

测试源代码

"""Translate an image to another image
An example of command-line usage is:
python export_graph.py --model pretrained/apple2orange.pb \--input input_sample.jpg \--output output_sample.jpg \--image_size 256
"""import tensorflow as tf
import os
from model import CycleGAN
import utilsFLAGS = tf.flags.FLAGStf.flags.DEFINE_string('model', 'model/apple2orange.pb', 'model path (.pb)')
tf.flags.DEFINE_string('input', 'samples/real_apple2orange_4.jpg', 'input image path (.jpg)')
tf.flags.DEFINE_string('output', 'output/output_sample3.jpg', 'output image path (.jpg)')
tf.flags.DEFINE_integer('image_size', '256', 'image size, default: 256')def inference():graph = tf.Graph()with graph.as_default():with tf.gfile.FastGFile(FLAGS.input, 'rb') as f:image_data = f.read()input_image = tf.image.decode_jpeg(image_data, channels=3)input_image = tf.image.resize_images(input_image, size=(FLAGS.image_size, FLAGS.image_size))input_image = utils.convert2float(input_image)input_image.set_shape([FLAGS.image_size, FLAGS.image_size, 3])with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:graph_def = tf.GraphDef()graph_def.ParseFromString(model_file.read())[output_image] = tf.import_graph_def(graph_def,input_map={'input_image': input_image},return_elements=['output_image:0'],name='output')with tf.Session(graph=graph) as sess:generated = output_image.eval()with open(FLAGS.output, 'wb') as f:f.write(generated)def main(unused_argv):inference()if __name__ == '__main__':tf.app.run()

实验结果

在这里是以相同物体不同模式下的数据集做训练(由于没有找到不同物体不同模式下的数据,当然你也可以自己做),从苹果到橘子的训练,测试结果如下:

从上图可以看出,苹果的颜色已经改成橘色,效果得到了体现。
源代码链接:CycleGAN source code

CycleGAN算法原理(附源代码,可直接运行)相关推荐

  1. python协同过滤可以预测吗_基于用户的协同过滤推荐算法原理-附python代码实现...

    在推荐系统众多方法中,基于用户的协同过滤推荐算法是最早诞生的,原理也较为简单.该算法1992年提出并用于邮件过滤系统,两年后1994年被 GroupLens 用于新闻过滤.一直到2000年,该算法都是 ...

  2. Python实用脚本/算法集合, 附源代码下载

    学习编程.学习Python最好的方式就是练习,哪怕是新手,只要不断地敲代码输出,肯定会有神效. Python的练手项目很多,特别是Github上,建议不管新手.老司机都去看看. 这里推荐给大家两个Gi ...

  3. 张正友相机标定算法原理与源代码(OpenCV+C++)

    摄像机的标定问题是机器视觉领域的入门问题,可以分为传统的摄像机定标方法和摄像机自定标方法.定标的方法有很多中常见的有:Tsai(传统)和张正友(介于传统和自定标)等, 摄像机成像模型和四个坐标系(通用 ...

  4. SHA224和SHA256哈希算法原理及实现(附源码)

    相关文章: SHA224和SHA256哈希算法原理及实现(附源码) 国密SM3哈希算法原理及实现(附源码) SHA1哈希算法原理及实现(附源码) MD5哈希算法原理及实现(附源码) MD4哈希算法原理 ...

  5. SHA3系列(KECCAK)哈希算法原理及实现(附源码)

    相关文章: (本文持续更新中) SHA3系列(KECCAK)哈希算法原理及实现(附源码) SHA512系列哈希算法原理及实现(附源码) SHA224和SHA256哈希算法原理及实现(附源码) 国密SM ...

  6. P2P之UDP穿透NAT的原理与实现(附源代码)(转)

    转自:http://www.ppcn.net/n1306c2.aspx 作者:shootingstars | 日期:2004-05-25 | 字体:大 中 小 P2P 之 UDP穿透NAT的原理与实现 ...

  7. 基于聚类的推荐算法笔记——以豆瓣电影为例(二)(附源代码)

    基于聚类的推荐算法笔记--以豆瓣电影为例(二)(附源代码) 第一章 聚类算法介绍 基于聚类的推荐算法笔记一 第二章 数据介绍 基于聚类的推荐算法笔记二 第三章 实现推荐算法 基于聚类的推荐算法笔记三 ...

  8. OpenCV4中有哪些视频背景/前景分割(背景建模/前景提取)算法的类,它们各自的算法原理、特点是什么,并附示例代码

    关于OpenCV4中有哪些视频背景/前景分割(背景建模/前景提取)算法的类,汇总如下: 上面的汇总不仅显示了OpenCV4中有哪些视频背景/前景分割(背景建模/前景提取)算法的类,还显示了它们的继承. ...

  9. 国密SM3密码杂凑算法原理及实现(附源码)

    相关文章: 国密SM3哈希算法原理及实现(附源码) SHA1哈希算法原理及实现(附源码) MD5哈希算法原理及实现(附源码) MD4哈希算法原理及实现(附源码) MD2哈希算法原理及实现(附源码) M ...

  10. SHA512系列哈希算法原理及实现(附源码)

    相关文章: SHA512系列哈希算法原理及实现(附源码) SHA224和SHA256哈希算法原理及实现(附源码) 国密SM3哈希算法原理及实现(附源码) SHA1哈希算法原理及实现(附源码) MD5哈 ...

最新文章

  1. android实现底部弹出菜单,Android实现底部缓慢弹出菜单
  2. 如何表示数组所有数都不等于一个数_每日算法系列【LeetCode 330】按要求补齐数组...
  3. 了解 SharePoint 2010 开发中的关键设计决定
  4. UIActionSheet与UIAlertView
  5. pytorch学习笔记(6):GPU和如何保存加载模型
  6. redis 缓存 淘汰
  7. linux查看文件元数据,Linux切换目录、查看目录下的文件、文件类型介绍和查看文件的元数据信息...
  8. 深入理解 GraphQL
  9. css3做各种角度三角形
  10. 空间相关分析与SDM
  11. 110报警声+单片机c语言,单片机实现消防车报警声音的设计
  12. 批处理|测试工具|局域网速度测试/网络上传与下载速度测试
  13. 高级计量经济学及stata应用_推荐使用的计量经济学教材
  14. 软件测试基础:白盒测试方法:用基本路径覆盖法设计测试用例,基本路径覆盖测试技术设计测试用例的步骤
  15. Excel两列的数据合并成一列
  16. 保险丝的作用,参数及选型应用,你真的懂了吗——电子元器件篇
  17. AdaBoost公式简单版本的推导
  18. 计算机安全在医学中的重要性,试议计算机技术在医学中的意义
  19. PHP在线客服系统源码+H5+APP+搭建教程实例
  20. 3.5 视频服务器-RTSP实现(框架搭建)

热门文章

  1. iOS开发中常用的宏
  2. Redis学习手册(主从复制)
  3. 跑马灯效果 例子 写法 利用闭包
  4. HDU2066--一个人的旅行(Dijkstra)
  5. AS3 CookBook学习整理(二)
  6. 修改Code Blocks默认代码格式
  7. C语言随机读写数据文件(一)
  8. 服务注册中心---服务发现nacos
  9. EasyUI form ajax submit到MVC后,在IE下提示下载内容的解决办法
  10. jQuery获取iframe中页面的高度