a GAN using Wasserstein loss and resnet to generate anime pics.

一个resnet-WGAN用于生成各种二次元头像(你也可以使用别的图像数据集,用于生成图片)

@本项目用于深度学习中的学习交流,如有任何问题,欢迎联系我!联系方式QQ:741533684

#我使用了残差模块设计了了两个相对对称的残差网络,分别做生成对抗网络的的生成器与判别器,基本原理其实与DCGAN类似。在此基础上,使用了不同于Binary cross entropy loss的Wasserstein loss, 并将优化器从Adam修改为RMSprop(注意:Adam会导致训练不稳定,建议使用RMSprop或者SGD,且学习率不能太大,最好使用学习率衰减。)

之后我会上传我训练的模型,以供大家使用作为预训练模型。

train.py文件重要代码如下:


def weights_init(m):  # 初始化模型权重classname = m.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(m.weight.data, 1.0, 0.02)nn.init.constant_(m.bias.data, 0)def get_lr(optimizer):for param_group in optimizer.param_groups:return param_group['lr']parser = argparse.ArgumentParser()
parser.add_argument('--batchSize', type=int, default=64)
parser.add_argument('--imageSize', type=int, default=96)
parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
parser.add_argument('--ngf', type=int, default=64)
parser.add_argument('--ndf', type=int, default=64)
parser.add_argument('--epoch', type=int, default=500, help='number of epochs to train for')
parser.add_argument('--lrd', type=float, default=5e-5,help="Discriminator's learning rate, default=0.00005")  # Discriminator's learning rate
parser.add_argument('--lrg', type=float, default=5e-5,help="Generator's learning rate, default=0.00005")  # Generator's learning rate
parser.add_argument('--data_path', default='data/', help='folder to train data')  # 将数据集放在此处
parser.add_argument('--outf', default='imgv3/',help='folder to output images and model checkpoints')  # 输出生成图片以及保存模型的位置
opt = parser.parse_args()
# 定义是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 图像读入与预处理
transforms = torchvision.transforms.Compose([torchvision.transforms.Resize(opt.imageSize),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])dataset = torchvision.datasets.ImageFolder(opt.data_path, transform=transforms)dataloader = torch.utils.data.DataLoader(dataset=dataset,batch_size=opt.batchSize,shuffle=True,drop_last=True,
)
netG = NetG().to(device)
netG.apply(weights_init)
print('Generator:' )
print(sum(p.numel() for p in netG.parameters()))netD = NetD().to(device)
netD.apply(weights_init)
print('Discriminator')
print(sum(p.numel() for p in netD.parameters()))print(dataset)netG.load_state_dict(torch.load('imgv2.5/netG_0280.pth', map_location=device))  # 这两句用来读取预训练模型
netD.load_state_dict(torch.load('imgv2.5/netD_0280.pth', map_location=device))  # 这两句用来读取预训练模型
criterionG = Hinge()
optimizerG = torch.optim.RMSprop(netG.parameters(), lr=opt.lrg)
optimizerD = torch.optim.RMSprop(netD.parameters(), lr=opt.lrd)
lrd_scheduler    = torch.optim.lr_scheduler.StepLR(optimizerD, step_size=5, gamma=0.92)
lrg_scheduler    = torch.optim.lr_scheduler.StepLR(optimizerG, step_size=5, gamma=0.92)
criterionD = Hinge()
criterion = torch.nn.BCELoss()
label = torch.FloatTensor(opt.batchSize)
real_label = 1
fake_label = 0
total_lossD = 0.0
total_lossG = 0.0
label = label.unsqueeze(1)start_epoch = 280 # 设置初始epoch大小
for epoch in range(start_epoch + 1, opt.epoch + 1):with tqdm(total=len(dataloader), desc=f'Epoch {epoch}/{opt.epoch}', postfix=dict, mininterval=0.3) as pbar:for i, (imgs, _) in enumerate(dataloader):# 固定生成器G,训练鉴别器D# 让D尽可能的把真图片判别为1imgs = imgs.to(device)# for k in range(1,5):outputreal = netD(imgs)optimizerD.zero_grad()## 让D尽可能把假图片判别为0# label.data.fill_(fake_label)noise = torch.randn(opt.batchSize, opt.nz)# noise = torch.randn(opt.batchSize, opt.nz)noise = noise.to(device)fake = netG(noise)  # 生成假图outputfake = netD(fake.detach())  # 避免梯度传到G,因为G不用更新lossD = criterionD(outputreal, outputfake)total_lossD += lossD.item()lossD.backward()optimizerD.step()# 固定鉴别器D,训练生成器GoptimizerG.zero_grad()# 让D尽可能把G生成的假图判别为1output = netD(fake)lossG = criterionG(output)total_lossG += lossG.item()lossG.backward()optimizerG.step()# print('[%d/%d][%d/%d] Loss_D: %.3f Loss_G %.3f'% (epoch, opt.epoch, i, len(dataloader), lossD.item(), lossG.item()))pbar.set_postfix(**{'total_lossD': total_lossD / (i + 1),'lrd':get_lr(optimizerD), 'total_lossG': total_lossG / (i + 1), 'lrg': get_lr(optimizerG)})pbar.update(1)lrg_scheduler.step()lrd_scheduler.step()vutils.save_image(fake.data,'%s/fake_samples_epoch_%03d.png' % (opt.outf, epoch),normalize=True)log = open("./log.txt", 'a')print('[%d/%d] total_Loss_D: %.3f total_Loss_G %.3f' % (epoch, opt.epoch, total_lossD / (len(dataloader)), total_lossG / (len(dataloader))),file=log)total_lossG = 0.0total_lossD = 0.0log.close()if epoch % 5 == 0:  # 每5个epoch,保存一次模型参数.torch.save(netG.state_dict(), '%s/netG_%04d.pth' % (opt.outf, epoch))torch.save(netD.state_dict(), '%s/netD_%04d.pth' % (opt.outf, epoch))

