一、infoGAN原理简介

普通的GAN存在无约束、不可控、噪声信号z很难解释等问题。InfoGAN 主要特点是对GAN进行了一些改动,成功地让网络学到了可解释的特征,网络训练完成之后,我们可以通过设定输入生成器的隐含编码来控制生成数据的特征。InfoGAN的基本结构为:

其中,真实数据Real_data只是用来跟生成的Fake_data混合在一起进行真假判断,并根据判断的结果更新生成器和判别器,从而使生成的数据与真实数据接近。生成数据既要参与真假判断,还需要和隐变量C_vector求互信息,并根据互信息更新生成器和判别器,从而使得生成图像中保留了更多隐变量C_vector的信息。

二、matlab代码实战

clear all; close all; clc;
%% Info Generative Adversarial Network
%% Load Data
load('mnistAll.mat')
trainX = preprocess(mnist.train_images);
trainY = mnist.train_labels;
testX = preprocess(mnist.test_images);
testY = mnist.test_labels;
%% Settings
args.maxepochs = 50; args.c_weight = 0.5; args.z_dim = 62;
args.batch_size = 16; args.image_size = [28,28,1];
args.lrD = 0.0002; args.lrG = 0.001; args.beta1 = 0.5;
args.beta2 = 0.999; args.cc_dim = 1; args.dc_dim = 10;
args.sample_size = 100;
%% Weights, Biases, Offsets and Scales
% Generator
paramsGen.FCW1 = dlarray(...initializeGaussian([1024,args.z_dim+args.cc_dim+args.dc_dim]));
paramsGen.FCb1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNo1 = dlarray(zeros(1024,1,'single'));
paramsGen.BNs1 = dlarray(ones(1024,1,'single'));paramsGen.FCW2 = dlarray(initializeGaussian([128*7*7,1024]));
paramsGen.FCb2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNo2 = dlarray(zeros(128*7*7,1,'single'));
paramsGen.BNs2 = dlarray(ones(128*7*7,1,'single'));paramsGen.TCW1 = dlarray(initializeGaussian([4,4,64,128]));
paramsGen.TCb1 = dlarray(zeros(64,1,'single'));
paramsGen.BNo3 = dlarray(zeros(64,1,'single'));
paramsGen.BNs3 = dlarray(ones(64,1,'single'));
paramsGen.TCW2 = dlarray(initializeGaussian([4,4,1,64]));
paramsGen.TCb2 = dlarray(zeros(1,1,'single'));%% Progress Plot
function progressplot(args,paramsGen,stGen)
fixednoise = zeros(args.z_dim,args.sample_size);
tmp = zeros(args.cc_dim,args.sample_size);
for i = 1:10tmp(1,(i-1)*10+1:i*10) = linspace(-2,2,10);
end
cc = tmp;tmp = zeros(args.dc_dim,args.sample_size);
for i = 1:10tmp(i,(i-1)*10+1:i*10) = 1;
end
dc = tmp;fake_data = gpuArray(dlarray(cat(1,fixednoise,cc,dc),'CB'));
fake_images = extractdata(Generator(fake_data,paramsGen,stGen));fig = gcf;
if ~isempty(fig.Children)delete(fig.Children)
endI = imtile(fake_images);
I = rescale(I);
imagesc(I)
title("Generated Images")drawnow;end
%% Report Progress
function [d_loss,g_loss] = reportprogress(x,z,paramsDis,...paramsGen,args,stDis,stGen)
fake_images = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
d_output_fake = Discriminator(fake_images,paramsDis,args,stDis);% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
end
%% Model Gradients
function [GradDis,GradGen,stDis,stGen] = modelGradients(x,z,paramsDis,...paramsGen,args,stDis,stGen)
[fake_images,stGen] = Generator(z,paramsGen,stGen);
d_output_real = Discriminator(x,paramsDis,args,stDis);
[d_output_fake,stDis] = Discriminator(fake_images,paramsDis,args,stDis);% Loss due to true or not
d_loss_a = -mean(log(d_output_real(1,:))+log(1-d_output_fake(1,:)));
g_loss_a = -mean(log(d_output_fake(1,:)));% cc loss
output_cc = d_output_fake(2,:);
d_loss_cc = mean((output_cc/0.5).^2);% softmax classification loss
output_dc = d_output_fake(3:end,:);
d_loss_dc = -(mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*output_dc,1))+...mean(sum(z(args.z_dim+args.cc_dim+1:end,:).*z(args.z_dim+args.cc_dim+1:end,:),1)));% Discriminator Loss
d_loss = d_loss_a+args.c_weight*d_loss_cc+d_loss_dc;
% Generator Loss
g_loss = g_loss_a+args.c_weight*d_loss_cc+d_loss_dc;% For each network, calculate the gradients with respect to the loss.
GradGen = dlgradient(g_loss,paramsGen,'RetainData',true);
GradDis = dlgradient(d_loss,paramsDis);end

结果展示

