下载:

git clone https://github.com/yunjey/StarGAN.git
1
cd StarGAN/
1
下载celebA训练数据:

bash download.sh
1
训练:

python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \
                 --sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \
                 --model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'
1
2
3
代码分析
生成网络
第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,

layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))
1
2
3
4
2个卷积层,stride=2,即下采样,

# Down-Sampling
curr_dim = conv_dim
for i in range(2):
    layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim * 2
1
2
3
4
5
6
7
残差层,

# Bottleneck
for i in range(repeat_num):
    layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
1
2
3
残差网络结构,

class ResidualBlock(nn.Module):
    """Residual Block."""
    def __init__(self, dim_in, dim_out):
        super(ResidualBlock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
            nn.InstanceNorm2d(dim_out, affine=True))

def forward(self, x):
        return x + self.main(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
上采样,

# Up-Sampling
for i in range(2):
    layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
    layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
    layers.append(nn.ReLU(inplace=True))
    curr_dim = curr_dim // 2
1
2
3
4
5
6
最后一层,得到输出维度为3,

layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
1
2
self.main = nn.Sequential(*layers)
1
对于输入图像x,label向量c,串联如下,

def forward(self, x, c):
    # replicate spatially and concatenate domain information
    c = c.unsqueeze(2).unsqueeze(3)
    c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
    x = torch.cat([x, c], dim=1)
    return self.main(x)
1
2
3
4
5
6
判别网络
判别网络输入为图像,用于判别输入图像真假,已经输入图像的类别,

class Discriminator(nn.Module):
    """Discriminator. PatchGAN."""
    def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
        super(Discriminator, self).__init__()

layers = []
        layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
        layers.append(nn.LeakyReLU(0.01, inplace=True))

curr_dim = conv_dim
        for i in range(1, repeat_num):
            layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
            layers.append(nn.LeakyReLU(0.01, inplace=True))
            curr_dim = curr_dim * 2

k_size = int(image_size / np.power(2, repeat_num))
        self.main = nn.Sequential(*layers)
        self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)

def forward(self, x):
        h = self.main(x)
        out_real = self.conv1(h)
        out_aux = self.conv2(h)
        return out_real.squeeze(), out_aux.squeeze()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
conv1输出维度为1,即判别输入的真假,conv2输出维度为c_dim,即判别输入图像的label.

训练数据,损失函数,参数更新,
输入包括

real_x,real_c,fake_c

fake_c为随机生成的,

# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
                    real_c = real_label.clone()
                    fake_c = fake_label.clone()
1
2
3
4
5
6
训练判别网络
将真实图像输入判别网络,

# Compute loss with real images
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src)
1
2
3
判别网络的输入为真实图像,输出out_cls为真实图像对应的标签的概率,则可以计算交叉损失熵,

if self.dataset == 'CelebA':
    d_loss_cls = F.binary_cross_entropy_with_logits(
        out_cls, real_label, size_average=False) / real_x.size(0)
1
2
3
将真实图像输入real_x和假的标签fake_c输入生成网络,得到生成图像fake_x,

fake_x = self.G(real_x, fake_c)
1
将生成图像输入判别网络,

fake_x = Variable(fake_x.data)
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src)
1
2
3
总的损失函数为,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
1
2
根据d_loss更新判别网络参数,

# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
计算梯度惩罚因子alpha,根据alpha结合real_x,fake_x,输入判别网络,计算梯度,得到梯度损失函数,

# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)

grad = torch.autograd.grad(outputs=out,
                           inputs=interpolated,
                           grad_outputs=torch.ones(out.size()).cuda(),
                           retain_graph=True,
                           create_graph=True,
                           only_inputs=True)[0]

grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
根据梯度损失函数d_loss_gp优化判别网路,

# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
训练生成网络
生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像.

将原图像输入生成网络,得到生成图像fake_x,同时将fake_x图像输入生成网络,希望生成的图像与真实图像尽量相似,

# Original-to-target and target-to-original domain
fake_x = self.G(real_x, fake_c)
rec_x = self.G(fake_x, real_c)
# Compute losses
g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
1
2
3
4
5
将fake_x输入判别网路,

out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)
1
2
计算损失函数,

g_loss_fake = - torch.mean(out_src)
1
对于fake_x,对应的label为fake_label,将fake_x输入判别网络,判别网络预测label概率为out_cls,因此可以计算交叉损失熵,

g_loss_cls = F.binary_cross_entropy_with_logits(
    out_cls, fake_label, size_average=False) / fake_x.size(0)
1
2
生成网络参数更新,

# Backward + Optimize
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
1
2
3
4
5
训练数据处理
以celebA数据为例,下载后的数据包括label文件,和图像.

文件的第一行为图像的总数,为202599.

第二行为数据处理的类别,包括40种,

5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young

第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1,之后提取为值1为1,-1为0.

000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1

list_attr_celeba.txt文件提取函数为,

