本文是对Train Variational Autoencoder (VAE) to Generate Images网页的翻译,该网页实现了变分自编码的图像生成,以MNIST手写数字为训练数据,生成了相似的图像。本文主要翻译了网页中帮助函数外的部分。主要部分见MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images。

processImagesMNIST

首先是两个用于处理mnist数据集的函数,分别处理图片和标签,使其符合网络的输入要求。

function X = processImagesMNIST(filename)
% The MNIST processing functions extract the data from the downloaded IDX
% files into MATLAB arrays. The processImagesMNIST function performs these
% operations: Check if the file can be opened correctly. Obtain the magic
% number by reading the first four bytes. The magic number is 2051 for
% image data, and 2049 for label data. Read the next 3 sets of 4 bytes,
% which return the number of images, the number of rows, and the number of
% columns. Read the image data. Reshape the array and swaps the first two
% dimensions due to the fact that the data was being read in column major
% format. Ensure the pixel values are in the range  [0,1] by dividing them
% all by 255, and converts the 3-D array to a 4-D dlarray object. Close the
% file.[fileID,errmsg] = fopen(filename,'r','b');
if fileID < 0error(errmsg);
endmagicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2051fprintf('\nRead MNIST image data...\n')
endnumImages = fread(fileID,1,'int32',0,'b');
fprintf('Number of images in the dataset: %6d ...\n',numImages);
numRows = fread(fileID,1,'int32',0,'b');
numCols = fread(fileID,1,'int32',0,'b');X = fread(fileID,inf,'unsigned char');X = reshape(X,numCols,numRows,numImages);
X = permute(X,[2 1 3]);
X = X./255;
X = reshape(X, [28,28,1,size(X,3)]);
X = dlarray(X, 'SSCB');fclose(fileID);
end

processImagesMNIST

处理标签,使其符合网络的输入要求

function Y = processLabelsMNIST(filename)
% The processLabelsMNIST function operates similarly to the
% processImagesMNIST function. After opening the file and reading the magic
% number, it reads the labels and returns a categorical array containing
% their values.[fileID,errmsg] = fopen(filename,'r','b');if fileID < 0error(errmsg);
endmagicNum = fread(fileID,1,'int32',0,'b');
if magicNum == 2049fprintf('\nRead MNIST label data...\n')
endnumItems = fread(fileID,1,'int32',0,'b');
fprintf('Number of labels in the dataset: %6d ...\n',numItems);Y = fread(fileID,inf,'unsigned char');Y = categorical(Y);fclose(fileID);
end

Model Gradients Function

The modelGradients function takes the encoder and decoder dlnetwork objects and a mini-batch of input data X, and returns the gradients of the loss with respect to the learnable parameters in the networks. The function performs three operations:

  • Obtain the encodings by calling the sampling function on the mini-batch of images that passes through the encoder network.
  • Obtain the loss by passing the encodings through the decoder network and calling the ELBOloss function.
  • Compute the gradients of the loss with respect to the learnable parameters of both networks by calling the dlgradient function.

modelGradients函数获取编码器和解码器的dlnetwork对象以及输入数据X的小批量,并返回网络中可训练参数的损失梯度。 该函数执行三个操作:

  • 通过在通过编码器网络的微型图像批次上调用采样函数来获取编码。
  • 通过使编码通过解码器网络并调用ELBOloss函数来获得损耗。
  • 通过调用dlgradient函数,针对两个网络的可学习参数计算损耗的梯度。
function [infGrad, genGrad] = modelGradients(encoderNet, decoderNet, x)
[z, zMean, zLogvar] = sampling(encoderNet, x);
xPred = sigmoid(forward(decoderNet, z));
loss = ELBOloss(x, xPred, zMean, zLogvar);
[genGrad, infGrad] = dlgradient(loss, decoderNet.Learnables, ...encoderNet.Learnables);
end

Sampling and Loss Functions

The sampling function obtains encodings from input images. Initially, it passes a mini-batch of images through the encoder network and splits the output of size (2*latentDim)miniBatchSize into a matrix of means and a matrix of variances, each of size latentDimbatchSize. Then, it uses these matrices to implement the reparameterization trick and to compute the encoding. Finally, it converts this encoding to a dlarray object in SSCB format.

Sampling 函数从输入图像获取编码。 最初,它通过编码器网络传递一个图像的小批量,并将大小(2 × latentDim) × miniBatchSize的输出分成均值矩阵和方差矩阵,每个大小均为latentDim × batchSize。 然后,它使用这些矩阵来实现重新参数化技巧并计算编码。 最后,它将这种编码转换为SSCB格式的dlarray对象。

function [zSampled, zMean, zLogvar] = sampling(encoderNet, x)
compressed = forward(encoderNet, x);
d = size(compressed,1)/2;
zMean = compressed(1:d,:);
zLogvar = compressed(1+d:end,:);sz = size(zMean);
epsilon = randn(sz);
sigma = exp(.5 * zLogvar);
z = epsilon .* sigma + zMean;
z = reshape(z, [1,1,sz]);
zSampled = dlarray(z, 'SSCB');
end