无监督式GAN(infoGAN) matlab实战相关推荐

  1. 从易到难,针对复杂问题的无监督式问题分解方法

    论文标题: Unsupervised Question Decomposition for Question Answering 论文作者: Ethan Perez (FAIR,NYU), Patri ...

  2. 最全机器学习种类讲解:监督、无监督、在线和批量学习都讲明白了

    导读:现有的机器学习系统种类繁多,根据以下内容将它们进行分类有助于我们理解: 是否在人类监督下训练(监督式学习.无监督式学习.半监督式学习和强化学习) 是否可以动态地进行增量学习(在线学习和批量学习) ...

  3. Unsupervised Degradation Representation Learning for Blind Super-Resolution(基于无监督退化表示学习的盲超分辨率处理)

    文章目录 Abstract(摘要) 1. Introduction 2. Related Work 2.1. Single Image Super-Resolution 2.2. Contrastiv ...

  4. 长文解读|Progress in Neurobiology:监督式机器学习在神经科学中的应用

    ​<本文同步发布于"脑之说"微信公众号,欢迎搜索关注~~> 这些年来,人们投入了相当多的热情在机器学习(Machine Learning)领域中,毕竟它让电脑的表现在某 ...

  5. 主流监督式机器学习分类算法

    主流监督式机器学习分类算法的性能比较 姓名:欧阳qq     学号:169559 摘要:机器学习是未来解决工程实践问题的一个重要思路.本文采比较了目前监督式学习中几种主流的分类算法(决策树.SVM.贝 ...

  6. 机器学习(监督,无监督,强化学习及线性回归)

    [监督式学习] 监督式学习算法包括一个目标变量(也就是因变量)和用来预测目标变量的预测变量(相当于自变量).通过这些变量,我们可以搭建一个模型,从而对于一个自变量,我们可以得到对应的因变量.重复训练这 ...

  7. ICLR 2021|基于GAN的二维图像无监督三维形状重建

    2D GAN知道3D形状吗?基于GAN的二维图像无监督三维形状重建 论文.代码地址:在公众号「计算机视觉工坊」,后台回复「二维图像GAN」,即可直接下载. 摘要: 自然图像是三维物体在二维图像平面上的 ...

  8. 港中文周博磊团队:无监督条件下GAN潜在语义识别指南

    点击上方"机器学习与生成对抗网络",关注"星标" 获取有趣.好玩的前沿干货! 作者:Yujun Shen.Bolei Zhou   机器之心编译 参与:蛋酱.魔 ...

  9. BigBiGAN问世,“GAN父”都说酷的无监督表示学习模型有多优秀?

    作者 | Jeff Donahue.Karen Simonyan 译者 | Lucy.一一 出品 | AI开发者大本营(ID:rgznai100) 众所周知,对抗训练生成模型(GAN)在图像生成领域获 ...

最新文章

  1. 重磅图书——PHP MySQL开发新圣经
  2. 内存管理模拟程序c语言,C语言 内存管理详解
  3. java的reflection
  4. 成本中心、作业中心、工作中心的区别
  5. linux下Bash编程until语句及格式化硬盘分区等编写脚本(十)
  6. 结对开发Ⅴ——循环一维数组求和最大的子数组
  7. java 停止kettle转换_通过java运行Kettle转换
  8. Android系统性能优化(73)---总结
  9. UVa 116 (多段图的最短路) Unidirectional TSP
  10. JBoss Tomcat 对 JSP 的泛型支持
  11. 新产品Wyn Enterprise 详解,立即预约公开课
  12. basys2数码管共阳还是共阴_如何判断PLC使用接近开关是PNP还是NPN?
  13. 使用TF卡烧录Jetson NX开发板
  14. 看完吴恩达(Andrew Ng)机器学习视频的感受
  15. java 框架_java三大主流框架是什么
  16. H5 小程序直播教程,一看就会!
  17. 利用HomeKit、智汀家庭云,让不同生态智能家居实现互联互通
  18. 小ck活动机器人包包_古力娜扎空降“小ck”线下门店,手上的包包亮了,仙气又便宜!...
  19. error: (-215:Assertion failed) npoints = 0 (depth == CV_32F || depth == CV_32F || depth ==CV_32S
  20. 使用java进行本地文件全盘搜索

热门文章

  1. 【MFC/C++操作Excel】Excel篇 (OLE/COM)
  2. 线性系统与非线性系统、定常系统和时变系统、连续系统和离散系统、单输入单输出系统与多输入多输出系统(自动控制原理)
  3. markdown模板笔记
  4. CentOS7系统编码
  5. 实践API钩子拦截DLL库调用
  6. python项目七:自建公告板
  7. 怎么用python算单价和总价_怎样用EXCEL表格自动算出数量单价总金额
  8. Java与本息总额计算
  9. Linux中使用iOStream头文件,在Linux中使用gcc链接iostream.h
  10. linux下达梦数据库导出dmp,DM7 达梦数据库 物理备份还原之 备份管理 操作手册