def preprocess(self):
    attrs = self.lines[1].split()
    for i, attr in enumerate(attrs):
        self.attr2idx[attr] = i
        self.idx2attr[i] = attr

self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
    self.train_filenames = []
    self.train_labels = []
    self.test_filenames = []
    self.test_labels = []

lines = self.lines[2:]#the image and labels
    random.shuffle(lines)   # random shuffling
    for i, line in enumerate(lines):

splits = line.split()
        filename = splits[0]#image name
        values = splits[1:]# labels

label = []
        for idx, value in enumerate(values):
            attr = self.idx2attr[idx]# there are 40 classes,find the idx th class name

if attr in self.selected_attrs:#check if the attr in the selected classes
                if value == '1':#if the ckss label is 1 then label equal 2,otherwise,0
                    label.append(1)
                else:
                    label.append(0)

if (i+1) < 2000:
            self.test_filenames.append(filename)
            self.test_labels.append(label)
        else:
            self.train_filenames.append(filename)
            self.train_labels.append(label)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
self.selected_attrs表示我们训练选用的任务类别集合.最后得到图像名数组self.train_filenames,及其对应的label数组 self.train_labels.

之后采用from torch.utils.data import DataLoader加载训练数据,

data_loader = DataLoader(dataset=dataset,
                         batch_size=batch_size,
                         shuffle=shuffle)
1
2
3
fixed_x = []
real_c = []
for i, (images, labels) in enumerate(self.data_loader):
    fixed_x.append(images)
    real_c.append(labels)
    if i == 3:
        break
1
2
3
4
5
6
7
读取后的图像数组为fixed_x,lable为real_c.图像为(bath_size,c_dim,imagesize,imagesize),label为(batch_size,len(self.selected_attrs)).

得到固定的输入图像数组,label,labelist,用于sample.

# Fixed inputs and target domain labels for debugging
fixed_x = torch.cat(fixed_x, dim=0)#4*batch_szie,(64,3,128,128)
fixed_x = self.to_var(fixed_x, volatile=True)
real_c = torch.cat(real_c, dim=0)

if self.dataset == 'CelebA':
    fixed_c_list = self.make_celeb_labels(real_c)
1
2
3
4
5
6
7
labellist生成函数为,

def make_celeb_labels(self, real_c):
    """Generate domain labels for CelebA for debugging/testing.

if dataset == 'CelebA':
        return single and multiple attribute changes
    elif dataset == 'Both':
        return single attribute changes
    """
    y = [torch.FloatTensor([1, 0, 0]),  # black hair
         torch.FloatTensor([0, 1, 0]),  # blond hair
         torch.FloatTensor([0, 0, 1])]  # brown hair

fixed_c_list = []

# single attribute transfer
    for i in range(self.c_dim):
        fixed_c = real_c.clone()
        for c in fixed_c:
            if i < 3:
                c[:3] = y[i]
            else:
                c[i] = 0 if c[i] == 1 else 1   # opposite value
        fixed_c_list.append(self.to_var(fixed_c, volatile=True))

# multi-attribute transfer (H+G, H+A, G+A, H+G+A)
    if self.dataset == 'CelebA':
        for i in range(4):
            fixed_c = real_c.clone()
            for c in fixed_c:
                if i in [0, 1, 3]:   # Hair color to brown
                    c[:3] = y[2] 
                if i in [0, 2, 3]:   # Gender
                    c[3] = 0 if c[3] == 1 else 1
                if i in [1, 2, 3]:   # Aged
                    c[4] = 0 if c[4] == 1 else 1
            fixed_c_list.append(self.to_var(fixed_c, volatile=True))
    return fixed_c_list
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
fixed_c_list长度为c_dim+4=5+4=9,

训练的时候,fake_label为随机产生0-batch_size的索引,并由索引,从real_label取值,

# Start training
start_time = time.time()
for e in range(start, self.num_epochs):
    for i, (real_x, real_label) in enumerate(self.data_loader):

# Generat fake labels randomly (target domain labels)
        rand_idx = torch.randperm(real_label.size(0))
        fake_label = real_label[rand_idx]

if self.dataset == 'CelebA':
            real_c = real_label.clone()
            fake_c = fake_label.clone()
        else:
            real_c = self.one_hot(real_label, self.c_dim)
            fake_c = self.one_hot(fake_label, self.c_dim)

# Convert tensor to variable
        real_x = self.to_var(real_x)#(16,3,128,128)
        real_c = self.to_var(real_c) #(16,5)          # input for the generator
        fake_c = self.to_var(fake_c)#(16,5)
        real_label = self.to_var(real_label)   # this is same as real_c if dataset == 'CelebA'
        fake_label = self.to_var(fake_label)
--------------------- 
作者:imperfect00 
来源:CSDN 
原文:https://blog.csdn.net/u011961856/article/details/78697863 
版权声明:本文为博主原创文章,转载请附上博文链接!