ELBOloss

The ELBOloss function takes the encodings of the means and the variances returned by the sampling function, and uses them to compute the ELBO loss.

ELBOloss函数采用均值和采样函数返回的方差的编码,并使用它们来计算ELBO损耗。

function elbo = ELBOloss(x, xPred, zMean, zLogvar)
squares = 0.5*(xPred-x).^2;
reconstructionLoss  = sum(squares, [1,2,3]);KL = -.5 * sum(1 + zLogvar - zMean.^2 - exp(zLogvar), 1);elbo = mean(reconstructionLoss + KL);
end

Visualization Functions

The VisualizeReconstruction function randomly chooses two images for each digit of the MNIST data set, passes them through the VAE, and plots the reconstruction side by side with the original input. Note that to plot the information contained inside a dlarray object, you need to extract it first using the extractdata and gather functions.

VisualizeReconstruction函数为MNIST数据集的每个数字随机选择两个图像,将它们通过VAE,然后与原始输入并排绘制。 请注意,要绘制dlarray对象中包含的信息,需要先使用extractdata and gather函数将其提取出来。

function visualizeReconstruction(XTest,YTest, encoderNet, decoderNet)
f = figure;
figure(f)
title("Example ground truth image vs. reconstructed image")
for i = 1:2for c=0:9idx = iRandomIdxOfClass(YTest,c);X = XTest(:,:,:,idx);[z, ~, ~] = sampling(encoderNet, X);XPred = sigmoid(forward(decoderNet, z));X = gather(extractdata(X));XPred = gather(extractdata(XPred));comparison = [X, ones(size(X,1),1), XPred];subplot(4,5,(i-1)*10+c+1), imshow(comparison,[]),end
end
endfunction idx = iRandomIdxOfClass(T,c)
idx = T == categorical(c);
idx = find(idx);
idx = idx(randi(numel(idx),1));
end

VisualizeLatentSpace

The VisualizeLatentSpace function visualizes the latent space defined by the mean and the variance matrices that form the output of the encoder network, and locates the clusters formed by the latent space representations of each digit.

VisualizeLatentSpace函数可视化由形成编码器网络输出的均值和方差矩阵定义的潜在空间,并找到由每个数字的潜在空间表示形式形成的聚类。

The function starts by extracting the mean and the variance matrices from the dlarray objects. Because transposing a matrix with channel/batch dimensions (C and B) is not possible, the function calls stripdims before transposing the matrices. Then, it carries out a principal component analysis (PCA) on both matrices. To visualize the latent space in two dimensions, the function keeps the first two principal components and plots them against each other. Finally, the function colors the digit classes so that you can observe clusters.

该函数首先从dlarray对象中提取均值和方差矩阵。 由于无法转置具有通道/批处理尺寸(C和B)的矩阵,因此该函数在转置矩阵之前调用stripdims。 然后,它对两个矩阵执行主成分分析(PCA)。 为了在两个维度上可视化潜在空间,该函数保留前两个主要成分并将其相互绘制。 最后,该函数为数字类着色,以便观察群集。

