最近跑了一下王晋东博士迁移学习简明手册上的深度网络自适应DDC(Deep Domain Confusion)的代码实现,在这里做一下笔记。
来源:Githup开源链接

总结代码的大体框架如下:
1.数据集选择:office31
2.模型选择:Resnet50

3.所用到的.py文件如下图所示:

下面来一个模块一个模块分析:

data_loader.py

from torchvision import datasets, transforms
import torch#参数为 下载数据集的路径、batch_size、布尔型变量判断是否是训练集、数据加载器中的进程数
def load_data(data_folder, batch_size, train, kwargs):transform = {'train': transforms.Compose([transforms.Resize([256, 256]),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]),'test': transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])}data = datasets.ImageFolder(root = data_folder, transform=transform['train' if train else 'test'])data_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, **kwargs, drop_last = True if train else False)return data_loader

分析:
这部分代码与我之前写过的的finetune代码中的dataload部分大同小异,具体可参考我的上一篇文章Pytorch_finetune代码解读,这部分主要是处理实验所用的数据,使之可以直接输入到模型,参数在注释里列出。

bckbone.py

import numpy as np
import torch
import torch.nn as nn
import torchvision
from torchvision import models
from torch.autograd import Variable#这里列出的是resnet50的网络
class ResNet50Fc(nn.Module):def __init__(self):super(ResNet50Fc, self).__init__()model_resnet50 = models.resnet50(pretrained=True)self.conv1 = model_resnet50.conv1self.bn1 = model_resnet50.bn1self.relu = model_resnet50.reluself.maxpool = model_resnet50.maxpool#resnet有四个block,每个block的层数分别为layers=[3,4,6,3]self.layer1 = model_resnet50.layer1self.layer2 = model_resnet50.layer2self.layer3 = model_resnet50.layer3self.layer4 = model_resnet50.layer4self.avgpool = model_resnet50.avgpool#获取全连接层的输入特征self.__in_features = model_resnet50.fc.in_featuresdef forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = x.view(x.size(0), -1)return xdef output_num(self):return self.__in_featuresnetwork_dict = {"alexnet": AlexNetFc,"resnet18": ResNet18Fc,"resnet34": ResNet34Fc,"resnet50": ResNet50Fc,"resnet101": ResNet101Fc,"resnet152": ResNet152Fc}

分析:
这部分代码实现了预模型参数的下载,这里给出了多个模型,我们只关注resnet50的模型参数即可,所以我把其他模型的配置删去了。
注意这里需要了解resnet的基本网络架构,参考资料如下:
resnet18 50网络结构以及pytorch实现代码
ResNet网络结构分析
ResNet的pytorch实现与解析

mmd.py

