分享一下文章PCL: Proxy-based Contrastive Learning for Domain Generalization,代码已经在GitHub上已经开源,其使用的是在DomainBed框架基础上实现的优化框架SWAD上改进的框架本文主要就放一些精简的代码,对于每个模块只保留了一个算法。
DomainBed库主要是为了DG领域的多种方法的实现,所以框架写的很复杂,封装了很多东西,对初次使用的同学真的很不友好,甚至可能连输入输出都看不懂,如果对DG和DA感兴趣的同学,这里推荐一个大佬实现的DA和DG的库,迁移学习代码库,比较容易看懂!!
话不多说,直接上代码

主文件

main.py

import torch
import algorithm
from torch.autograd import Variable
from torchvision import datasets, transforms
#使用swad调优的话,异步文章最后
#import swa_utils
#import swad as swad_module
import torch.nn as nntrain_transforms= transforms.Compose([transforms.Resize(256),transforms.RandomRotation((5), expand=True),transforms.CenterCrop(224),transforms.ToTensor(),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.ColorJitter(.3, .3, .3, .3),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_dataTrans = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])train_data_dir = '../../data/train'
val_data_dir = '../../data/val'
test_data_dir='../../data/test'train_dataset = datasets.ImageFolder(train_data_dir, train_transforms)
val_dataset = datasets.ImageFolder(val_data_dir,val_dataTrans)
test_dataset = datasets.ImageFolder(test_data_dir, val_dataTrans)
#根据需要可重新划分数据集
# train_dataset = torch.utils.data.ConcatDataset([train_dataset1, test_dataset])
# val_dataset = datasets.ImageFolder(val_data_dir, _dataTrans)
# val_dataset, val_dataset_ = torch.utils.data.random_split(val_dataset, [5, len(val_dataset) - 5])
# train_dataset = torch.utils.data.ConcatDataset([train_dataset, val_dataset_])train_dataloder = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True)
val_dataloder = torch.utils.data.DataLoader(val_dataset,batch_size=4,shuffle=True)
device = "cuda" if torch.cuda.is_available() else"cpu"
# setup hparams
algorithm = algorithm.ERM(input_shape=[3, 244, 244], num_classes=4)
use_swad=False#是否使用swad优化
#优化器的选择在algorithm文件里
if use_swad:swad_algorithm = swa_utils.AveragedModel(algorithm)swad_cls = getattr(swad_module, 'LossValley')swad_kwargs={'n_converge': 3, 'n_tolerance': 6, 'tolerance_ratio': 0.3}swad = swad_cls(**swad_kwargs)
algorithm.to(device)
lossfunc=nn.CrossEntropyLoss()
epochs=10
if __name__ == '__main__':for epoch in range(epochs):running_loss = 0running_corrects = 0algorithm.train()for step ,(inputs, labels) in enumerate(train_dataloder):inputs = Variable(inputs.cuda())labels = Variable(labels.cuda())step_vals = algorithm.update(inputs, labels)if use_swad:swad_algorithm.update_parameters(algorithm, step=step)_, outputs = algorithm.predict(inputs)_, preds = torch.max(outputs.data, 1)train_loss = lossfunc(outputs, labels)# statisticsrunning_loss += loss.datatrain_acc=torch.sum(preds == labels.data).cpu().to(torch.float32)running_corrects += train_acctr_epoch_loss = running_loss / len(train_dataloder)tr_epoch_acc = running_corrects / len(train_dataloder)print('{} Loss: {:.4f} Acc: {:.4f}'.format(epoch, tr_epoch_loss, tr_epoch_acc))with torch.no_grad():algorithm.eval()running_loss = 0running_corrects = 0for step, (inputs, labels) in enumerate(val_dataloder ):inputs = Variable(inputs.cuda())labels = Variable(labels.cuda())_, outputs = algorithm.predict(inputs)_, preds = torch.max(outputs.data, 1)loss = lossfunc(outputs, labels)# statisticsrunning_loss += loss.dataval_acc = torch.sum(preds == labels.data).cpu().to(torch.float32)running_corrects += val_accte_epoch_loss = running_loss / len(val_dataloder)te_epoch_acc = running_corrects / len(val_dataloder)if use_swad:swad.update_and_evaluate(swad_algorithm, te_epoch_acc)swad_algorithm = swa_utils.AveragedModel(algorithm)  # resetfilename = r'epoch{}_Loss{:.4f}_Acc{:.4f}_Loss{:.4f}_Acc{:.4f}.pth'.format(epoch, tr_epoch_loss, tr_epoch_acc, te_epoch_loss, te_epoch_acc)torch.save(algorithm.state_dict(), filename, _use_new_zipfile_serialization=False)