starGAN原理代码分析相关推荐

  1. 头条巨量快手广点通等平台APPAPI回传事件注册激活-转化联调-API对接原理代码分析和功能实现

    2022年最新的头条巨量快手广点通等各推广平台APP&API回传事件-转化联调注册激活-API对接原理代码分析和功能实现! 在商户推广管理当中,经常碰到 需要将用户行为数据回传给推广平台.做转 ...

  2. vu2响应式原理 代码分析

    随着vue3的发布和运行,vue2 的知识也不能忘却.温习一下vue2的原理用代码来解说. 我们熟悉vue的都知道,在vue2 中是不能直接监测到数据的新增和删除的.所以也有一定的方法,我们要知道这些 ...

  3. 深度学习原理—代码分析线性分类与神经网络分类的区别

    https://www.toutiao.com/a6687727778487337476/ 利用sklearn.dataset随机产生数据,随机生成两类数据,用不同的颜色展示出来,如下图: 产生的随机 ...

  4. vu3响应式原理 代码分析

    上个文章我们说了vue2的原理,看这里. 现在简单说一下vue3的原理. vue3 建议大家还是多多看看官网,毕竟语法都变了,虽然兼容vue2,但是最好按照官网说的取用vue3 .不然会有一系列的报错 ...

  5. 手写内存池以及原理代码分析【C语言】

    内存池是对堆进行管理 当进程执行时,操作系统会分出0~4G的虚拟内存空间给进程,程序员可以自行管理(分配.释放)的部分就是mmap映射区.heap堆区,而内存池管理的部分就是用户进程的堆区. 为什么要 ...

  6. OpenStack 虚拟机冷/热迁移的实现原理与代码分析

    目录 文章目录 目录 前文列表 冷迁移代码分析(基于 Newton) Nova 冷迁移实现原理 热迁移代码分析 Nova 热迁移实现原理 向 libvirtd 发出 Live Migration 指令 ...

  7. 对dpdk的rte_ring实现原理和代码分析

    对dpdk的rte_ring实现原理和代码分析 前言 dpdk的rte_ring是借鉴了linux内核的kfifo实现原理,这里统称为无锁环形缓冲队列. 环形缓冲区通常有一个读指针和一个写指针.读指针 ...

  8. TrueCrypt 6.2a原理及代码分析

    TrueCrypt 6.2a原理及代码分析 3 comments 25th Apr 10 rafa 1 项目物理布局 Project     |____ Boot /* MBR部分的代码 */     ...

  9. 免费的Lucene 原理与代码分析完整版下载

    Lucene是一个基于Java的高效的全文检索库. 那么什么是全文检索,为什么需要全文检索? 目前人们生活中出现的数据总的来说分为两类:结构化数据和非结构化数据.很容易理解,结构化数据是有固定格式和结 ...

最新文章

  1. 通过anaconda2安装python2.7和安装pytorch
  2. java正则表达式课程_通过此免费课程学习正则表达式
  3. c语言链表创建递归,递归创建二叉树c语言实现+详细解释
  4. pageContext.findAttribute()与pageContext.getAttribute()的区别
  5. python箱线图_Python 箱线图 plt.boxplot() 参数详解
  6. boost线程之类成员函数
  7. centos安装nginx步骤
  8. 【搬砖】【Python数据分析】Pycharm中plot绘图不能显示出来
  9. centos php 开启socket,CentOS 配置PHP支持socket扩展
  10. Kaggle 发布首份数据科学从业报告 | 不及美国同行1/3,中国数据科学家平均年薪约3万美元
  11. mysql 装载dump文件_mysql命令、mysqldump命令找不到解决
  12. PYTHON_SPLIT
  13. 几款好用的录屏软件推荐
  14. treetable怎么带参数_treeTable的使用(ajax异步获取数据,动态渲染treeTable)
  15. sim868 建立tcp链接时的步骤所对应hex码
  16. 2048游戏规则及玩法技巧攻略
  17. map拼接URL参数
  18. 记一个 Harvester SNAT 案例
  19. 链表的概念以及它的作用
  20. mock_httpserver

热门文章

  1. 夏日php登录系统源码,夏日PHP企业管理系统 v0.1
  2. 计算机不能启动 如何排除故障,开工发现电脑无法开机 如何排查故障?
  3. android电视视频app下载,央视频APP智能电视版下载-央视频电视版客户端 1.9.0.53139 安卓版-玩友游戏网...
  4. centos7 nginx安装_手把手教你PHP(一) Centos7上的LEMP配置
  5. common pool2 mysql_连接池Commons Pool2的使用
  6. 用法 stl_PoEdu培训第四课-C++之STL
  7. 中介者模式 调停者 Mediator 行为型 设计模式(二十一)
  8. 如何成为一名Android架构师,乃至高级架构师,文末有路线图
  9. Idea--Tomcat配置中的On Upate Action 与 On Frame Deactivation
  10. 只改一个值!马上加快宽带上网速度