import torch
import torch.nn as nnclass MMD_loss(nn.Module):def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5):super(MMD_loss, self).__init__()self.kernel_num = kernel_numself.kernel_mul = kernel_mulself.fix_sigma = Noneself.kernel_type = kernel_typedef guassian_kernel(self, source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):n_samples = int(source.size()[0]) + int(target.size()[0])total = torch.cat([source, target], dim=0)total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))L2_distance = ((total0-total1)**2).sum(2)if fix_sigma:bandwidth = fix_sigmaelse:bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)bandwidth /= kernel_mul ** (kernel_num // 2)bandwidth_list = [bandwidth * (kernel_mul**i)for i in range(kernel_num)]kernel_val = [torch.exp(-L2_distance / bandwidth_temp)for bandwidth_temp in bandwidth_list]return sum(kernel_val)def linear_mmd2(self, f_of_X, f_of_Y):loss = 0.0delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)loss = delta.dot(delta.T)return lossdef forward(self, source, target):if self.kernel_type == 'linear':return self.linear_mmd2(source, target)elif self.kernel_type == 'rbf':batch_size = int(source.size()[0])kernels = self.guassian_kernel(source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)with torch.no_grad():XX = torch.mean(kernels[:batch_size, :batch_size])YY = torch.mean(kernels[batch_size:, batch_size:])XY = torch.mean(kernels[:batch_size, batch_size:])YX = torch.mean(kernels[batch_size:, :batch_size])loss = torch.mean(XX + YY - XY - YX)torch.cuda.empty_cache()return loss

分析:
这部分代码是深度网络自适应的核心之一,这里使用mmd算法作为自适应度量。
下图loss计算公式中红框标出部分,说明了这一部分代码的作用,作为总的损失函数的一部份组成,主要度量源域和目标域的数据分布是否达到一致。

注意: coral.py模块也是一种度量方法,我们这里使用了mmd方法,coral就不再列出,功能是一样的。

model.py

import torch.nn as nn
from Coral import CORAL
import mmd
import backbone#注意有两个网络流向,adaptation layer跟mmd有关(比较两个网络流向)
#classifier跟网络本身有关
class Transfer_Net(nn.Module):def __init__(self, num_class, base_net='resnet50', transfer_loss='mmd', use_bottleneck=False, bottleneck_width=256, width=1024):super(Transfer_Net, self).__init__()#引入backbone.py#bottleneck 的层,用来将最高维的特征进行降维,然后进行距离计算。#定义网络,获得网络名self.base_network = backbone.network_dict[base_net]()#确定使用bottleneck层self.use_bottleneck = use_bottleneck#定义mmd距离来计算transfer_lossself.transfer_loss = transfer_loss#定义(瓶颈)全连接层、规范化bottleneck_list = [nn.Linear(self.base_network.output_num(), bottleneck_width), nn.BatchNorm1d(bottleneck_width), nn.ReLU(), nn.Dropout(0.5)]#合并进程self.bottleneck_layer = nn.Sequential(*bottleneck_list)# 定义(分类)全连接层、规范化classifier_layer_list = [nn.Linear(self.base_network.output_num(), width), nn.ReLU(), nn.Dropout(0.5),nn.Linear(width, num_class)]# 合并进程self.classifier_layer = nn.Sequential(*classifier_layer_list)#???self.bottleneck_layer[0].weight.data.normal_(0, 0.005)self.bottleneck_layer[0].bias.data.fill_(0.1)for i in range(2):self.classifier_layer[i * 3].weight.data.normal_(0, 0.01)self.classifier_layer[i * 3].bias.data.fill_(0.0)def forward(self, source, target):#选择网络source = self.base_network(source)target = self.base_network(target)#源域的数据进入网络source_clf = self.classifier_layer(source)#是否使用瓶颈层,这里不适用在前面改为Falseif self.use_bottleneck:source = self.bottleneck_layer(source)target = self.bottleneck_layer(target)#加入适应层!!!#分析两个不同网络的距离分布transfer_loss = self.adapt_loss(source, target, self.transfer_loss)return source_clf, transfer_lossdef predict(self, x):features = self.base_network(x)clf = self.classifier_layer(features)return clf#引入mmd,这里参数为源域网络矩阵、目标域矩阵网络矩阵、计算loss的方法def adapt_loss(self, X, Y, adapt_loss):"""Compute adaptation loss, currently we support mmd and coralArguments:X {tensor} -- source matrixY {tensor} -- target matrixadapt_loss {string} -- loss type, 'mmd' or 'coral'. You can add your own lossReturns:[tensor] -- adaptation loss tensor"""if adapt_loss == 'mmd':mmd_loss = mmd.MMD_loss()loss = mmd_loss(X, Y)elif adapt_loss == 'coral':loss = CORAL(X, Y)else:loss = 0return loss

分析:
那么这一部分代码就是整个深度网络自适应算法的最核心之处,实现了源域和目标域的距离的输出。得出了上面图中公式右端的第二个loss的的具体数值。这里bottleneck我们先不要管,只关注self.classifier_layer部分,这一部分是常规的网络训练的构建。下面forward中self.adapt_loss函数是这部分代码核心指出,理解时要以这个函数为核心,向外延申,这部分理解透了,自适应部分也就明白了,后续就是一些常规的模型训练。

utlis.py

class AverageMeter(object):"""Computes and stores the average and current value"""def __init__(self):self.reset()def reset(self):self.val = 0self.avg = 0self.sum = 0self.count = 0def update(self, val, n=1):self.val = valself.sum += val * nself.count += nself.avg = self.sum / self.count

分析:
用来辅助acc的平均loss的计算。

main.py

import argparse
import torch
import os
import data_loader
import models
import utils
import numpy as npDEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
log = []# Command setting
parser = argparse.ArgumentParser(description='DDC_DCORAL')
parser.add_argument('--model', type=str, default='resnet50')
parser.add_argument('--batchsize', type=int, default=32)
parser.add_argument('--src', type=str, default='amazon')
parser.add_argument('--tar', type=str, default='webcam')
parser.add_argument('--n_class', type=int, default=31)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--n_epoch', type=int, default=10)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--decay', type=float, default=5e-4)
parser.add_argument('--data', type=str, default='D:\迁移学习\Original_images')
parser.add_argument('--early_stop', type=int, default=20)
parser.add_argument('--lamb', type=float, default=10)
parser.add_argument('--trans_loss', type=str, default='mmd')
args = parser.parse_args()def test(model, target_test_loader):model.eval()test_loss = utils.AverageMeter()correct = 0criterion = torch.nn.CrossEntropyLoss()len_target_dataset = len(target_test_loader.dataset)with torch.no_grad():for data, target in target_test_loader:data, target = data.to(DEVICE), target.to(DEVICE)s_output = model.predict(data)loss = criterion(s_output, target)test_loss.update(loss.item())pred = torch.max(s_output, 1)[1]correct += torch.sum(pred == target)acc = 100. * correct / len_target_datasetreturn acc#参数:源域数据、目标域数据、测试数据、模型数据、优化器数据
def train(source_loader, target_train_loader, target_test_loader, model, optimizer):len_source_loader = len(source_loader)len_target_loader = len(target_train_loader)best_acc = 0stop = 0for e in range(args.n_epoch):stop += 1#传递计算数据的函数train_loss_clf = utils.AverageMeter()train_loss_transfer = utils.AverageMeter()train_loss_total = utils.AverageMeter()#训练模式model.train()#iter:用来生成迭代器iter_source, iter_target = iter(source_loader), iter(target_train_loader)#定义每次循环的次数n_batch = min(len_source_loader, len_target_loader)#定义损失函数criterion = torch.nn.CrossEntropyLoss()for _ in range(n_batch):#获得数据与标签(target域没有标签)data_source, label_source = iter_source.next()data_target, _ = iter_target.next()#选择设备data_source, label_source = data_source.to(DEVICE), label_source.to(DEVICE)data_target = data_target.to(DEVICE)optimizer.zero_grad()#将数据投入到模型,这里model为forward,因为model参数在main中已经定义完毕label_source_pred, transfer_loss = model(data_source, data_target)#计算网络的lossclf_loss = criterion(label_source_pred, label_source)#核心部分,计算源域网络和目标域网络的loss,由两部分组成#一个是原网络的损失函数,另一个是两个域的mmd距离loss = clf_loss + args.lamb * transfer_lossloss.backward()optimizer.step()train_loss_clf.update(clf_loss.item())train_loss_transfer.update(transfer_loss.item())train_loss_total.update(loss.item())# Test,获取准确率,这里每次训练都测试一下acc = test(model, target_test_loader)log.append([train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg])np_log = np.array(log, dtype=float)np.savetxt('train_log.csv', np_log, delimiter=',', fmt='%.6f')print('Epoch: [{:2d}/{}], cls_loss: {:.4f}, transfer_loss: {:.4f}, total_Loss: {:.4f}, acc: {:.4f}'.format(e, args.n_epoch, train_loss_clf.avg, train_loss_transfer.avg, train_loss_total.avg, acc))if best_acc < acc:best_acc = accstop = 0#连续20次acc无增加,跳出循环if stop >= args.early_stop:breakprint('Transfer result: {:.4f}'.format(best_acc))#下载数据集,参数为源域文件名,目标域文件名,数据所在目录
#于finetune中dataload.py里实现功能一致
def load_data(src, tar, root_dir):folder_src = os.path.join(root_dir, src)folder_tar = os.path.join(root_dir, tar)source_loader = data_loader.load_data(folder_src, args.batchsize, True, {'num_workers': 4})target_train_loader = data_loader.load_data(folder_tar, args.batchsize, True, {'num_workers': 4})target_test_loader = data_loader.load_data(folder_tar, args.batchsize, False, {'num_workers': 4})return source_loader, target_train_loader, target_test_loaderif __name__ == '__main__':torch.manual_seed(0)source_name = "amazon"target_name = "webcam"print('Src: %s, Tar: %s' % (source_name, target_name))source_loader, target_train_loader, target_test_loader = load_data(source_name, target_name, args.data)#网络模型选择,参数为:最后输出类别数(31)、loss距离名(mmd)、网络模型名(resnet50)model = models.Transfer_Net(args.n_class, transfer_loss=args.trans_loss, base_net=args.model).to(DEVICE)#优化器#注意最后训练的层学习率扩大十倍optimizer = torch.optim.SGD([{'params': model.base_network.parameters()},{'params': model.bottleneck_layer.parameters(), 'lr': 10 * args.lr},{'params': model.classifier_layer.parameters(), 'lr': 10 * args.lr},], lr=args.lr, momentum=args.momentum, weight_decay=args.decay)#训练数据+测试train(source_loader, target_train_loader,target_test_loader, model, optimizer)