主算法

主算法部分,这里使用的是Empirical Risk Minimization (ERM, Vapnik, 1998),原DomainBed框架提供了很多算法,如IRM、GroupDRO、RSC等,可以根据需要自行取用。
————————————————————
algorithm.py

import math
from model import *
from losses import ProxyLoss, ProxyPLoss
import torchclass ERM(torch.nn.Module):"""Empirical Risk Minimization (ERM)"""def __init__(self, input_shape, num_classes):super(ERM, self).__init__()self.encoder, self.scale, self.pcl_weights = encoder()self._initialize_weights(self.encoder)self.fea_proj, self.fc_proj = fea_proj()nn.init.kaiming_uniform_(self.fc_proj, mode='fan_out', a=math.sqrt(5))self.featurizer = ResNet()self.classifier = nn.Parameter(torch.FloatTensor(num_classes,256))nn.init.kaiming_uniform_(self.classifier, mode='fan_out', a=math.sqrt(5))self.optimizer = torch.optim.Adam([{'params': self.featurizer.parameters()},{'params': self.encoder.parameters()},{'params': self.fea_proj.parameters()},{'params': self.fc_proj},{'params': self.classifier},], lr=0.002, weight_decay=0.0)self.proxycloss = ProxyPLoss(num_classes=num_classes, scale=self.scale)def _initialize_weights(self, modules):for m in modules:if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, math.sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.Linear):n = m.weight.size(1)m.weight.data.normal_(0, 0.01)m.bias.data.zero_()def update(self, x, y, **kwargs):all_x = xall_y = yrep, pred = self.predict(all_x)loss_cls = F.nll_loss(F.log_softmax(pred, dim=1), all_y)fc_proj = F.linear(self.classifier, self.fc_proj)assert fc_proj.requires_grad == Trueloss_pcl = self.proxycloss(rep, all_y, fc_proj)loss = loss_cls + self.pcl_weights * loss_pclself.optimizer.zero_grad()loss.backward()self.optimizer.step()return {"loss_cls": loss_cls.item(), "loss_pcl": loss_pcl.item()}def predict(self, x):x = self.featurizer(x)x = self.encoder(x)rep = self.fea_proj(x)pred = F.linear(x, self.classifier)return rep, pred

网络结构

接下来是主要的网络结构模块,这里是用的ResNet50进行图片的特征提取,然后用全连接层进行encoder
————————————————————
model.py


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.modelsclass Identity(nn.Module):"""An identity layer"""def __init__(self):super(Identity, self).__init__()def forward(self, x):return x
class SqueezeLastTwo(nn.Module):"""A module which squeezes the last two dimensions,ordinary squeeze can be a problem for batch size 1"""def __init__(self):super(SqueezeLastTwo, self).__init__()def forward(self, x):return x.view(x.shape[0], x.shape[1])
class ResNet(torch.nn.Module):"""ResNet with the softmax chopped off and the batchnorm frozen"""def __init__(self):super(ResNet, self).__init__()#如果要用其他的网络进行特征提取,可以在这里改#但是要把下面encoder模块的全连接的层的输入和新的网络最后的全连接输出相同network = torchvision.models.resnet50(pretrained=False)# network = resnet50(pretrained=hparams["pretrained"])self.network = network# adapt number of channels# save memory# del self.network.fc#把新的网络的输出层替换为空,用来提供encoder的接口#tips;最后一层大部分都是model.fc或model.headself.network.fc = Identity()self.dropout = nn.Dropout(0.1)self.freeze_bn()def forward(self, x):"""Encode x into a feature vector of size n_outputs."""return self.dropout(self.network(x))def train(self, mode=True):"""Override the default train() to freeze the BN parameters"""super().train(mode)self.freeze_bn()def freeze_bn(self):for m in self.network.modules():if isinstance(m, nn.BatchNorm2d):m.eval()def encoder():scale_weights = 12pcl_weights = 1dropout = nn.Dropout(0.25)hidden_size = 512out_dim = 256#换了新网络要注意改这里n_outputs = 2048encoder = nn.Sequential(nn.Linear(n_outputs, hidden_size),nn.BatchNorm1d(hidden_size),nn.ReLU(inplace=True),dropout,nn.Linear(hidden_size, out_dim),)return encoder, scale_weights, pcl_weightsdef fea_proj():dropout = nn.Dropout(0.25)hidden_size = 256out_dim = 256fea_proj = nn.Sequential(nn.Linear(out_dim,out_dim),)fc_proj = nn.Parameter(torch.FloatTensor(out_dim,out_dim))return fea_proj, fc_proj