function visualizeLatentSpace(XTest, YTest, encoderNet)
[~, zMean, zLogvar] = sampling(encoderNet, XTest);zMean = stripdims(zMean)';
zMean = gather(extractdata(zMean));zLogvar = stripdims(zLogvar)';
zLogvar = gather(extractdata(zLogvar));[~,scoreMean] = pca(zMean);
[~,scoreLogvar] = pca(zLogvar);c = parula(10);
f1 = figure;
figure(f1)
title("Latent space")ah = subplot(1,2,1);
scatter(scoreMean(:,2),scoreMean(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
axis equal
xlabel("Z_m_u(2)")
ylabel("Z_m_u(1)")
cb = colorbar; cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);ah = subplot(1,2,2);
scatter(scoreLogvar(:,2),scoreLogvar(:,1),[],c(double(YTest),:));
ah.YDir = 'reverse';
xlabel("Z_v_a_r(2)")
ylabel("Z_v_a_r(1)")
cb = colorbar;  cb.Ticks = 0:(1/9):1; cb.TickLabels = string(0:9);
axis equal
end

generate

The generate function tests the generative capabilities of the VAE. It initializes a dlarray object containing 25 randomly generated encodings, passes them through the decoder network, and plots the outputs.

生成函数测试VAE的生成能力。 它初始化包含25个随机生成的编码的dlarray对象,将它们传递通过解码器网络,并绘制输出。

function generate(decoderNet, latentDim)
randomNoise = dlarray(randn(1,1,latentDim,25),'SSCB');
generatedImage = sigmoid(predict(decoderNet, randomNoise));
generatedImage = extractdata(generatedImage);f3 = figure;
figure(f3)
imshow(imtile(generatedImage, "ThumbnailSize", [100,100]))
title("Generated samples of digits")
drawnow
end

MATLAB实现自编码器(五)——变分自编码器(VAE)实现图像生成的帮助函数相关推荐

  1. 【阿里云课程】深度生成模型基础,自编码器与变分自编码器

    大家好,继续更新有三AI与阿里天池联合推出的深度学习系列课程,本次更新内容为第11课中两节,介绍如下: 第1节:生成模型基础 本次课程是阿里天池联合有三AI推出的深度学习系列课程第11期,深度生成模型 ...

  2. 自编码器,变分自编码器和生成对抗网络异同

    一. AE(AutoEncoder) 参考AutoEncoder 1.1 自编码器简单模型介绍 自编码器可以理解为一个试图去 还原其原始输入的系统. 自动编码模型主要由编码器和解码器组成,其主要目的是 ...

  3. 机器学习-自编码器,变分自编码器及其变种的基本原理(一)

    本篇从自编码器(Auto-Encoder)入手,进行扩展,论述了监督学习和无监督学习的相关知识.接着讲解了自编码器的各种变种,以及比较难以理解的变分自编码器(Variational Auto-Enco ...

  4. torch实现自编码器_Pytorch-自编码器与变分自编码器

    提前导包: 1 importtorch2 from torch importnn, optim3 from torch.utils.data importDataLoader4 from torchv ...

  5. MATLAB 图像函数(第五章) 图像空间变换和图像配准

    1.checkerboard    -------创建棋盘图像 I=checkerboard:  默认8*8的棋盘图像 I=checkerboard(n)   指定期盼图像中每个单元边长的像素 I=c ...

  6. MATLAB实现自编码器(六)——变分自编码器(VAE)官网代码的改进

    本文内容参考了Conditional VAE (Variational Auto Encoder) 条件付きVAE 是对官方网页Train Variational Autoencoder (VAE) ...

  7. MATLAB实现自编码器(四)——变分自编码器实现图像生成Train Variational Autoencoder (VAE) to Generate Images

    本文是对Train Variational Autoencoder (VAE) to Generate Images网页的翻译,该网页实现了变分自编码的图像生成,以MNIST手写数字为训练数据,生成了 ...

  8. 【自然语言处理系列】自编码器AE、变分自编码器VAE和条件变分自编码器CVAE

    作者:CHEONG 公众号:AI机器学习与知识图谱 研究方向:自然语言处理与知识图谱 本文主要分享自编码器.变分自编码器和条件变分自编码器的相关知识以及在实际实践中的应用技巧,原创不易转载请注明出处, ...

  9. VAE(变分自编码器)原理简介

    一.技术背景 变分自编码器(VAE)是一种深度生成模型,可以用于从高维数据中提取潜在的低维表示,并用于生成新的样本数据.自编码器(Autoencoder)是深度学习领域中常用的一种无监督学习方法,其基 ...

  10. 【生成模型】变分自编码器(VAE)及图变分自编码器(VGAE)

    这段时间在学习机器学习中有关不确定性和概率分布的知识,发现了VAE这样一个有趣的方向,想抓紧时间整理一下VAE的主要思想和方法,然后思考如何迁移应用到自己的研究方向上. 从直观上理解VAE 变分自编码 ...

最新文章

  1. python sys.argv 默认值
  2. yum 更新_CentOS7 - 使用yum-cron自动更新软件
  3. makefile的使用方法(简单视频教程以及详细文字教程)
  4. 几个简单的正则小例子
  5. linq from 多个sum_快手重拳打击劣质电商 7月以来封禁700多个团伙账号
  6. centos7源码安装ntp_如何安装和配置 Chrony 作为 NTP 客户端?
  7. 线上故障排查大体思路
  8. 谷歌浏览器报错“您的连接不是私密连接攻击者可能会试图从 xxx 窃取您的信息(例如:密码、通讯内容或信用卡信息)“
  9. Docker系列之五:Docker 三剑客之 Docker Swarm
  10. 游戏爱好者如何选购计算机,通俗易懂,游戏本该怎么挑
  11. 2048小游戏HTML网页版源码共享
  12. EXCEL 启动参数设置
  13. Java操作Excel并显示到网页
  14. Jena学习三——代码解释
  15. BERT实战(1):使用DistilBERT作为词嵌入进行文本情感分类,与其它词向量(FastText,Word2vec,Glove)进行对比
  16. css画钟表_如何使用css3绘制出圆形动态时钟
  17. 【C#】无法从命令行或调试器启动服务,必须首先安装Windows服务(使用installutil.exe)
  18. DingTalk钉钉消息通知
  19. b2b2c系统jwt权限源码分享part1
  20. 2010年知名大企业工资大曝光

热门文章

  1. win10小课堂:必须掌握的十个电脑使用技巧
  2. 【Antd】rawData.some is not a function 报错解决方法
  3. 【计量经济学导论】08. 平稳时间序列
  4. 模拟cmos集成电路 第二版 razavi
  5. 板绘新手sai入门基础教程,非常详细全面!
  6. python建模用什么库_Python 建模库介绍 - 树懒学堂
  7. php读取excel文件_php读取excel文件数据的导入和导出
  8. powshell的tree命令
  9. MYSQL查询语句大全集锦
  10. 软件测试jmeter面试问题,jmeter 面试题剖析实战