分析:
对以上模块进行整合,训练网络,得到loss和准确率并打印。

end

Pytorch_DDC(深度网络自适应,以resnet50为例)代码解读相关推荐

  1. 鱼眼图像自监督深度估计原理分析和Omnidet核心代码解读

    作者丨苹果姐@知乎 来源丨https://zhuanlan.zhihu.com/p/508090405 编辑丨3D视觉工坊 在自动驾驶实际应用中,对相机传感器的要求之一是拥有尽可能大的视野范围,鱼眼相 ...

  2. 【三维深度学习】多视角立体视觉 MVSNet代码解读

    MVSNet通过将相机几何参数编码到网络中,实现了端到端的多视角三维重建,并在性能和视觉效果上超越了先前算法,并在eccv2018 oral中发表. 模型主要包含四个主要步骤:图像特征抽取.多视角可微 ...

  3. 深度残差网络+自适应参数化ReLU激活函数(调参记录3)

    续上一篇: 深度残差网络+自适应参数化ReLU激活函数(调参记录2) https://blog.csdn.net/dangqing1988/article/details/105595917 本文继续 ...

  4. 以VGG为例,分析深度网络的计算量和参数量

    本文原载于https://imlogm.github.io,转载请注明出处~ 摘要:我第一次读到ResNet时,完全不敢相信152层的残差网络,竟然在时间复杂度(计算量)上和16层的VGG是一样大的. ...

  5. 深度残差网络+自适应参数化ReLU激活函数(调参记录8)

    续上一篇: 深度残差网络+自适应参数化ReLU激活函数(调参记录7) https://blog.csdn.net/dangqing1988/article/details/105670981 本文将层 ...

  6. input自适应_深度残差网络+自适应参数化ReLU(调参记录18)Cifar10~94.28%

    本文在调参记录17的基础上,将残差模块的数量增加到27个.其实之前也这样做过,现在的区别在于,自适应参数化ReLU激活函数中第一个全连接层中的神经元个数设置成了特征通道数量的1/16.同样是在Cifa ...

  7. 深度网络学习调研报告

     深度网络学习调研报告 目录 1.前言...............................................3 1.1课题研究的背景及意义................. ...

  8. 《深度学习》(美)Ian Goodfellow 花书简要笔记(第二部分:深度网络)

    本部分是目前应用比较成熟的深度学习基础方法.推荐李飞飞老师的CS231n课程(网易云课堂有全部视频和课件,建议把编程作业刷了)配合学习~ 第六章 深度前馈网络 1.我们最好将前馈神经网络想成是为了实现 ...

  9. 清华大学丁霄汉:深度网络重参数化——让你的模型更快更强

    不到现场,照样看最干货的学术报告! 嗨,大家好.这里是学术报告专栏,读芯术小编不定期挑选并亲自跑会,为大家奉献科技领域最优秀的学术报告,为同学们记录报告干货,并想方设法搞到一手的PPT和现场视频--足 ...

