点击下方卡片,关注“CVer”公众号

AI/CV重磅干货,第一时间送达

作者:Aryansh Omray,微软数据科学工程师,Medium技术博主

机器学习领域的一个基本问题就是如何学习复杂数据的表征是机器学习。这项任务的重要性在于,现存的大量非结构化和无标签的数据,只有通过无监督式学习才能理解。密度估计、异常检测、文本总结、数据聚类、生物信息学、DNA建模等各方面的应用均需要完成这项任务。多年来,研究人员发明了许多方法来学习大型数据集的概率分布,包括生成对抗网络(GAN)、变分自编码器(VAE)和Normalizing Flow等。本文即向大家介绍Normalizing Flow这一为了克服GAN和VAE的不足而提出的方法。

Glow模型的输出样例

https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf

GAN和VAE的能力本已十分惊人,它们都能通过简单的推理方法学习十分复杂的数据分布。然而,GAN和VAE都缺乏对概率分布的精确评估和推理,这往往导致VAE中的模糊结果质量不高,GAN训练也面临着如模式崩溃和后置崩溃等挑战。因此,Normalizing Flow应运而生,试图通过使用可逆函数来解决目前GAN和VAE存在的许多问题。

Normalizing Flow

简单地说,Normalizing Flow就是一系列的可逆函数,或者说这些函数的解析逆是可以计算的。例如,f(x)=x+2是一个可逆函数,因为每个输入都有且仅有一个唯一的输出,并且反之亦然,而f(x)=x²则不是一个可逆函数。这样的函数也被称为双射函数。

图源作者

从上图可以看出,Normalizing Flow可以将复杂的数据点(如MNIST中的图像)转化为简单的高斯分布,反之亦然。和GAN非常不一样的地方是,GAN输入的是一个随机向量,而输出的是一个图像,基于流(Flow)的模型则是将数据点转化为简单分布。在上图的MNIST一例中,我们从高斯分布中抽取随机样本,均可重新获得其对应的MNIST图像。

基于流的模型使用负对数可能性损失函数进行训练,其中p(z)是概率函数。下面的损失函数就是使用统计学中的变量变化公式得到的。

https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf

Normalizing Flow的优势

与GAN和VAE相比,Normalizing Flow具有各种优势,包括:

  • Normalizing Flow模型不需要在输出中放入噪声,因此可以有更强大的局部方差模型(local variance model);

  • 与GAN相比,基于流的模型训练过程非常稳定,GAN则需要仔细调整生成器和判别器的超参数;

  • 与GAN和VAE相比,Normalizing Flow更容易收敛。

Normalizing Flow的不足

虽然基于流的模型有其优势,但它们也有一些缺点:

  • 基于流的模型在密度估计等任务上的表现不尽如人意;

  • 基于流的模型要求保留变换的体积(volume preservation over transformations),这往往会产生非常高维的潜在空间,通常会导致解释性变差;

  • 基于流的模型产生的样本通常没有GAN和VAE的好。

为了更好地理解Normalizing Flow,我们以Glow架构为例进行解释。Glow是OpenAI在2018年提出的一个基于流的模型。下图展示了Glow的架构。

Glow的架构

https://papers.nips.cc/paper/2018/file/d139db6a236200b21cc7f752979132d0-Paper.pdf

Glow架构由多个表层(superficial layers)组合而成。首先我们来看看Glow模型的多尺度框架。Glow模型由一系列的重复层(命名为尺度)组成。每个尺度包括一个挤压函数和一个流步骤,每个流步骤包含ActNorm、1x1 Convolution和Coupling Layer,流步骤后是分割函数。分割函数在通道维度上将输入分成两个相等的部分。其中一半进入之后的层,另一半则进入损失函数。分割是为了减少梯度消失的影响,梯度消失会在模型以端到端方式(end-to-end)训练时出现。

