BIGGAN代码以及训练参数,超级清晰版

  • 模型代码:
  • 如何使用

超级清晰版

模型代码:

import torch.nn as nn
from torch.nn.utils import spectral_norm
import torch.nn.functional as F
import torch
import random
import numpy as np# 定义归一化函数
def batchnorm_2d(in_features, eps=1e-4, momentum=0.1, affine=True):return nn.BatchNorm2d(in_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=True)# sn全连接层(这样的目的是为了更加稳定)
def snlinear(in_features, out_features, bias=True):return spectral_norm(nn.Linear(in_features=in_features, out_features=out_features, bias=bias), eps=1e-6)# sn卷积层(这样的目的是为了更加稳定)
def snconv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):return spectral_norm(nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=kernel_size,stride=stride,padding=padding,dilation=dilation,groups=groups,bias=bias),eps=1e-6)
# sn编码(这样的目的是为了更加稳定)
def sn_embedding(num_embeddings, embedding_dim):return spectral_norm(nn.Embedding(num_embeddings=num_embeddings, embedding_dim=embedding_dim), eps=1e-6)
# 归一化层?
class ConditionalBatchNorm2d(nn.Module):def __init__(self, in_features=148, out_features=384):super().__init__()self.in_features = in_featuresself.bn = batchnorm_2d(out_features, eps=1e-4, momentum=0.1, affine=False)self.gain = snlinear(in_features=in_features, out_features=out_features, bias=False)self.bias = snlinear(in_features=in_features, out_features=out_features, bias=False)def forward(self, x, y):gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)bias = self.bias(y).view(y.size(0), -1, 1, 1)out = self.bn(x)return out * gain + bias# 自我注意机制?
class SelfAttention(nn.Module):def __init__(self, in_channels, is_generator):super(SelfAttention, self).__init__()self.in_channels = in_channelsif is_generator:self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_attn = snconv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,stride=1, padding=0, bias=False)else:self.conv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels // 8, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels // 2, kernel_size=1,stride=1, padding=0, bias=False)self.conv1x1_attn = snconv2d(in_channels=in_channels // 2, out_channels=in_channels, kernel_size=1,stride=1, padding=0, bias=False)self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)self.softmax = nn.Softmax(dim=-1)self.sigma = nn.Parameter(torch.zeros(1), requires_grad=True)def forward(self, x):_, ch, h, w = x.size()# Theta paththeta = self.conv1x1_theta(x)theta = theta.view(-1, ch // 8, h * w)# Phi pathphi = self.conv1x1_phi(x)phi = self.maxpool(phi)phi = phi.view(-1, ch // 8, h * w // 4)# Attn mapattn = torch.bmm(theta.permute(0, 2, 1), phi)attn = self.softmax(attn)# g pathg = self.conv1x1_g(x)g = self.maxpool(g)g = g.view(-1, ch // 2, h * w // 4)# Attn_gattn_g = torch.bmm(g, attn.permute(0, 2, 1))attn_g = attn_g.view(-1, ch // 2, h, w)attn_g = self.conv1x1_attn(attn_g)return x + self.sigma * attn_g# 一个块
class GenBlock(nn.Module):def __init__(self):super(GenBlock, self).__init__()in_features = 148self.bn1 = ConditionalBatchNorm2d(in_features=in_features)self.bn2 = ConditionalBatchNorm2d(in_features=in_features)self.activation = nn.ReLU(inplace=True)self.conv2d0 = snconv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))self.conv2d1 = snconv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.conv2d2 = snconv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))def forward(self, x, affine):x0 = xx = self.bn1(x, affine)x = self.activation(x)x = F.interpolate(x, scale_factor=2, mode="nearest")x = self.conv2d1(x)x = self.bn2(x, affine)x = self.activation(x)x = self.conv2d2(x)x0 = F.interpolate(x0, scale_factor=2, mode="nearest")x0 = self.conv2d0(x0)out = x + x0return out# 生成网络 输入inputs.shape = tensor.size[(batch_size, 80)]
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.linear0 = snlinear(in_features=20, out_features=6144, bias=True)self.shared = nn.Embedding(10, 128)# 主要块self.blocks = nn.ModuleList()self.blocks.append(nn.ModuleList([GenBlock()])) # (0): ModuleListself.blocks.append(nn.ModuleList([GenBlock()])) # (1): ModuleListself.blocks.append(nn.ModuleList([SelfAttention(in_channels=384, is_generator=True)])) # (2): ModuleListself.blocks.append(nn.ModuleList([GenBlock()])) # (3): ModuleListself.bn4 = batchnorm_2d(in_features=384)self.activation = nn.ReLU(inplace=True)self.conv2d5 = snconv2d(384, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))self.tanh = nn.Tanh()# 关于cifar10的默认参数self.bottom = 4def forward(self, z, label=None):affine_list = []z0 = zzs = torch.split(z, 20, 1)z = zs[0]shared_label = self.shared(label)affine_list.append(shared_label)affines = [torch.cat(affine_list + [item], 1) for item in zs[1:]]act = self.linear0(z)act = act.view(-1, 384, self.bottom, self.bottom)counter = 0for index, blocklist in enumerate(self.blocks):for block in blocklist:if isinstance(block, SelfAttention):act = block(act)else:act = block(act, affines[counter])counter += 1act = self.bn4(act)act = self.activation(act)act = self.conv2d5(act)out = self.tanh(act)return out

