上一期为大家说明了什么是极大似然法,以及如何使用极大似然法搭建生成模型,本期将为大家介绍第一个显式生成模型完全可见置信网络FVBN。

作者&编辑 | 小米粥

1 完全可见置信网络

在完全可见置信网络中,不存在不可观察的潜在变量,观察变量的概率被链式法则从维度上进行分解,对于 n 维观察变量x ,其概率表达式为:

自回归网络是最简单的完全可见置信网络,其中每一个维度的观察变量都构成概率模型的一个节点,而这些所有的节点{x1,x2,...,xn}共同构成一个完全有向图,即图中任意两个节点都存在连接关系,如图所示。

在自回归网络中,因为已经有了随机变量的链式分解关系,那么核心问题便成为如何表达条件概率p(xi|xi-1,xx-2,...,x1) 。最简单的模型是线性自回归网络,即每个条件概率均被定义为线性模型,对实数值数据使用线性回归模型(例如定义 p(xi|xi-1,xx-2,...,x1)= w1x1+w2x2+...+wi-1xi-1 ,对二值数据使用逻辑回归,而对离散数据使用softmax回归,其计算过程如下图。

但线性模型容量有限,拟合函数的能力不足。在神经自回归网络中,使用神经网络代替线性模型,它可以任意增加容量,理论上可以拟合任意联合分布。神经自回归网络还使用了特征重用的技巧,神经网络从观察变量 xi 学习到的隐藏抽象特征 hi 不仅在计算p(xi+1|xi,xi-1,...,x1)时使用,也会在计算p(xi+2|xi+1,xi,...,x1)时进行重用,其计算图如下所示,并且该模型不需要将每个条件概率的计算都分别使用不同神经网络表示,可以将所有神经网络整合为一个,因此只要设计成抽象特征hi只依赖于x1,x2,...,xi即可。而目前的神经自回归密度估计器是神经自回归网络中最具有代表性的方案,它是在神经自回归网络中引入了参数共享的方案,即从观察变量xi到任意隐藏抽象特征 hi+1,hi+2,... 的权值参数是共享的,使用了特征重用、参数共享等深度学习技巧的神经自回归密度估计器具有非常优秀的性能。

PixelRNN和PixelCNN也属于完全可见置信网络,从名字可以看出,这两个模型一般用于图像的生成。它们将图像x的概率p(x)按照像素分解为 n 个条件概率的乘积,其中n为图像的像素点个数,即在每一个像素点上定义了一个条件概率用以表达像素之间的依赖关系,该条件概率分别使用RNN或者CNN进行学习。为了将输出离散化,通常将RNN或CNN的最后一层设置为softmax层,用以表示其输出不同像素值的概率。在PixelRNN中,一般定义从左上角开始沿着右方和下方依次生成每一个像素点,如下图所示。这样,对数似然的表达式便可以得到,训练模型时只需要将其极大化即可。

PixelRNN在其感受野内可能具有无边界的依赖范围,因为待求位置的像素值依赖之前所有已知像素点的像素值,这将需要大量的计算代价,PixelCNN使用标准卷积层来捕获有界的感受野,其训练速度要快于PixelRNN。在PixelCNN中,每个位置的像素值仅与其周围已知像素点的值有关,如下图所示。灰色部分为已知像素,而白色部分为未知像素,计算黑色位置的像素值时,需要把方框区域内的所有灰色像素值传递给CNN,由CNN最后的softmax输出层来表达表在黑色位置取不同像素值的概率,这里可以使用由0和1构成的掩模矩阵将方框区域内的白色位置像素抹掉。PixelRNN和PixelCNN此后仍有非常多改进模型,但由于它是逐个像素点地生成图片,具有串行性,故在实际应用中效率难以保证,这也是FVBN模型的通病。

2 pixelCNN 代码

接下来我们将提供一份完整的pixelCNN的代码讲解,其中训练集为mnist数据集。

首先读取相关python库,设置训练参数:

# 读取相关库

import time

import torch

import torch.nn.functional as F

from torch import nn, optim, cudafrom torch.utils

import datafrom torchvision import datasets, transforms, utils

# 设置训练参数

train_batch_size = 256

generation_batch_size = 48

epoch_number = 25feature_dim = 64

# 是否使用GPU

if torch.cuda.is_available():

device = torch.device('cuda:0')

else:

device = torch.device('cpu')

然后定义二维掩膜卷积,所谓掩膜即使卷积中心的右方和下方的权值为0,如下图所示为3x3掩膜卷积核(A型):

定义二维掩膜卷积核,其中有A与B两种类型,区别之处在于中心位置是否被卷积计算:

class MaskedConv2d(nn.Conv2d):

def __init__(self, mask_type, *args, **kwargs):

super(MaskedConv2d, self).__init__(*args, **kwargs)

assert mask_type in {'A', 'B'}

self.register_buffer('mask', self.weight.data.clone())

bs, o_feature_dim, kH, kW = self.weight.size()

self.mask.fill_(1)

self.mask[:, :, kH // 2, kW // 2 + (mask_type == 'B'):] = 0

self.mask[:, :, kH // 2 + 1:] = 0

def forward(self, x):

self.weight.data *= self.mask

return super(MaskedConv2d, self).forward(x)

我们的pixelCNN网络为多层掩膜卷积的堆叠,即:

network = nn.Sequential(

MaskedConv2d('A',1,feature_dim,7,1,3, bias=False),nn.BatchNorm2d(feature_dim),nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),

MaskedConv2d('B', feature_dim, feature_dim, 7, 1, 3, bias=False), nn.BatchNorm2d(feature_dim), nn.ReLU(True),    nn.Conv2d(feature_dim, 256, 1))

network.to(device)

接着设置dataloader和优化器:

train_data = data.DataLoader(datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),                     batch_size=train_batch_size, shuffle=True, num_workers=1, pin_memory=True)

test_data = data.DataLoader(datasets.MNIST('data', train=False, download=True, transform=transforms.ToTensor()),                     batch_size=train_batch_size, shuffle=False, num_workers=1, pin_memory=True)

optimizer = optim.Adam(network.parameters())

开始训练网络,并在每一轮epoch后进行测试和生成样本

if __name__ == "__main__":

for epoch in range(epoch_number):

# 训练

cuda.synchronize()

network.train(True)

for input_image, _ in train_data:

time_tr = time.time()

input_image = input_image.to(device)

output_image = network(input_image)

target = (input_image.data[:, 0] * 255).long().to(device)

loss = F.cross_entropy(output_image, target)

optimizer.zero_grad()

loss.backward()

optimizer.step()

print("train: {} epoch, loss: {}, cost time: {}".format(epoch, loss.item(), time.time() - time_tr))        cuda.synchronize()

# 测试

with torch.no_grad():

cuda.synchronize()

time_te = time.time()

network.train(False)

for input_image, _ in test_data:                                              input_image = input_image.to(device)

target = (input_image.data[:, 0] * 255).long().to(device)

loss = F.cross_entropy(network(input_image), target)

cuda.synchronize()

time_te = time.time() - time_te

print("test: {} epoch, loss: {}, cost time: {}".format(epoch, loss.item(), time_te))

# 生成样本

with torch.no_grad():

image = torch.Tensor(generation_batch_size, 1, 28, 28).to(device)

image.fill_(0)

network.train(False)

for i in range(28):

for j in range(28):

out = network(image)

probs = F.softmax(out[:, :, i, j]).data

image[:, :, i, j] = torch.multinomial(probs, 1).float() / 255.

utils.save_image(image, 'generation-image_{:02d}.png'.format(epoch), nrow=12, padding=0)

[1] Oord A V D , Kalchbrenner N , Kavukcuoglu K . Pixel Recurrent Neural Networks[J]. 2016.

[2] 伊恩·古德费洛, 约书亚·本吉奥, 亚伦·库维尔. 深度学习

总结

本期带大家学习了第一种显式生成模型完全可见置信网络,并对其中的自回归网络和pixelRNN,pixelCNN做了讲解,并讲解了一份完整的pixelCNN代码。下一期我们将对第二个显式模型流模型进行讲解。

个人知乎,欢迎关注

GAN群

有三AI建立了一个GAN群,便于有志者相互交流。感兴趣的同学也可以微信搜索xiaozhouguo94,备注“加入有三-GAN群”。

更多GAN的学习

知识星球是有三AI的付费内容社区,里面包超过100种经典GAN模型的解读,了解详细请阅读以下文章:

【杂谈】有三AI知识星球指导手册出炉!和公众号相比又有哪些内容?

有三AI秋季划GAN学习小组,可长期跟随有三学习GAN相关的内容,并获得及时指导,了解详细请阅读以下文章:

【杂谈】如何让2020年秋招CV项目能力更加硬核,可深入学习有三秋季划4大领域32个方向

转载文章请后台联系

侵权必究

往期精选

  • 【GAN优化】GAN优化专栏上线,首谈生成模型与GAN基础

  • 【GAN的优化】从KL和JS散度到fGAN

  • 【GAN优化】详解对偶与WGAN

  • 【GAN优化】详解SNGAN(频谱归一化GAN)

  • 【GAN优化】一览IPM框架下的各种GAN

  • 【GAN优化】GAN优化专栏栏主小米粥自述,脚踏实地,莫问前程

  • 【GAN优化】GAN训练的几个问题

  • 【GAN优化】GAN训练的小技巧

  • 【GAN优化】从动力学视角看GAN是一种什么感觉?

  • 【GAN优化】小批量判别器如何解决模式崩溃问题

  • 【GAN优化】长文综述解读如何定量评价生成对抗网络(GAN)

  • 【技术综述】有三说GANs(上)

  • 【模型解读】历数GAN的5大基本结构

  • 【百战GAN】如何使用GAN拯救你的低分辨率老照片

  • 【百战GAN】二次元宅们,给自己做一个专属动漫头像可好!

  • 【百战GAN】羡慕别人的美妆?那就用GAN复制粘贴过来

  • 【百战GAN】GAN也可以拿来做图像分割,看起来效果还不错?

  • 【百战GAN】新手如何开始你的第一个生成对抗网络(GAN)任务

  • 【百战GAN】自动增强图像对比度和颜色美感,GAN如何做?

  • 【直播回放】80分钟剖析GAN如何从各个方向提升图像的质量

  • 【直播回放】60分钟剖析GAN如何用于人脸的各种算法

【生成模型】解读显式生成模型之完全可见置信网络FVBN相关推荐

  1. (转)c++模版:包含模型、显式实例化、分离模型

    c++模版:包含模型.显式实例化.分离模型 大多数c和c++程序员会这样的组织他们的非模板代码:类和其他类型放在头文件中,对于全局变量和(非内联)函数,只有声明放在头文件中,定义则位于.cpp文件中, ...

  2. C++模版:包含模型、显式实例化、分离模型(详解)

    C++模版:包含模型.显式实例化.分离模型 函数和类类型声明和定义的实质 非模板类类型的分文件定义 test.cpp: #include "test.h" #include < ...

  3. NLP-生成模型-2016-生成式摘要模型:Seq2Seq+Attention+Copy【Pointer网络的Copy机制解决Decoder端的OOV问题】【抽取式+生成式】【第一个生成式摘要模型】

    <原始论文:Abstractive Text Summarization Using Sequence-to-Sequence RNNs and Beyond> Seq2Seq(BiGRU ...

  4. 目标检测模型设计准则 | YOLOv7参考的ELAN模型解读,YOLO系列模型思想的设计源头

    转载: https://mp.weixin.qq.com/s/5SjQvRqRct6ClpE2eEcdkw 设计高效.高质量的表达性网络架构一直是深度学习领域最重要的研究课题.当今的大多数网络设计策略 ...

  5. sequelize模型关联_Node.js Sequelize 模型(表)之间的关联及关系模型的操作

    Sequelize模型之间存在关联关系,这些关系代表了数据库中对应表之间的主/外键关系.基于模型关系可以实现关联表之间的连接查询.更新.删除等操作.本文将通过一个示例,介绍模型的定义,创建模型关联关系 ...

  6. 【Django入门】——模型管理器对象、模型管理器类和模型类

    文章目录 一.模型管理器对象 1. 自定义模型管理器对象 2. 自定义模型管理器类 3. 自定义模型管理器类应用 3.1 重写框架的方法 3.2 封装自定义方法 4. 模型管理器对象的`model`属 ...

  7. VTK和ParaView中引入了显式结构化网格表达地质网格

    Introducing Explicit Structured Grids in VTK and ParaView - Kitware Blog 1.简介 新版本的vtk引入了适用于油藏角点网格模型的 ...

  8. 最新3D GAN可生成三维几何数据了!模型速度提升7倍,英伟达斯坦福出品

    明敏 发自 凹非寺 量子位 报道 | 公众号 QbitAI 2D图片变3D,还能给出3D几何数据? 英伟达和斯坦福大学联合推出的这个GAN,真是刷新了3D GAN的新高度. 而且生成画质也更高,视角随 ...

  9. PSGAN——姿态稳健型可感知空间式生成对抗网络论文详细解读与整理

    PSGAN--姿态稳健型可感知空间式生成对抗网络论文详细解读与整理 1.摘要 2.什么是PSGAN? 3.主要贡献 4.整体模块 5.目标函数 6.实验结果--部分化妆和插值化妆 7.定量比较 8.参 ...

最新文章

  1. win10服务器权限修改时间,win10系统修改时间显示没权限的解决方案
  2. Swift类与结构、存储属性、计算属性、函数与方法、附属脚本等
  3. Nagios监控linux服务器
  4. swift1.2语言函数和闭包函数介绍
  5. python爬取慕课视频-Python爬虫抓取技术的门道
  6. 如何设计一门语言(一)——什么是坑(a)
  7. jsp mysql连接池 回收_mysql连接池连接JSP
  8. 今天分享的案例是关于某电商店铺的年终销售业绩
  9. linux典型压缩包操作 tar打包、压缩与解压
  10. 红蓝宝书1000题 新日本语能力考试N1文字.词汇.文法 练习+详解
  11. Tokenized的设计哲学(三)
  12. Linux——shell脚本的基础篇(变量定义、变量种类、变量操作)
  13. 埃氏筛_四种形式 ( pass even , bool , char , differently judge )
  14. 面试经典-你为什么觉得自己能够在这个职位上取得成就?
  15. OLED(1)与LDC区别
  16. python 自动执行 apdl_【转载】利用VB生成APDL文件 和Python文件的方法
  17. C语言家谱管理程序,C语言二叉树家谱管理系统.doc
  18. Domino NotesV11开放下载啦!
  19. 银河麒麟搭建nodejs环境
  20. mysql table plugin_MySQL 启动报错Table 'mysql.plugin' doesn't exis(转载)

热门文章

  1. 信息系统项目管理知识--项目管理一般知识
  2. java中Date与String的相互转化
  3. 一天搞定CSS: overflow--14
  4. dropout层_DNN,CNN和RNN的12种主要dropout方法的数学和视觉解释
  5. python画切片图_python|Python图片常用操作-索引与切片
  6. (Java)注解和反射
  7. 回溯算法 | 追忆那些年曾难倒我们的八皇后问题
  8. 把所有的谎言献给你β(找规律数学题)
  9. Spring之Bean的配置(一)
  10. Java 动态代理介绍及用法