如下图所示,挤压函数(squeeze function)通过横向重塑张量,将大小为[c, h, w]的输入张量转换为大小为[4c, h/2, w/2]的张量。此外,在测试阶段可以采用重塑函数,将输入的[4c, h/2, w/2]重塑为大小为[c, h, w]的张量。

https://arxiv.org/pdf/1605.08803.pdf

其他层,如ActNorm、1x1 Convolution和Affine Coupling层,可以从下表理解。该表展示了每层的功能(包括正向和反向)。

https://arxiv.org/pdf/1605.08803.pdf


实现

在了解了Normalizing Flow和Glow模型的基础知识后,我们将介绍如何使用PyTorch实现该模型,并在MNIST数据集上进行训练。

Glow模型

首先,我们将使用PyTorch和nflows实现Glow架构。为了节省时间,我们使用nflows包含所有层的实现。

import torch
import torch.nn as nn
import torch.nn.functional as F
from nflows import transforms
import numpy as np
from torchvision.transforms.functional import resize
from nflows.transforms.base import Transformclass Net(nn.Module):def __init__(self, in_channel, out_channels):super().__init__()self.net = nn.Sequential(nn.Conv2d(in_channel, 64, 3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(64, 64, 1),nn.ReLU(inplace=True),ZeroConv2d(64, out_channels),)def forward(self, inp, context=None):return self.net(inp)def getGlowStep(num_channels, crop_size, i):mask = [1] * num_channelsif i % 2 == 0:mask[::2] = [-1] * (len(mask[::2]))else:mask[1::2] = [-1] * (len(mask[1::2]))def getNet(in_channel, out_channels):return Net(in_channel, out_channels)return transforms.CompositeTransform([transforms.ActNorm(num_channels),transforms.OneByOneConvolution(num_channels),transforms.coupling.AffineCouplingTransform(mask, getNet)])def getGlowScale(num_channels, num_flow, crop_size):z = [getGlowStep(num_channels, crop_size, i) for i in range(num_flow)]return transforms.CompositeTransform([transforms.SqueezeTransform(),*z])def getGLOW():num_channels = 1 * 4num_flow = 32num_scale = 3crop_size = 28 // 2transform = transforms.MultiscaleCompositeTransform(num_scale)for i in range(num_scale):next_input = transform.add_transform(getGlowScale(num_channels, num_flow, crop_size),[num_channels, crop_size, crop_size])num_channels *= 2crop_size //= 2return transformGlow_model = getGLOW()

我们可以用各种数据集来训练Glow模型,如MNIST、CIFAR-10、ImageNet等。本文为了演示方便,使用的是MNIST数据集。

像MNIST(https://gas.graviti.cn/dataset/data-decorators/MNIST)这样的数据集可以很容易地从格物钛开放数据集平台(https://gas.graviti.cn/open-datasets)获取,该平台包含了机器学习中所有常用的开放数据集,如分类、密度估计、物体检测和基于文本的分类数据集等。

要访问数据集,我们只需要在格物钛的平台上创建账户,就可以直接fork想要的数据集,可以直接下载或者使用格物钛提供的pipeline导入数据集。基本的代码和相关文档可在TensorBay的支持网页上获得(graviti.cn/tensorBay)。

结合格物钛TensorBay的Python SDK,我们可以很方便地导入MNIST数据集到PyTorch中:

from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transformsfrom tensorbay import GAS
from tensorbay.dataset import Dataset as TensorBayDatasetclass MNISTSegment(Dataset):def __init__(self, gas, segment_name, transform):super().__init__()self.dataset = TensorBayDataset("MNIST", gas)self.segment = self.dataset[segment_name]self.category_to_index = self.dataset.catalog.classification.get_category_to_index()self.transform = transformdef __len__(self):return len(self.segment)def __getitem__(self, idx):data = self.segment[idx]with data.open() as fp:image_tensor = self.transform(Image.open(fp))return image_tensor, self.category_to_index[data.label.classification.category]

模型训练

模型训练可以通过下面的代码简单开始。该代码使用格物钛TensorBay提供的Pipeline创建数据加载器,其中的ACCESS_KEY可以在TensorBay的账户设置中获得。

from nflows.distributions import normalACCESS_KEY = "Accesskey-*****"
EPOCH = 100to_tensor = transforms.ToTensor()
normalization = transforms.Normalize(mean=[0.485], std=[0.229])
my_transforms = transforms.Compose([to_tensor, normalization])train_segment = MNISTSegment(GAS(ACCESS_KEY), segment_name="train", transform=my_transforms)
train_dataloader = DataLoader(train_segment, batch_size=4, shuffle=True, num_workers=4)optimizer = torch.optim.Adam(Glow_model.parameters(), 1e-3)for epoch in range(EPOCH):for index, (image, label) in enumerate(train_dataloader):if index == 0:image_size = image.shaape[2]channels = image.shape[1]image = image.cuda()output, logabsdet = Glow_model._transform(image)shape = output.shape[1:]log_z = normal.StandardNormal(shape=shape).log_prob(output)loss = log_z + logabsdetloss = -loss.mean()/(image_size * image_size * channels)optimizer.zero_grad()loss.backward()optimizer.step()print(f"Epoch:{epoch+1}/{EPOCH} Loss:{loss}")

上面代码用的是MNIST数据集,要想使用其他数据集我们可以直接替换该数据集的数据加载器。

样例生成

模型训练完成之后,我们可以通过下面的代码来生成样例:

samples = Glow_model.sample(25)
display(samples)

使用nflows库之后,我们只需要用一行代码就可以生成样例,而display函数则能在一个网格中显示生成的样本。

用MNIST训练模型之后生成的样例

结语

本文向大家介绍了Normalizing Flow的基本知识,并与GAN和VAE进行了比较,同时向大家展示了Glow模型的基本工作方式。我们还讲解了如何简单实现Glow模型,并使用MNIST数据集进行训练。在格物钛公开数据集平台的帮助下,数据集访问变得十分便捷。

关于「格物钛」

格物钛定位为面向机器学习的数据平台,帮助AI开发者解决日益增长的非结构化数据难题。借助非结构化数据管理平台TensorBay和开源数据集社区Open Datasets,机器学习团队和个人可进行数据管理、查询、协同、可视化和版本控制等高效操作,降低高质量数据获取、存储和处理成本,加速AI开发和产品创新。

Open Datasets  ????

格物钛|公开数据集

graviti.cn/open-datasets

订阅号:格物钛  ????

微信号|Graviti_2019

微博|格物钛

https://www.graviti.cn/

点击阅读原文 / 访问格物钛官网

超越GAN?OpenAI提出可逆生成模型Glow!图像生成太逼真相关推荐

  1. Azure OpenAI 官方指南03|DALL-E 的图像生成功能与安全过滤机制

    2021年1月,OpenAI 推出 DALL-E.这是 GPT 模型在图像生成方面的人工智能应用.其名称来源于著名画家.艺术家萨尔瓦多 • 达利(Dalí)和机器人总动员(Wall-E).DALL-E ...

  2. ICCV 2019 | Lifelong GAN:基于持续学习的条件图像生成模型

    作者丨薛洁婷 学校丨北京交通大学硕士生 研究方向丨图像翻译 研究动机 持续学习是指希望模型能和人一样能基于过去的先验知识来快速准确的解决当前任务,然而对于人类而言与生俱来的能力对于模型来说却宛如大海捞 ...

  3. 从DDPM到GLIDE:基于扩散模型的图像生成算法进展

    前几天,OpenAI在Arxiv上挂出来了他们最新最强的文本-图像生成GLIDE [1],如头图所示,GLIDE能生成非常真实的结果.GLIDE并非基于对抗生成网络或是VQ-VAE类模型所设计,而是采 ...

  4. 有钱可以多任性?OpenAI提出人肉模型训练,文本摘要全面超越人类表现!

    文 | 小轶 背景 三个多月前,OpenAI的GPT-3在NLP界掀起轩然大波.就在上周,视金钱如粪土的OpenAI团队又在文本摘要方面推出了最新力作,全方位超越人类表现.其亮点在于:以人类偏好替代自 ...

  5. 换脸效果媲美GAN!一文解析OpenAI最新流生成模型「Glow」

    在碎片化阅读充斥眼球的时代,越来越少的人会去关注每篇论文背后的探索和思考. 在这个栏目里,你会快速 get 每篇精选论文的亮点和痛点,时刻紧跟 AI 前沿成果. 点击本文底部的「阅读原文」即刻加入社区 ...

  6. 生成模型VAE、GAN和基于流的模型详细对比

    在Ian Goodfellow和其他研究人员在一篇论文中介绍生成对抗网络两年后,Yann LeCun称对抗训练是"过去十年里ML最有趣的想法".尽管GANs很有趣,也很有前途,但它 ...

  7. 【GAN优化外篇】详解生成模型VAE的数学原理

    最近在学习生成模型的相关知识,这篇文章将介绍一下变分自编码器(Variational Auto-encoder),本文只介绍一些粗浅内容,不会涉及比较深刻的问题. 作者&编辑 | 小米粥 1. ...

  8. GAN、扩散模型应有尽有,CMU出品的生成模型专属搜索引擎Modelverse来了

    关注公众号,发现CV技术之美 本文转载自机器之心. 生成模型也有自己的搜索引擎啦! 从前几年出尽风头的 GAN 到今年独占鳌头的 Stable Diffusion,预训练生成模型一直风头不减,相关论文 ...

  9. GAN生成对抗网络之生成模型

    朋友们,如需转载请标明出处:http://blog.csdn.net/jiangjunshow 什么是生成模型? 在开始讲生成对抗网络之前,我们先看一下什么是生成模型.在概率统计理论中,生成模型是指能 ...

最新文章

  1. php 字符串数组转数组对象_php怎么将数组转成对象?
  2. 一键数据分析自动化特征工程!
  3. (Easy) Height Checker LeetCode
  4. 计算机科学千分之一高校,69所计算机科学进入ESI世界前1%的高校及排名,西电全国第四!...
  5. css3属性box-sizing:border-box 用法解析 击败边框:带border的百分比布局
  6. 普通IO口驱动LCD
  7. html js css如何关联_html+css +js 选项卡
  8. 刚发现的2011年最给力的春联
  9. 使用机器学习算法在 .NET Core 中运行的 100% C# 开源 AI 聊天机器人平台构建器...
  10. java.lang 源码剖析_java.lang.Void类源码解析
  11. [Voice communications] 音量的控制
  12. winform button设计(一)
  13. java jsm_JSM 基础
  14. matplotlib学习日记(八)----完善统计图
  15. Mac上双系统中Windows无法使用苹果鼠标键盘,如何处理
  16. Unite Beijing 2018 | 精彩游戏案例议题曝光
  17. video-js rtmp直播、this .el_vjs_getproperty问题、多个rtmp播放、可用rtmp地址
  18. OpenGL 驱动 与 扩展的关系
  19. 2022年汽车驾驶员(中级)理论题库模拟考试平台操作
  20. cordova应用兼容iOS11+和iPhoneX刘海屏

热门文章

  1. 树莓派我的世界服务器优化,用树莓派做我的世界服务端.md
  2. Principle教程:如何制作卡片左右划动动效
  3. c#里的bindingsource和bindingnavigator
  4. location 简介
  5. 批处理bat脚本if else条件语句
  6. layer tips的用法
  7. Xcode13在M1 Apple CPU上配置
  8. 微信#errcode:45047,errmsg:out of response count limit rid
  9. 原生android省电模式,Android 9.0 (P版本) MTK平台原生的省电模式自动开启
  10. layUi数据表格显示parsererror