损失函数

PCL中提到的损失函数
————————————————————————
losses.py

# coding: utf-8'''
custom loss function
'''import math
import numpy as npimport torch
import torch.nn as nnimport torch.nn.functional as F# # =========================  proxy Contrastive loss ==========================
class ProxyLoss(nn.Module):'''pass'''def __init__(self, scale=1, thres=0.1):super(ProxyLoss, self).__init__()self.scale = scaleself.thres = thresdef forward(self, feature, pred, target):feature = F.normalize(feature, p=2, dim=1)  # normalizefeature = torch.matmul(feature, feature.transpose(1, 0))  # (B, B)label_matrix = target.unsqueeze(1) == target.unsqueeze(0)feature = feature * ~label_matrix  # get negative matrixfeature = feature.masked_fill(feature < self.thres, -np.inf)pred = torch.cat([pred, feature], dim=1)  # (N, C+N)loss = F.nll_loss(F.log_softmax(self.scale * pred, dim=1), \target)return lossclass ProxyPLoss(nn.Module):'''pass'''def __init__(self, num_classes, scale):super(ProxyPLoss, self).__init__()self.soft_plus = nn.Softplus()self.label = torch.LongTensor([i for i in range(num_classes)]).cuda()self.scale = scaledef forward(self, feature, target, proxy):feature = F.normalize(feature, p=2, dim=1)pred = F.linear(feature, F.normalize(proxy, p=2, dim=1))  # (N, C)label = (self.label.unsqueeze(1) == target.unsqueeze(0))pred = torch.masked_select(pred.transpose(1, 0), label)pred = pred.unsqueeze(1)feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)label_matrix = target.unsqueeze(1) == target.unsqueeze(0)index_label = torch.LongTensor([i for i in range(feature.shape[0])])  # generate index labelindex_matrix = index_label.unsqueeze(1) == index_label.unsqueeze(0)  # get index matrixfeature = feature * ~label_matrix  # get negative matrixfeature = feature.masked_fill(feature < 1e-6, -np.inf)logits = torch.cat([pred, feature], dim=1)  # (N, C+N)label = torch.zeros(logits.size(0), dtype=torch.long).cuda()loss = F.nll_loss(F.log_softmax(self.scale * logits, dim=1), label)return lossclass PosAlign(nn.Module):'''pass'''def __init__(self):super(PosAlign, self).__init__()self.soft_plus = nn.Softplus()def forward(self, feature, target):feature = F.normalize(feature, p=2, dim=1)feature = torch.matmul(feature, feature.transpose(1, 0))  # (N, N)label_matrix = target.unsqueeze(1) == target.unsqueeze(0)positive_pair = torch.masked_select(feature, label_matrix)# print("positive_pair.shape", positive_pair.shape)loss = 1. * self.soft_plus(torch.logsumexp(positive_pair, 0))return loss

SWAD调参

如果要使用SWAD调参,可以从去GitHub自取swad和swa_utils。

以后有时间再逐行分析吧…

关于DG(域泛化)领域的PCL方法的代码实例相关推荐

  1. css如何设置透明度?设置透明度的两种方法(代码实例)

    在前端页面开发布局的时候,为了给用户呈现不同的效果,经常需要设置透明度,那么css是怎样设置透明度的?本章给大家介绍用css设置透明度的两种方法(代码实例).有一定的参考价值,有需要的朋友可以参考一下 ...

  2. php绘制一个三角形,如何利用css或html5画出一个三角形?两种不同的制作三角形方法(代码实例)...

    我们在平时的前端开发的时候,有时候是需要一些小图形来丰富一下页面效果,比如:下拉列表的倒三角图形.那么这样的一个三角形是如何制作出来的,本章给大家介绍如何利用css或html画出一个三角形?两种不同的 ...

  3. h5画三角形_如何利用css或html5画出一个三角形?两种不同的制作三角形方法(代码实例)...

    我们在平时的前端开发的时候,有时候是需要一些小图形来丰富一下页面效果,比如:下拉列表的倒三角图形.那么这样的一个三角形是如何制作出来的,本章给大家介绍如何利用css或html画出一个三角形?两种不同的 ...

  4. html中怎么写正六边形,如何用css画正六边形?用css画正六边形的两种方法(代码实例)...

    本章给大家介绍如何用css画正六边形?用css画正六边形的两种方法(代码实例).有一定的参考价值,有需要的朋友可以参考一下,希望对你有所帮助. 在之前要先了解一下正六边形内角和边的关系,正六边形的每个 ...

  5. 数组添加元素的方法PHP,JavaScript如何给数组添加元素?js数组添加元素的3种方法(代码实例)...

    数组是JavaScrip中中一个比较重要的部分,在学习js数组时,数组元素的操作是不可缺少的部分,那么你知道数组元素如何添加吗?本篇文章就给大家介绍如何往js数组(一维)中添加元素,让大家了解往js数 ...

  6. JS字符串转换成数字的三种经典方法和代码实例

    1. 转换函数: js提供了parseInt()和parseFloat()两个转换函数.前者把值转换成整数,后者把值转换成浮点数.只有对String类型调用这些方法,这两个函数才能正确运行: 对其他类 ...

  7. 域适应(DA)---域泛化(DG)

    域泛化数据 分类 PACS数据集 PACS\VLCS\office-home 提取码:tmid ImageNet-C 衡量分类器对损坏的鲁棒性,包含15种corruption,可分为noise(+3) ...

  8. 【域泛化】2022 IJCAI领域泛化教程报告

    目录 1.简介 2.域泛化 背景 AI背景 ​编辑 域适应背景 DA存在问题

  9. 如何通过Meta Learning实现域泛化(Domain Generalization)?

    ©作者 | 丘明姗 单位 | 华南理工大学 研究方向 | 领域泛化 域泛化(Domain Generalization)中有很多工作是用 meta learning 做的.Meta learning ...

最新文章

  1. php点号的意思,[PHP] - 逗号和点号的区别
  2. a prefect storm歌词_Storm s Perfect Storm歌词
  3. 首先声明两者所要实现的功能是一致的(将多维数组降位一维)。这点从两个单词的意也可以看出来,ravel(散开,解开),flatten(变平)。两者的区别在于返回拷贝(copy)还是返回视图(view)
  4. IM热门功能讨论:为什么微信里没有消息“已读”功能?...
  5. Netty之实现一个简单的群聊系统
  6. 御龙在天手游怎么不显示服务器了,御龙在天手游进不去怎么办 闪退原因及解决办法...
  7. ecshop 标签使用 非常好的例子
  8. 最详细的YOLO论文笔记
  9. SQL中的「规则」 constraint 与「约束」 rule 的区别。
  10. 计算机网络苏州大学题库,苏州大学计算机网络样卷B[计科大类].doc
  11. 适合初学者 :用Google map street view api 实现批量下载谷歌地图街景 in python
  12. 计算机ipv4地址修改方法,电脑ip地址的修改方法步骤图
  13. Intel正式发布新一代Atom处理器
  14. B/S系统界面设计与开发详解
  15. Windows应用程序安装向导制作
  16. Microsoft XBOX 360 Project Natal 体感装置2010年6月15正式发布产品正式命名为“Kinect”
  17. 趣图:代码突然又可以运行了,why?
  18. 面试题目之:说出至少4种vue当中的指令和它的用法?
  19. python的super函数详解
  20. AAPT: error: resource drawable...not found.

热门文章

  1. web3 js 连接 metamask 获取账户信息 web3.eth.getAccounts 为空
  2. 渗透测试技术----提权(第三方提权和WCE)
  3. servlet配置信息url-patten的三种匹配形式
  4. 017 如何学习仓储物流自动化知识
  5. Dis-PU复现踩坑
  6. Vue可视化大屏(vue+datav)纯前端
  7. 高盛表示将放弃建立加密货币交易部门的计划
  8. vm安装xp系统提示system not found,一直网卡形式启动
  9. c语言程序设计证书有没有,从未学习过c语言程序设计,10天考取计算机二级c语言程序设计证书可能吗?...
  10. 有1,2,3,4,5,6,7,8,9一共九个数,能组成多少个互不相同且不重复的四位数,分别是多少?