备注:下载github程序并运行,本人是小白一枚。这也是本人第一次运行GAN程序。在这里仅仅将自己的代码实操过程分享出来。附件的tf.subtract()和tf.multiply()函数相对于github有所改动,不然版本配不上会出错。这里附上:github源代码链接

#运行run.py文件
run run.py

运行结果,耗时约一小时,CPU Windows7

step 0: loss 21.247
step 10: loss 22.076
step 20: loss 19.516
step 30: loss 6.678
step 40: loss 5.416
step 50: loss 8.300
step 60: loss 5.144
step 70: loss 6.920
step 80: loss 5.351
step 90: loss 6.231
step 100: loss 5.182
step 110: loss 4.267
step 120: loss 7.220
step 130: loss 6.011
step 140: loss 6.295
step 150: loss 4.976·····step 99870: loss 0.013
step 99880: loss 0.011
step 99890: loss 0.010
step 99900: loss 0.014
step 99910: loss 0.011
step 99920: loss 0.009
step 99930: loss 0.014
step 99940: loss 0.011
step 99950: loss 0.014
step 99960: loss 0.006
step 99970: loss 0.003
step 99980: loss 0.017
step 99990: loss 0.004<Figure size 640x480 with 1 Axes>
#显示生成的最后一张图片
visualize.visualize(embed, x_test)


附件

