starGAN原理代码分析
下载:
git clone https://github.com/yunjey/StarGAN.git
1
cd StarGAN/
1
下载celebA训练数据:
bash download.sh
1
训练:
python main.py --mode='train' --dataset='CelebA' --c_dim=5 --image_size=128 \
--sample_path='stargan_celebA/samples' --log_path='stargan_celebA/logs' \
--model_save_path='stargan_celebA/models' --result_path='stargan_celebA/results'
1
2
3
代码分析
生成网络
第一个卷积层,输入为图像和label的串联,3表示图像为3通道,c_dim为label的维度,
layers = []
layers.append(nn.Conv2d(3+c_dim, conv_dim, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.InstanceNorm2d(conv_dim, affine=True))
layers.append(nn.ReLU(inplace=True))
1
2
3
4
2个卷积层,stride=2,即下采样,
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim*2, affine=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
1
2
3
4
5
6
7
残差层,
# Bottleneck
for i in range(repeat_num):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim))
1
2
3
残差网络结构,
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out):
super(ResidualBlock, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=True))
def forward(self, x):
return x + self.main(x)
1
2
3
4
5
6
7
8
9
10
11
12
13
上采样,
# Up-Sampling
for i in range(2):
layers.append(nn.ConvTranspose2d(curr_dim, curr_dim//2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim//2, affine=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
1
2
3
4
5
6
最后一层,得到输出维度为3,
layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
layers.append(nn.Tanh())
1
2
self.main = nn.Sequential(*layers)
1
对于输入图像x,label向量c,串联如下,
def forward(self, x, c):
# replicate spatially and concatenate domain information
c = c.unsqueeze(2).unsqueeze(3)
c = c.expand(c.size(0), c.size(1), x.size(2), x.size(3))
x = torch.cat([x, c], dim=1)
return self.main(x)
1
2
3
4
5
6
判别网络
判别网络输入为图像,用于判别输入图像真假,已经输入图像的类别,
class Discriminator(nn.Module):
"""Discriminator. PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, c_dim=5, repeat_num=6):
super(Discriminator, self).__init__()
layers = []
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, repeat_num):
layers.append(nn.Conv2d(curr_dim, curr_dim*2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
k_size = int(image_size / np.power(2, repeat_num))
self.main = nn.Sequential(*layers)
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=3, stride=1, padding=1, bias=False)
self.conv2 = nn.Conv2d(curr_dim, c_dim, kernel_size=k_size, bias=False)
def forward(self, x):
h = self.main(x)
out_real = self.conv1(h)
out_aux = self.conv2(h)
return out_real.squeeze(), out_aux.squeeze()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
conv1输出维度为1,即判别输入的真假,conv2输出维度为c_dim,即判别输入图像的label.
训练数据,损失函数,参数更新,
输入包括
real_x,real_c,fake_c
fake_c为随机生成的,
# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
real_c = real_label.clone()
fake_c = fake_label.clone()
1
2
3
4
5
6
训练判别网络
将真实图像输入判别网络,
# Compute loss with real images
out_src, out_cls = self.D(real_x)
d_loss_real = - torch.mean(out_src)
1
2
3
判别网络的输入为真实图像,输出out_cls为真实图像对应的标签的概率,则可以计算交叉损失熵,
if self.dataset == 'CelebA':
d_loss_cls = F.binary_cross_entropy_with_logits(
out_cls, real_label, size_average=False) / real_x.size(0)
1
2
3
将真实图像输入real_x和假的标签fake_c输入生成网络,得到生成图像fake_x,
fake_x = self.G(real_x, fake_c)
1
将生成图像输入判别网络,
fake_x = Variable(fake_x.data)
out_src, out_cls = self.D(fake_x)
d_loss_fake = torch.mean(out_src)
1
2
3
总的损失函数为,
# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
1
2
根据d_loss更新判别网络参数,
# Backward + Optimize
d_loss = d_loss_real + d_loss_fake + self.lambda_cls * d_loss_cls
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
计算梯度惩罚因子alpha,根据alpha结合real_x,fake_x,输入判别网络,计算梯度,得到梯度损失函数,
# Compute gradient penalty
alpha = torch.rand(real_x.size(0), 1, 1, 1).cuda().expand_as(real_x)
interpolated = Variable(alpha * real_x.data + (1 - alpha) * fake_x.data, requires_grad=True)
out, out_cls = self.D(interpolated)
grad = torch.autograd.grad(outputs=out,
inputs=interpolated,
grad_outputs=torch.ones(out.size()).cuda(),
retain_graph=True,
create_graph=True,
only_inputs=True)[0]
grad = grad.view(grad.size(0), -1)
grad_l2norm = torch.sqrt(torch.sum(grad ** 2, dim=1))
d_loss_gp = torch.mean((grad_l2norm - 1)**2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
根据梯度损失函数d_loss_gp优化判别网路,
# Backward + Optimize
d_loss = self.lambda_gp * d_loss_gp
self.reset_grad()
d_loss.backward()
self.d_optimizer.step()
1
2
3
4
5
训练生成网络
生成网络的作用是,输入original域的图可以生成目标域的图像,输入为目标域的图像,生成original域的图像.
将原图像输入生成网络,得到生成图像fake_x,同时将fake_x图像输入生成网络,希望生成的图像与真实图像尽量相似,
# Original-to-target and target-to-original domain
fake_x = self.G(real_x, fake_c)
rec_x = self.G(fake_x, real_c)
# Compute losses
g_loss_rec = torch.mean(torch.abs(real_x - rec_x))
1
2
3
4
5
将fake_x输入判别网路,
out_src, out_cls = self.D(fake_x)
g_loss_fake = - torch.mean(out_src)
1
2
计算损失函数,
g_loss_fake = - torch.mean(out_src)
1
对于fake_x,对应的label为fake_label,将fake_x输入判别网络,判别网络预测label概率为out_cls,因此可以计算交叉损失熵,
g_loss_cls = F.binary_cross_entropy_with_logits(
out_cls, fake_label, size_average=False) / fake_x.size(0)
1
2
生成网络参数更新,
# Backward + Optimize
g_loss = g_loss_fake + self.lambda_rec * g_loss_rec + self.lambda_cls * g_loss_cls
self.reset_grad()
g_loss.backward()
self.g_optimizer.step()
1
2
3
4
5
训练数据处理
以celebA数据为例,下载后的数据包括label文件,和图像.
文件的第一行为图像的总数,为202599.
第二行为数据处理的类别,包括40种,
5_o_Clock_Shadow Arched_Eyebrows Attractive Bags_Under_Eyes Bald Bangs Big_Lips Big_Nose Black_Hair Blond_Hair Blurry Brown_Hair Bushy_Eyebrows Chubby Double_Chin Eyeglasses Goatee Gray_Hair Heavy_Makeup High_Cheekbones Male Mouth_Slightly_Open Mustache Narrow_Eyes No_Beard Oval_Face Pale_Skin Pointy_Nose Receding_Hairline Rosy_Cheeks Sideburns Smiling Straight_Hair Wavy_Hair Wearing_Earrings Wearing_Hat Wearing_Lipstick Wearing_Necklace Wearing_Necktie Young
第三行及之后的每行为,图像名,已经对应的40种类别的label,label值为1或-1,之后提取为值1为1,-1为0.
000001.jpg -1 1 1 -1 -1 -1 -1 -1 -1 -1 -1 1 -1 -1 -1 -1 -1 -1 1 1 -1 1 -1 -1 1 -1 -1 1 -1 -1 -1 1 1 -1 1 -1 1 -1 -1 1
list_attr_celeba.txt文件提取函数为,
def preprocess(self):
attrs = self.lines[1].split()
for i, attr in enumerate(attrs):
self.attr2idx[attr] = i
self.idx2attr[i] = attr
self.selected_attrs = ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Male', 'Young']
self.train_filenames = []
self.train_labels = []
self.test_filenames = []
self.test_labels = []
lines = self.lines[2:]#the image and labels
random.shuffle(lines) # random shuffling
for i, line in enumerate(lines):
splits = line.split()
filename = splits[0]#image name
values = splits[1:]# labels
label = []
for idx, value in enumerate(values):
attr = self.idx2attr[idx]# there are 40 classes,find the idx th class name
if attr in self.selected_attrs:#check if the attr in the selected classes
if value == '1':#if the ckss label is 1 then label equal 2,otherwise,0
label.append(1)
else:
label.append(0)
if (i+1) < 2000:
self.test_filenames.append(filename)
self.test_labels.append(label)
else:
self.train_filenames.append(filename)
self.train_labels.append(label)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
self.selected_attrs表示我们训练选用的任务类别集合.最后得到图像名数组self.train_filenames,及其对应的label数组 self.train_labels.
之后采用from torch.utils.data import DataLoader加载训练数据,
data_loader = DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=shuffle)
1
2
3
fixed_x = []
real_c = []
for i, (images, labels) in enumerate(self.data_loader):
fixed_x.append(images)
real_c.append(labels)
if i == 3:
break
1
2
3
4
5
6
7
读取后的图像数组为fixed_x,lable为real_c.图像为(bath_size,c_dim,imagesize,imagesize),label为(batch_size,len(self.selected_attrs)).
得到固定的输入图像数组,label,labelist,用于sample.
# Fixed inputs and target domain labels for debugging
fixed_x = torch.cat(fixed_x, dim=0)#4*batch_szie,(64,3,128,128)
fixed_x = self.to_var(fixed_x, volatile=True)
real_c = torch.cat(real_c, dim=0)
if self.dataset == 'CelebA':
fixed_c_list = self.make_celeb_labels(real_c)
1
2
3
4
5
6
7
labellist生成函数为,
def make_celeb_labels(self, real_c):
"""Generate domain labels for CelebA for debugging/testing.
if dataset == 'CelebA':
return single and multiple attribute changes
elif dataset == 'Both':
return single attribute changes
"""
y = [torch.FloatTensor([1, 0, 0]), # black hair
torch.FloatTensor([0, 1, 0]), # blond hair
torch.FloatTensor([0, 0, 1])] # brown hair
fixed_c_list = []
# single attribute transfer
for i in range(self.c_dim):
fixed_c = real_c.clone()
for c in fixed_c:
if i < 3:
c[:3] = y[i]
else:
c[i] = 0 if c[i] == 1 else 1 # opposite value
fixed_c_list.append(self.to_var(fixed_c, volatile=True))
# multi-attribute transfer (H+G, H+A, G+A, H+G+A)
if self.dataset == 'CelebA':
for i in range(4):
fixed_c = real_c.clone()
for c in fixed_c:
if i in [0, 1, 3]: # Hair color to brown
c[:3] = y[2]
if i in [0, 2, 3]: # Gender
c[3] = 0 if c[3] == 1 else 1
if i in [1, 2, 3]: # Aged
c[4] = 0 if c[4] == 1 else 1
fixed_c_list.append(self.to_var(fixed_c, volatile=True))
return fixed_c_list
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
fixed_c_list长度为c_dim+4=5+4=9,
训练的时候,fake_label为随机产生0-batch_size的索引,并由索引,从real_label取值,
# Start training
start_time = time.time()
for e in range(start, self.num_epochs):
for i, (real_x, real_label) in enumerate(self.data_loader):
# Generat fake labels randomly (target domain labels)
rand_idx = torch.randperm(real_label.size(0))
fake_label = real_label[rand_idx]
if self.dataset == 'CelebA':
real_c = real_label.clone()
fake_c = fake_label.clone()
else:
real_c = self.one_hot(real_label, self.c_dim)
fake_c = self.one_hot(fake_label, self.c_dim)
# Convert tensor to variable
real_x = self.to_var(real_x)#(16,3,128,128)
real_c = self.to_var(real_c) #(16,5) # input for the generator
fake_c = self.to_var(fake_c)#(16,5)
real_label = self.to_var(real_label) # this is same as real_c if dataset == 'CelebA'
fake_label = self.to_var(fake_label)
---------------------
作者:imperfect00
来源:CSDN
原文:https://blog.csdn.net/u011961856/article/details/78697863
版权声明:本文为博主原创文章,转载请附上博文链接!
starGAN原理代码分析相关推荐
- 头条巨量快手广点通等平台APPAPI回传事件注册激活-转化联调-API对接原理代码分析和功能实现
2022年最新的头条巨量快手广点通等各推广平台APP&API回传事件-转化联调注册激活-API对接原理代码分析和功能实现! 在商户推广管理当中,经常碰到 需要将用户行为数据回传给推广平台.做转 ...
- vu2响应式原理 代码分析
随着vue3的发布和运行,vue2 的知识也不能忘却.温习一下vue2的原理用代码来解说. 我们熟悉vue的都知道,在vue2 中是不能直接监测到数据的新增和删除的.所以也有一定的方法,我们要知道这些 ...
- 深度学习原理—代码分析线性分类与神经网络分类的区别
https://www.toutiao.com/a6687727778487337476/ 利用sklearn.dataset随机产生数据,随机生成两类数据,用不同的颜色展示出来,如下图: 产生的随机 ...
- vu3响应式原理 代码分析
上个文章我们说了vue2的原理,看这里. 现在简单说一下vue3的原理. vue3 建议大家还是多多看看官网,毕竟语法都变了,虽然兼容vue2,但是最好按照官网说的取用vue3 .不然会有一系列的报错 ...
- 手写内存池以及原理代码分析【C语言】
内存池是对堆进行管理 当进程执行时,操作系统会分出0~4G的虚拟内存空间给进程,程序员可以自行管理(分配.释放)的部分就是mmap映射区.heap堆区,而内存池管理的部分就是用户进程的堆区. 为什么要 ...
- OpenStack 虚拟机冷/热迁移的实现原理与代码分析
目录 文章目录 目录 前文列表 冷迁移代码分析(基于 Newton) Nova 冷迁移实现原理 热迁移代码分析 Nova 热迁移实现原理 向 libvirtd 发出 Live Migration 指令 ...
- 对dpdk的rte_ring实现原理和代码分析
对dpdk的rte_ring实现原理和代码分析 前言 dpdk的rte_ring是借鉴了linux内核的kfifo实现原理,这里统称为无锁环形缓冲队列. 环形缓冲区通常有一个读指针和一个写指针.读指针 ...
- TrueCrypt 6.2a原理及代码分析
TrueCrypt 6.2a原理及代码分析 3 comments 25th Apr 10 rafa 1 项目物理布局 Project |____ Boot /* MBR部分的代码 */ ...
- 免费的Lucene 原理与代码分析完整版下载
Lucene是一个基于Java的高效的全文检索库. 那么什么是全文检索,为什么需要全文检索? 目前人们生活中出现的数据总的来说分为两类:结构化数据和非结构化数据.很容易理解,结构化数据是有固定格式和结 ...
最新文章
- 通过anaconda2安装python2.7和安装pytorch
- java正则表达式课程_通过此免费课程学习正则表达式
- c语言链表创建递归,递归创建二叉树c语言实现+详细解释
- pageContext.findAttribute()与pageContext.getAttribute()的区别
- python箱线图_Python 箱线图 plt.boxplot() 参数详解
- boost线程之类成员函数
- centos安装nginx步骤
- 【搬砖】【Python数据分析】Pycharm中plot绘图不能显示出来
- centos php 开启socket,CentOS 配置PHP支持socket扩展
- Kaggle 发布首份数据科学从业报告 | 不及美国同行1/3,中国数据科学家平均年薪约3万美元
- mysql 装载dump文件_mysql命令、mysqldump命令找不到解决
- PYTHON_SPLIT
- 几款好用的录屏软件推荐
- treetable怎么带参数_treeTable的使用(ajax异步获取数据,动态渲染treeTable)
- sim868 建立tcp链接时的步骤所对应hex码
- 2048游戏规则及玩法技巧攻略
- map拼接URL参数
- 记一个 Harvester SNAT 案例
- 链表的概念以及它的作用
- mock_httpserver
热门文章
- 夏日php登录系统源码,夏日PHP企业管理系统 v0.1
- 计算机不能启动 如何排除故障,开工发现电脑无法开机 如何排查故障?
- android电视视频app下载,央视频APP智能电视版下载-央视频电视版客户端 1.9.0.53139 安卓版-玩友游戏网...
- centos7 nginx安装_手把手教你PHP(一) Centos7上的LEMP配置
- common pool2 mysql_连接池Commons Pool2的使用
- 用法 stl_PoEdu培训第四课-C++之STL
- 中介者模式 调停者 Mediator 行为型 设计模式(二十一)
- 如何成为一名Android架构师,乃至高级架构师,文末有路线图
- Idea--Tomcat配置中的On Upate Action 与 On Frame Deactivation
- 只改一个值!马上加快宽带上网速度