BIGGAN代码以及训练参数,超级清晰版(CIFAR10数据集生成)
BIGGAN代码以及训练参数,超级清晰版
- 模型代码:
- 如何使用
超级清晰版
模型代码:
import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
import torch
import random
import numpy as np# 定义归一化函数
def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True):return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True)# sn全连接层(这样的目的是为了更加稳定)
def snlinear(in_features, out_features, bias=True):return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6)# sn卷积层(这样的目的是为了更加稳定)
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):return spectral_norm(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias),eps=1e-6)
# sn编码(这样的目的是为了更加稳定)
def sn_embedding(num_embeddings, embedding_dim):return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6)
# 归一化层?
class ConditionalBatchNorm2d(nn.Module):def __init__(self, in_features=148, out_features=384):super().__init__()self.in_features = in_featuresself.bn = batchnorm_2d(out_features, eps=1e-4, momentum=0.1, affine=False)self.gain = snlinear(in_features=in_features, out_features=out_features, bias=False)self.bias = snlinear(in_features=in_features, out_features=out_features, bias=False)def forward(self, x, y):gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)bias = self.bias(y).view(y.size(0), -1, 1, 1)out = self.bn(x)return out * gain + bias# 自我注意机制?
class SelfAttention(nn.Module):def __init__(self, in_channels, is_generator):super(SelfAttention, self).__init__()self.in_channels = in_channelsif is_generator:self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_attn = snconv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,stride=1, padding=0, bias=False)else:self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_attn = snconv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,stride=1, padding=0, bias=False)self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)self.softmax = nn.Softmax(dim=-1)self.sigma = nn.Parameter(torch.zeros(1), requires_grad=True)def forward(self, x):_, ch, h, w = x.size()# Theta paththeta = self.conv1x1_theta(x)theta = theta.view(-1, ch // 8, h * w)# Phi pathphi = self.conv1x1_phi(x)phi = self.maxpool(phi)phi = phi.view(-1, ch // 8, h * w // 4)# Attn mapattn = torch.bmm(theta.permute(0, 2, 1), phi)attn = self.softmax(attn)# g pathg = self.conv1x1_g(x)g = self.maxpool(g)g = g.view(-1, ch // 2, h * w // 4)# Attn_gattn_g = torch.bmm(g, attn.permute(0, 2, 1))attn_g = attn_g.view(-1, ch // 2, h, w)attn_g = self.conv1x1_attn(attn_g)return x + self.sigma * attn_g# 一个块
class GenBlock(nn.Module):def __init__(self):super(GenBlock, self).__init__()in_features = 148self.bn1 = ConditionalBatchNorm2d(in_features=in_features)self.bn2 = ConditionalBatchNorm2d(in_features=in_features)self.activation = nn.ReLU(inplace=True)self.conv2d0 = snconv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))self.conv2d1 = snconv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.conv2d2 = snconv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))def forward(self, x, affine):x0 = xx = self.bn1(x, affine)x = self.activation(x)x = F.interpolate(x, scale_factor=2, mode="nearest")x = self.conv2d1(x)x = self.bn2(x, affine)x = self.activation(x)x = self.conv2d2(x)x0 = F.interpolate(x0, scale_factor=2, mode="nearest")x0 = self.conv2d0(x0)out = x + x0return out# 生成网络 输入inputs.shape = tensor.size[(batch_size, 80)]
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear0 = snlinear(in_features=20, out_features=6144, bias=True)self.shared = nn.Embedding(10, 128)# 主要块self.blocks = nn.ModuleList()self.blocks.append(nn.ModuleList([GenBlock()])) # (0): ModuleListself.blocks.append(nn.ModuleList([GenBlock()])) # (1): ModuleListself.blocks.append(nn.ModuleList([SelfAttention(in_channels=384, is_generator=True)])) # (2): ModuleListself.blocks.append(nn.ModuleList([GenBlock()])) # (3): ModuleListself.bn4 = batchnorm_2d(in_features=384)self.activation = nn.ReLU(inplace=True)self.conv2d5 = snconv2d(384, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.tanh = nn.Tanh()# 关于cifar10的默认参数self.bottom = 4def forward(self, z, label=None):affine_list = []z0 = zzs = torch.split(z, 20, 1)z = zs[0]shared_label = self.shared(label)affine_list.append(shared_label)affines = [torch.cat(affine_list + [item], 1) for item in zs[1:]]act = self.linear0(z)act = act.view(-1, 384, self.bottom, self.bottom)counter = 0for index, blocklist in enumerate(self.blocks):for block in blocklist:if isinstance(block, SelfAttention):act = block(act)else:act = block(act, affines[counter])counter += 1act = self.bn4(act)act = self.activation(act)act = self.conv2d5(act)out = self.tanh(act)return out
如何使用
戳
BIGGAN代码以及训练参数,超级清晰版(CIFAR10数据集生成)相关推荐
- CIFAR-10模型训练python版cifar10数据集
在之前的博客中已经对CIFAR-10做了整体的解析,但是目前从tensorflow/models/tree/master/tutorials/image/cifar10中下载下来,运行cifar10_ ...
- YOLOV7开源代码讲解--训练参数解释
目录 训练参数说明: --weights: -- cfg: --data: --hpy: --epoch: --batch_size: --img-size: --rect: --resume: -- ...
- MediBang Paint Pro超级精简版/超精简/懂你版
原创:梦幻软件天堂 作者:心灵代码 MediBang Paint Pro 超级精简版/超精简/懂你版 由 梦幻软件天堂 发布.MediBang Paint Pro 超精简版 是一个是由日本 med ...
- 网页版的svn怎样同步代码_学会使用Hdlbits网页版Verilog代码仿真验证平台
大家推荐一款网页版的 Verilog代码编辑仿真验证平台,这个平台是国外的一家开源FPGA学习网站,通过 "https://hdlbits.01xz.net/wiki/Main_Page&q ...
- 手把手Maven搭建SpringMVC+Spring+MyBatis框架(超级详细版)
手把手Maven搭建SpringMVC+Spring+MyBatis框架(超级详细版) SSM(Spring+SpringMVC+Mybatis),目前较为主流的企业级架构方案.标准的MVC设计模式, ...
- BigGAN代码解读(gpt3.5的帮助)——谱正则化部分
BigGAN代码解读(gpt4.0的帮助)--谱正则化部分 作者个人记录学习 BigGAN中使用谱归一化对训练过程进行优化,在github中的代码中,使用了自己编写的谱归一化对卷积层.线性层以及Emb ...
- Cg教程_可编程实时图形权威指南(扫描清晰版)+部分unity shader 知识
Cg教程_可编程实时图形权威指南(扫描清晰版) .pdf (34.5 MB, 下载次数: 239) Shader Model(在 3D 图形领域常被简称SM)就是"优化渲染引擎模式&qu ...
- supervisord部署使用超级详细版
supervisord部署使用超级详细版 一. 安装 pip 命令(安装python 环境) 因为 supervisord本身是基于Python开发的,所以在使用时需要先安装Python 的运行环境 ...
- GitHub上YOLOv5开源代码的训练数据定义
GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...
最新文章
- Android渲染机制和丢帧分析
- 查找数组中第二个最小元素
- JavaScript——判断undefined解决方案
- vector-空间增长
- 使用HTML5的Canvas画布来剪裁用户头像
- IDEA——修改idea64.exe.vmoptions文件解决coding卡顿问题
- 20170505思考点--编写案例时是以功能为主还是业务为主要
- MySQL : mysql连接报 Communications link failure
- 模式分享 公众号_微信公众号+()模式营销!公众号还可以这样玩?
- qt读取txt文件内容
- ubunntu安装php7.0_乌班图Ubuntu 16.04下安装PHP 7过程详解
- Oracle批量修改字段长度
- 扩展欧几里得算法求逆元c语言,利用扩展欧几里得算法编程求逆元
- 对齐函数:ALIGN()
- 美团CAT客户端集成
- 巧用google实现快速搜索
- 谈谈基于深度相机的三维重建
- 从键盘读入10个的整数,判断正数和负数的个数
- 主成分分析应用之主成分回归
- 偏态分布学习笔记(期望,中位数,众数)