最新文章

  1. 257.二叉树的所有路径
  2. 快速排序C实现(阿里巴巴 2012年全国校招笔试题)
  3. C程序验证邮件地址是否真实存在(不是验证邮箱格式)
  4. 文本分类从入门到精通
  5. 在jMeter里如何创建用户定义的能生成随机数的变量
  6. 辽源a货翡翠,张掖a货翡翠
  7. HashMap在Jdk1.7和1.8中的实现
  8. Kvm虚拟化性能测试与性能优化实践
  9. 5.2.7 原子操作的释放函数
  10. mysql报错:Host‘IP地址‘ isblocked because of many connection errors;unblock with ‘mysqladmin flush-hosts‘
  11. VS2015 ASSERT(false)直接退出不弹出Assert failed对话框的解决方法
  12. Linux环境变量配置【转】
  13. 2021年软件设计师考试大纲
  14. MySQL安装配置步骤
  15. 为了让你在“口袋奇兵”聊遍全球,java面试代码题
  16. 短视频平台api接口php源码
  17. win10-SW2016工程图关联零件属性链接操作
  18. 谈谈自己对目前新型冠状病毒疫情的想法
  19. 基于多目标算法的冷热电联供型综合能源系统运行优化 多目标粒子群 冷热电联供 综合能源系统 运行优化
  20. Html-网页调用摄像头并拍照效果

热门文章

  1. PHP工资管理系统、考勤管理系统、薪资管理系统
  2. Ubuntu17.04+Nvidia GT 640LE+CUDA9.0+cuDNN7.05+Tensorflow1.5r0(GPU)+Anaconda5.01(python3.6)安装
  3. Python爬取网页的所有内外链
  4. 启动vscode不打开上次文件夹
  5. Android 自动开关机
  6. 聊天室平台搭建【免费下载 无需积分/C币】java、Android、php多平台聊天室源码打包下载
  7. 伦敦金实时行情在线看
  8. 展锐芯片之GPU频率(一百一十四)
  9. 简单毛概刷题网页制作
  10. 设计模式-04.02-结构型设计模式-门面模式组合模式享元模式