#run.py""" Siamese implementation using Tensorflow with MNIST example.
This siamese network embeds a 28x28 image (a point in 784D)
into a point in 2D.By Youngwook Paul Kwon (young at berkeley.edu)
"""from __future__ import absolute_import
from __future__ import division
from __future__ import print_function#import system things
from tensorflow.examples.tutorials.mnist import input_data # for data
import tensorflow as tf
import numpy as np
import os#import helpers
import inference
import visualize# prepare data and tf.session
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
sess = tf.InteractiveSession()# setup siamese network
siamese = inference.siamese();
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(siamese.loss)
saver = tf.train.Saver()
tf.initialize_all_variables().run()# if you just want to load a previously trainmodel?
new = True
model_ckpt = 'model.ckpt'
if os.path.isfile(model_ckpt):input_var = Nonewhile input_var not in ['yes', 'no']:input_var = raw_input("We found model.ckpt file. Do you want to load it [yes/no]?")if input_var == 'yes':new = False# start training
if new:for step in range(100000):batch_x1, batch_y1 = mnist.train.next_batch(128)batch_x2, batch_y2 = mnist.train.next_batch(128)batch_y = (batch_y1 == batch_y2).astype('float')_, loss_v = sess.run([train_step, siamese.loss], feed_dict={siamese.x1: batch_x1, siamese.x2: batch_x2, siamese.y_: batch_y})if np.isnan(loss_v):print('Model diverged with loss = NaN')quit()if step % 10 == 0:print ('step %d: loss %.3f' % (step, loss_v))if step % 1000 == 0 and step > 0:saver.save(sess, 'model.ckpt')embed = siamese.o1.eval({siamese.x1: mnist.test.images})embed.tofile('embed.txt')
else:saver.restore(sess, 'model.ckpt')# visualize result
x_test = mnist.test.images.reshape([-1, 28, 28])
visualize.visualize(embed, x_test)
#inference.py
import tensorflow as tf class siamese:# Create modeldef __init__(self):self.x1 = tf.placeholder(tf.float32, [None, 784])self.x2 = tf.placeholder(tf.float32, [None, 784])with tf.variable_scope("siamese") as scope:self.o1 = self.network(self.x1)scope.reuse_variables()self.o2 = self.network(self.x2)# Create lossself.y_ = tf.placeholder(tf.float32, [None])self.loss = self.loss_with_spring()def network(self, x):weights = []fc1 = self.fc_layer(x, 1024, "fc1")ac1 = tf.nn.relu(fc1)fc2 = self.fc_layer(ac1, 1024, "fc2")ac2 = tf.nn.relu(fc2)fc3 = self.fc_layer(ac2, 2, "fc3")return fc3def fc_layer(self, bottom, n_weight, name):assert len(bottom.get_shape()) == 2n_prev_weight = bottom.get_shape()[1]initer = tf.truncated_normal_initializer(stddev=0.01)W = tf.get_variable(name+'W', dtype=tf.float32, shape=[n_prev_weight, n_weight], initializer=initer)b = tf.get_variable(name+'b', dtype=tf.float32, initializer=tf.constant(0.01, shape=[n_weight], dtype=tf.float32))fc = tf.nn.bias_add(tf.matmul(bottom, W), b)return fcdef loss_with_spring(self):margin = 5.0labels_t = self.y_labels_f = tf.subtract(1.0, self.y_, name="1-yi")          # labels_ = !labels;eucd2 = tf.pow(tf.subtract(self.o1, self.o2), 2)eucd2 = tf.reduce_sum(eucd2, 1)eucd = tf.sqrt(eucd2+1e-6, name="eucd")C = tf.constant(margin, name="C")# yi*||CNN(p1i)-CNN(p2i)||^2 + (1-yi)*max(0, C-||CNN(p1i)-CNN(p2i)||^2)pos = tf.multiply(labels_t, eucd2, name="yi_x_eucd2")# neg = tf.mul(labels_f, tf.sub(0.0,eucd2), name="yi_x_eucd2")# neg = tf.mul(labels_f, tf.maximum(0.0, tf.sub(C,eucd2)), name="Nyi_x_C-eucd_xx_2")neg = tf.multiply(labels_f, tf.pow(tf.maximum(tf.subtract(C, eucd), 0), 2), name="Nyi_x_C-eucd_xx_2")losses = tf.add(pos, neg, name="losses")loss = tf.reduce_mean(losses, name="loss")return lossdef loss_with_step(self):margin = 5.0labels_t = self.y_labels_f = tf.subtract(1.0, self.y_, name="1-yi")          # labels_ = !labels;eucd2 = tf.pow(tf.subtract(self.o1, self.o2), 2)eucd2 = tf.reduce_sum(eucd2, 1)eucd = tf.sqrt(eucd2+1e-6, name="eucd")C = tf.constant(margin, name="C")pos = tf.multiply(labels_t, eucd, name="y_x_eucd")neg = tf.multiply(labels_f, tf.maximum(0.0, tf.subtract(C, eucd)), name="Ny_C-eucd")losses = tf.add(pos, neg, name="losses")loss = tf.reduce_mean(losses, name="loss")return loss
#visualize.py
from tensorflow.examples.tutorials.mnist import input_dataimport numpy as np
import matplotlib.pyplot as plt
from matplotlib import offsetboxdef visualize(embed, x_test):# two ways of visualization: scale to fit [0,1] scale# feat = embed - np.min(embed, 0)# feat /= np.max(feat, 0)# two ways of visualization: leave with original scalefeat = embedax_min = np.min(embed,0)ax_max = np.max(embed,0)ax_dist_sq = np.sum((ax_max-ax_min)**2)plt.figure()ax = plt.subplot(111)shown_images = np.array([[1., 1.]])for i in range(feat.shape[0]):dist = np.sum((feat[i] - shown_images)**2, 1)if np.min(dist) < 3e-4*ax_dist_sq:   # don't show points that are too closecontinueshown_images = np.r_[shown_images, [feat[i]]]imagebox = offsetbox.AnnotationBbox(offsetbox.OffsetImage(x_test[i], zoom=0.6, cmap=plt.cm.gray_r),xy=feat[i], frameon=False)ax.add_artist(imagebox)plt.axis([ax_min[0], ax_max[0], ax_min[1], ax_max[1]])# plt.xticks([]), plt.yticks([])plt.title('Embedding from the last layer of the network')plt.show()if __name__ == "__main__":mnist = input_data.read_data_sets('MNIST_data', one_hot=False)x_test = mnist.test.imagesx_test = x_test.reshape([-1, 28, 28])embed = np.fromfile('embed.txt', dtype=np.float32)embed = embed.reshape([-1, 2])visualize(embed, x_test)

GAN代码实操(github代码实操)相关推荐

  1. github代码推送

    1.git简介  git是一个集版本控制,内容管理,工作管理一身的系统,可以通过它对一个大项目进行多人开发,并对不同终端的代码块进行整合. 2.git图解   git可以分为四个工作区分别是 工作区( ...

  2. MOOC TensorFlow入门实操课程代码回顾总结(三)

    欢迎来到TensorFlow入门实操课程的学习 MOOC TensorFlow入门实操课程代码回顾总结(一) MOOC TensorFlow入门实操课程代码回顾总结(二) 注: 用于表示python代 ...

  3. Bootstrap4+MySQL前后端综合实训-Day10-AM【实训汇报-下午返校、项目代码(7个包+7个Html页面)】

    [Bootstrap4前端框架+MySQL数据库]前后端综合实训[10天课程 博客汇总表 详细笔记][附:实训所有代码] 目录 实训汇报 数据库--所有SQL语句 工程文件展示 代码 ①package ...

  4. Asp代码转换java代码器_asp下实现对HTML代码进行转换的函数

    asp下实现对HTML代码进行转换的函数 更新时间:2007年08月08日 12:08:49   作者: '****************************** '函数:HTMLEncode( ...

  5. GitHub代码一键转VS Code:只需+1s

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转自 | 机器之心 导读 被微软收购后的 GitHub,正在变 ...

  6. 巧断梯度:单个loss实现GAN模型(附开源代码)

    作者丨苏剑林 单位丨广州火焰信息科技有限公司 研究方向丨NLP,神经网络 个人主页丨kexue.fm 我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了.但是 GAN 不一 ...

  7. 使用pycharm将自己项目代码上传github(保姆教程)

    1.梳理一下Git.github和gitee这三个之间的关系: 1.1.Github 首先从我们最熟悉的github来说,他其实是一个代码托管平台,我们可以在他的里面新建很多的仓库,有强迫症的我理解就 ...

  8. 真「祖传代码」!你的 GitHub 代码已打包运往北极,传给 1000 年后人类

    晓查 发自 凹非寺  量子位 报道 | 公众号 QbitAI 程序员们,激动的消息来了! GitHub刚刚公布了一组照片,你的代码上周已经被打包运往北极保存.只要你2月2日以前贡献过的开源代码,现在都 ...

  9. ubuntu18.04.4 中 下载 github 代码 并创建 python 虚拟环境virtualenv

    文章目录 ubuntu18.04.4 中 下载 github 代码 并创建 python 虚拟环境virtualenv 1 安装virtualenv和virtualenvwrapper 2 githu ...

  10. yolov3网络结构图_目标检测——YOLO V3简介及代码注释(附github代码——已跑通)...

    GitHub: liuyuemaicha/PyTorch-YOLOv3​github.com 注:该代码fork自eriklindernoren/PyTorch-YOLOv3,该代码相比master分 ...

最新文章

  1. python发送消息到微信_通过python登录微信发送消息
  2. 怎么看mysql有没有安装成功_MySQL 安装看这一篇就够了
  3. fedora在此处打开终端
  4. docker4dotnet #1 – 前世今生 amp; 世界你好
  5. 字符串匹配(多模式匹配篇)
  6. java中break和return的区别_java 中return和break的区别
  7. 响应式网站与自适应网站比较
  8. 【数据结构和算法笔记】AOE网和关键路径
  9. C语言之枚举的定义以及测试
  10. 十大经典策略之一 - Dual Thrust策略(期货)
  11. Adaptable DL with nGraph™ Compiler and ONNX*
  12. windows下使用curl命令
  13. Gitter有趣的软件安装界面
  14. LQR:Linear Quadratic Regulator 线性二次型调节器
  15. MySQL-存储引擎-索引-锁-集群
  16. 经纬恒润受邀出席2021世界智能网联汽车大会
  17. Linux常见日志文件和常用命令
  18. Altium中PCB导入二维码
  19. 快速入门 | 篇十三:正运动技术运动控制器ZDevelop 编程软件的使用
  20. 【君思智慧园区】智慧园区规划思路分析

热门文章

  1. java正态分布的概率密度函数_正态分布概率密度函数
  2. 利用PPT表格对图片进行矩形分割
  3. 25个常用Matplotlib图的Python代码
  4. 怎么用transmac制作mac安装盘|transmac制作苹果系统启动U盘方法
  5. 【WiFi】WiFi信道(2.4G、5G及5G DFS)及国家码和电话代码和时区对应表
  6. sem与seo的区别与联系
  7. oracle 基础知识(十四)----索引扫描
  8. 小米8鸿蒙系统,小米手机刷鸿蒙系统
  9. 查找、下载芯片手册推荐网址
  10. [Groovy]Groovy with Ant Task