如何使用

BIGGAN代码以及训练参数,超级清晰版(CIFAR10数据集生成)相关推荐

  1. CIFAR-10模型训练python版cifar10数据集

    在之前的博客中已经对CIFAR-10做了整体的解析,但是目前从tensorflow/models/tree/master/tutorials/image/cifar10中下载下来,运行cifar10_ ...

  2. YOLOV7开源代码讲解--训练参数解释

    目录 训练参数说明: --weights: -- cfg: --data: --hpy: --epoch: --batch_size: --img-size: --rect: --resume: -- ...

  3. MediBang Paint Pro超级精简版/超精简/懂你版

    原创:梦幻软件天堂  作者:心灵代码 MediBang Paint Pro 超级精简版/超精简/懂你版 由 梦幻软件天堂 发布.MediBang Paint Pro  超精简版 是一个是由日本 med ...

  4. 网页版的svn怎样同步代码_学会使用Hdlbits网页版Verilog代码仿真验证平台

    大家推荐一款网页版的 Verilog代码编辑仿真验证平台,这个平台是国外的一家开源FPGA学习网站,通过 "https://hdlbits.01xz.net/wiki/Main_Page&q ...

  5. 手把手Maven搭建SpringMVC+Spring+MyBatis框架(超级详细版)

    手把手Maven搭建SpringMVC+Spring+MyBatis框架(超级详细版) SSM(Spring+SpringMVC+Mybatis),目前较为主流的企业级架构方案.标准的MVC设计模式, ...

  6. BigGAN代码解读(gpt3.5的帮助)——谱正则化部分

    BigGAN代码解读(gpt4.0的帮助)--谱正则化部分 作者个人记录学习 BigGAN中使用谱归一化对训练过程进行优化,在github中的代码中,使用了自己编写的谱归一化对卷积层.线性层以及Emb ...

  7. Cg教程_可编程实时图形权威指南(扫描清晰版)+部分unity shader 知识

      Cg教程_可编程实时图形权威指南(扫描清晰版) .pdf (34.5 MB, 下载次数: 239) Shader Model(在 3D 图形领域常被简称SM)就是"优化渲染引擎模式&qu ...

  8. supervisord部署使用超级详细版

    supervisord部署使用超级详细版 一. 安装 pip 命令(安装python 环境) 因为 supervisord本身是基于Python开发的,所以在使用时需要先安装Python 的运行环境 ...

  9. GitHub上YOLOv5开源代码的训练数据定义

    GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...

最新文章

  1. Android渲染机制和丢帧分析
  2. 查找数组中第二个最小元素
  3. JavaScript——判断undefined解决方案
  4. vector-空间增长
  5. 使用HTML5的Canvas画布来剪裁用户头像
  6. IDEA——修改idea64.exe.vmoptions文件解决coding卡顿问题
  7. 20170505思考点--编写案例时是以功能为主还是业务为主要
  8. MySQL : mysql连接报 Communications link failure
  9. 模式分享 公众号_微信公众号+()模式营销!公众号还可以这样玩?
  10. qt读取txt文件内容
  11. ubunntu安装php7.0_乌班图Ubuntu 16.04下安装PHP 7过程详解
  12. Oracle批量修改字段长度
  13. 扩展欧几里得算法求逆元c语言,利用扩展欧几里得算法编程求逆元
  14. 对齐函数:ALIGN()
  15. 美团CAT客户端集成
  16. 巧用google实现快速搜索
  17. 谈谈基于深度相机的三维重建
  18. 从键盘读入10个的整数,判断正数和负数的个数
  19. 主成分分析应用之主成分回归
  20. 偏态分布学习笔记(期望,中位数,众数)

热门文章

  1. 【DDR3 控制器设计】系列博客汇总篇(附直达链接)
  2. 我的世界是一款自由度非常高的游戏,你玩过吗?
  3. VC2015编译旧工程找不到头文件stdio.h
  4. Oracle EBS R12 IE兼容Java插件(多版本)相关设置
  5. MathJax常用符号
  6. android 微信支付键盘,Android 高仿微信支付键盘
  7. MSN空间日志发布项灰色解决方法
  8. oracle中常用关键字,oracle常用函数及关键字笔记
  9. 2021最后一次Java面试,java工程师职业生涯规划
  10. ubuntu18.04 安装JLinkOB驱动以及问题解决