以下是残差模块,residual block的定义:


class BasicBlock(nn.Module):def __init__(self, in1):super(BasicBlock, self).__init__()self.conv1 = nn.Conv2d(in1, in1 * 2, kernel_size=1,stride=1, padding=0, bias=False)self.bn1 =nn.BatchNorm2d(in1*2)self.relu1 = nn.LeakyReLU(0.2)self.conv2 = nn.Conv2d(in1*2, in1, kernel_size=3,stride=1, padding=1, bias=False)self.bn2 =nn.BatchNorm2d(in1)self.relu2 = nn.LeakyReLU(0.2)def forward(self, x):residual = xout = self.conv1(x)#  out = self.bn1(out)out = self.relu1(out)out = self.conv2(out)#  out = self.bn2(out)out = self.relu2(out)out += residualreturn out

损失函数使用了wasserstein loss,相比于BCEloss(JS距离),能准确衡量生成器产生图片的质量,而Hinge loss相对于W loss来说,能解决其梯度爆炸导致训练不稳定的问题。

class Wasserstein(nn.Module):def forward(self, pred_real, pred_fake=None):if pred_fake is not None:loss_real = -pred_real.mean()loss_fake = pred_fake.mean()loss = loss_real + loss_fakereturn losselse:loss = -pred_real.mean()return lossclass Hinge(nn.Module):#与Wasserstein相比,Hinge能防止梯度暴增。def forward(self, pred_real, pred_fake=None):if pred_fake is not None:loss_real = F.relu(1 - pred_real).mean()loss_fake = F.relu(1 + pred_fake).mean()return loss_real + loss_fakeelse:loss = -pred_real.mean()return loss

所使用的判别器代码:


class RestNet18(nn.Module):def __init__(self):super(RestNet18, self).__init__()self.conv1 = nn.Conv2d(3, 64, kernel_size=3 ,stride=1, padding=1)self.layer1 = nn.Sequential(BasicBlock(64),nn.AvgPool2d(3, 2),BasicBlock(64),BasicBlock(64),)self.layer2 = nn.Sequential(nn.AvgPool2d(3,2),BasicBlock(64),BasicBlock(64),)self.layer3 = nn.Sequential(nn.AvgPool2d(3, 2),BasicBlock(64),BasicBlock(64),)self.layer4 = nn.Sequential(nn.AvgPool2d(3, 2),BasicBlock(64),BasicBlock(64)#  nn.LayerNorm([64,5,5]),)self.layer5 = nn.Sequential(nn.BatchNorm2d(64),#   nn.LayerNorm([64,5,5]),nn.ReLU(True))self.fc = nn.Sequential(nn.Linear(1600, 1),)def forward(self, x):out = self.conv1(x)out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.layer5(out)out = torch.flatten(out,start_dim=1)out = self.fc(out)out = F.sigmoid(out)return out

生成器代码如下:


class Generator(nn.Module):def __init__(self):super().__init__()self.linear = nn.Linear(nz, 64*3*3)self.layer1 = nn.Sequential(BasicBlock(64),nn.UpsamplingNearest2d(scale_factor=2),BasicBlock(64),nn.UpsamplingNearest2d(scale_factor=2),)self.layer2 = nn.Sequential(BasicBlock(64),nn.UpsamplingNearest2d(scale_factor=2),BasicBlock(64),nn.UpsamplingNearest2d(scale_factor=2))self.layer3 = nn.Sequential(BasicBlock(64),BasicBlock(64),nn.UpsamplingNearest2d(scale_factor=2))self.layer4 = nn.Sequential(BasicBlock(64),)self.Conv = nn.Sequential(BasicBlock(64),nn.BatchNorm2d(64),#  nn.LayerNorm([64,96,96]),nn.ReLU(True),nn.Conv2d(64, 3, kernel_size=3, padding=1, stride=1),nn.Tanh())def forward(self, z):x = self.linear(z)x = x.view(batch_size,64,3,3)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.Conv(x)return x

所有代码在我的Github中已经上传:

https://github.com/rabbitdeng/anime-WGAN-resnet-pytorch

readme.md中有所使用的数据集的百度云盘链接!

使用残差网络resnet与WGAN制作一个生成二次元人物头像的GAN(pytorch)相关推荐

  1. 使用SAGAN生成二次元人物头像(GAN生成对抗网络)--pytorch实现

    这是训练250epoch左右的成果. 之前的文章里面,我们使用了残差网络的形式实现生成器与辨别器,它理论上可以实现很不错的效果,但有一个很致命的缺点,就是训练太慢,很难见到成果. 这一次,我们实现了一 ...

  2. 使用残差网络与wgan制作二次元人物头像

    使用残差网络与wgan制作二次元人物头像 ref:https://blog.csdn.net/qq_41103479/article/details/119352714 我复现的项目链接:https: ...

  3. (pytorch-深度学习)实现残差网络(ResNet)

    实现残差网络(ResNet) 我们一般认为,增加神经网络模型的层数,充分训练后的模型理论上能更有效地降低训练误差. 理论上,原模型解的空间只是新模型解的空间的子空间.也就是说,如果我们能将新添加的层训 ...

  4. dlibdotnet 人脸相似度源代码_使用dlib中的深度残差网络(ResNet)实现实时人脸识别 - supersayajin - 博客园...

    opencv中提供的基于haar特征级联进行人脸检测的方法效果非常不好,本文使用dlib中提供的人脸检测方法(使用HOG特征或卷积神经网方法),并使用提供的深度残差网络(ResNet)实现实时人脸识别 ...

  5. 残差网络ResNet

    文章目录 ResNet模型 两个注意点 关于x 关于残差单元 核心实验 原因分析 ResNet的效果 题外话 ResNet是由何凯明在论文Deep Residual Learning for Imag ...

  6. 对残差网络resnet shortcut的解释

    重读残差网络--resnet(对百度vd模型解读) 往事如yan 已于 2022-02-25 07:53:37 修改 652 收藏 4 分类专栏: AI基础 深度学习概念 文章标签: 网络 cnn p ...

  7. 深度残差网络RESNET

    一.残差神经网络--ResNet的综述 深度学习网络的深度对最后的分类和识别的效果有着很大的影响,所以正常想法就是能把网络设计的越深越好, 但是事实上却不是这样,常规的网络的堆叠(plain netw ...

  8. 深度学习目标检测 RCNN F-RCNN SPP yolo-v1 v2 v3 残差网络ResNet MobileNet SqueezeNet ShuffleNet

    深度学习目标检测--结构变化顺序是RCNN->SPP->Fast RCNN->Faster RCNN->YOLO->SSD->YOLO2->Mask RCNN ...

  9. 吴教授的CNN课堂:进阶 | 从LeNet到残差网络(ResNet)和Inception Net

    转载自:https://www.jianshu.com/p/841ac51c7961 第二周是关于卷积网络(CNN)进阶部分,学到挺多新东西.因为之前了解过CNN基础后,就大多在用RNN进行自然语言处 ...

最新文章

  1. 关于javascript数据存储机制的一个案例。
  2. [UWP小白日记-3]记账项目-1
  3. 阿里副总裁肖利华:数智化转型的7个关键词
  4. 第七十四期:国内SaaS企业终于活成了自己讨厌的样子!
  5. markdown 转义字符
  6. 深入理解JAVA中的注解
  7. 分段函数返回字符c语言,C++对cin输入字符的判断及分段函数处理方法示例
  8. MAPGIS提示请在“系统设置”里设置好系统库路径(SUVSLIB或者其他)再重新运行程序
  9. STM32下载Bin文件的几种方式
  10. 抓住七月的尾巴,出门放松一下
  11. 人工智能轨道交通行业周刊-第14期(2022.9.12-9.18)
  12. python学习 Day08 字符串和正则表达式
  13. Deepin20-R7000开启显示器扩展
  14. java dbutils工具类_DbUtils工具类使用
  15. ERD ONline 为企业数字化转型助力
  16. Android App Bundle 自动打包原理
  17. DataFrame添加数据
  18. 哈工大计算机学院考研复试分数线2021,哈工大考研分数线2021什么时候出来?
  19. B1019(数字黑洞)
  20. window查看端口占用

热门文章

  1. redis基础数据类型set(无序不重复集合)
  2. golang处理excel打开csv乱码问题
  3. 微前端(一)微前端是什么?为什么要用微前端?
  4. 华捷艾米a200摄像头_华捷艾米:3D MR打破行业边界,优化产业结构,让生活更美好...
  5. C++ sqlite3 0x00007FF9DA42F621 (ucrtbased.dll)处(位于 TestSqlite.exe 中)引发的异常: 0xC0000005: 读取位置 0x000000
  6. 树莓派连接阿里云物联网平台设备
  7. 一文看懂基于内容的推荐算法
  8. GKFX捷凯官网:gkfx-cn com 八种炒汇交易操作技巧详解
  9. 《Effective Java》阅读笔记7 避免使用终结方法
  10. 实战篇2:假猪套天下第一