代码链接:GitHub - aassxun/SEMICON

环境配置

# 创建&激活虚拟环境
conda create -n semicon python==3.8.5
conda activate semicon# 安装相关依赖包 (该 pytorch 为无 gpu 版本)
conda install pytorch==1.10.0 torchvision==0.11.1 torchaudio==0.10.0 cpuonly -c pytorch
pip install numpy==1.19.2
pip install loguru==0.5.3
pip install tqdm==4.54.1
pip install pandas
pip install scipy

需要将 SEMICON_train.py、SEMICON.py、Hash_mAP.py、baseline_train.py、baseline.py 中的import models.resnet as resnet 和 from models.resnet import *改为 import models.resnet_torch as resnet 和 from models.resnet_torch import *

下载CUB_200_2011数据集

参考博客:CUB-200-2011鸟类数据集的下载与使用pytorch加载_景唯acr的博客-CSDN博客_cub200-2011

代码运行

1)训练

python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --max-epoch 30 --gpu 0 --arch semicon --batch-size 16 --max-iter 40 --code-length 12,24,32,48 --lr 2.5e-4 --wd 1e-4 --optim SGD --lr-step 40 --num-samples 2000 --info 'CUB-SEMICON' --momen=0.91

2)测试

python run.py --dataset cub-2011 --root /dataset/CUB2011/CUB_200_2011 --gpu 0 --arch test --batch-size 16 --code-length 12,24,32,48 --wd 1e-4 --info 'CUB-SEMICON'

如果不想使用 gpu,将参数 --gpu 设为False 即可

代码学习

1)固定随机种子

与 YOLO-X 类似,将随机种子进行固定,后续实验将在此固定的随机种子下进行 (如消融实验等),增强了模型的可复现性 (但我觉得也只是仅限于特定的随机数,换另一个随机数可能结果又不一样了)。

torch.backends.cudnn.deterministic 和 torch.backends.cudnn.benchmark:前者可以保证每次运行网络的时候相同输入的输出是固定的,后者为整个网络的每个卷积层搜索最适合它的卷积实现算法,进而实现网络的加速。适用场景是网络结构固定,网络的输入形状(包括 batch size,图片大小,输入的通道)是不变的,其实也就是一般情况下都比较适用。反之,如果卷积层的设置一直变化,将会导致程序不停地做优化,反而会耗费更多的时间。

参考博客:

【pytorch】torch.backends.cudnn.deterministic_Xhfei1224的博客-CSDN博客_torch.backends.cudnn.deter

torch.backends.cudnn.benchmark_Wanderer001的博客-CSDN博客_torch.backends.cudnn.benchmark

def seed_everything(seed):random.seed(seed)os.environ['PYTHONHASHSEED'] = str(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed(seed)torch.backends.cudnn.deterministic = Truetorch.backends.cudnn.benchmark = Trueseed_everything(68)

2)数据加载

对应脚本:data/cub_2011.py

函数 load_data() 会返回 3 个 dataloader:query_dataloader、train_dataloader、retrieval_dataloader

# 划分训练、测试集
Cub2011.init(root)
# 定义查询、训练及检索数据集,涉及的数据增强在 data/transform.py 中
query_dataset = Cub2011(root, 'query', query_transform())
train_dataset = Cub2011(root, 'train', train_transform())
retrieval_dataset = Cub2011(root, 'retrieval', query_transform())
class Cub2011(Dataset):def __init__(self, root, mode, transform=None, loader=default_loader):self.root = os.path.expanduser(root)self.transform = transformself.loader = default_loaderif mode == 'train':self.data = Cub2011.TRAIN_DATAself.targets = Cub2011.TRAIN_TARGETSelif mode == 'query':self.data = Cub2011.QUERY_DATAself.targets = Cub2011.QUERY_TARGETSelif mode == 'retrieval':self.data = Cub2011.RETRIEVAL_DATAself.targets = Cub2011.RETRIEVAL_TARGETSelse:raise ValueError(r'Invalid arguments: mode, can\'t load dataset!')@staticmethoddef init(root):images = pd.read_csv(os.path.join(root, 'images.txt'), sep=' ',names=['img_id', 'filepath'])image_class_labels = pd.read_csv(os.path.join(root, 'image_class_labels.txt'),sep=' ', names=['img_id', 'target'])train_test_split = pd.read_csv(os.path.join(root, 'train_test_split.txt'),sep=' ', names=['img_id', 'is_training_img'])data = images.merge(image_class_labels, on='img_id')all_data = data.merge(train_test_split, on='img_id')all_data['filepath'] = 'images/' + all_data['filepath']train_data = all_data[all_data['is_training_img'] == 1]test_data = all_data[all_data['is_training_img'] == 0]# Split datasetCub2011.QUERY_DATA = test_data['filepath'].to_numpy()Cub2011.QUERY_TARGETS = encode_onehot((test_data['target'] - 1).tolist(), 200)Cub2011.TRAIN_DATA = train_data['filepath'].to_numpy()Cub2011.TRAIN_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)Cub2011.RETRIEVAL_DATA = train_data['filepath'].to_numpy()Cub2011.RETRIEVAL_TARGETS = encode_onehot((train_data['target'] - 1).tolist(), 200)def get_onehot_targets(self):return torch.from_numpy(self.targets).float()def __len__(self):return len(self.data)def __getitem__(self, idx):img = Image.open(os.path.join(self.root, self.data[idx])).convert('RGB')if self.transform is not None:img = self.transform(img)return img, self.targets[idx], idx

3)网络训练

主干网络

这里只使用了 resnet50 的前三个 layer,具体可查看 models/SEMICON.py 中的 ResNet_Backbone 类

model = ResNet_Backbone(Bottleneck, [3, 4, 6], **kwargs)

全局/局部转换网络

class ResNet_Refine(nn.Module):def __init__(self, block, layer, is_local=True, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, norm_layer=None):super(ResNet_Refine, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 1024self.dilation = 1self.is_local = is_localself.groups = groupsself.base_width = width_per_groupself.layer4 = self._make_layer(block, 512, layer, stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups,self.base_width, previous_dilation, norm_layer))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes, groups=self.groups,base_width=self.base_width, dilation=self.dilation,norm_layer=norm_layer))layers.append(ChannelTransformer(planes * block.expansion, max(planes * block.expansion // 64, 16)))return nn.Sequential(*layers)def _forward_impl(self, x):x = self.layer4(x)pool_x = self.avgpool(x)pool_x = torch.flatten(pool_x, 1)if self.is_local:return x, pool_xelse:return pool_xdef forward(self, x):return self._forward_impl(x)

SEM

class SEM(nn.Module):def __init__(self, block, layer, att_size=4, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64, replace_stride_with_dilation=None,norm_layer=None):super(SEM, self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.inplanes = 1024self.dilation = 1self.att_size = att_sizeif replace_stride_with_dilation is None:# each element in the tuple indicates if we should replace# the 2x2 stride with a dilated convolution insteadreplace_stride_with_dilation = [False, False, False]if len(replace_stride_with_dilation) != 3:raise ValueError("replace_stride_with_dilation should be None ""or a 3-element tuple, got {}".format(replace_stride_with_dilation))self.groups = groupsself.base_width = width_per_groupself.layer4 = self._make_layer(block, 512, layer, stride=1)self.feature1 = nn.Sequential(conv1x1(self.inplanes, 1),nn.BatchNorm2d(1),nn.ReLU(inplace=True),)self.feature2 = nn.Sequential(conv1x1(self.inplanes, 1),nn.BatchNorm2d(1),nn.ReLU(inplace=True))self.feature3 = nn.Sequential(conv1x1(self.inplanes, 1),nn.BatchNorm2d(1),nn.ReLU(inplace=True))for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationatt_expansion = 0.25layers = []layers.append(block(self.inplanes, int(self.inplanes * att_expansion), stride,downsample, self.groups, self.base_width, previous_dilation, norm_layer))for _ in range(1, blocks):layers.append(nn.Sequential(conv1x1(self.inplanes, int(self.inplanes * att_expansion)),nn.BatchNorm2d(int(self.inplanes * att_expansion))))self.inplanes = int(self.inplanes * att_expansion)layers.append(block(self.inplanes, int(self.inplanes * att_expansion), groups=self.groups,base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer))return nn.Sequential(*layers)def _mask(self, feature, x):with torch.no_grad():cam1 = feature.mean(1)attn = torch.softmax(cam1.view(x.shape[0], x.shape[2] * x.shape[3]), dim=1)#B,H,Wstd, mean = torch.std_mean(attn)attn = (attn - mean) / (std ** 0.3) + 1 #0.15attn = (attn.view((x.shape[0], 1, x.shape[2], x.shape[3]))).clamp(0, 2)return attndef _forward_impl(self, x):x = self.layer4(x)#bs*64*14*14fea1 = self.feature1(x) #bs*1*14*14attn = 2-self._mask(fea1, x)x = x.mul(attn.repeat(1, self.inplanes, 1, 1))fea2 = self.feature2(x)attn = 2-self._mask(fea2, x)x = x.mul(attn.repeat(1, self.inplanes, 1, 1))fea3 = self.feature3(x)x = torch.cat([fea1,fea2,fea3], dim=1)return xdef forward(self, x):return self._forward_impl(x)

ICON

class ChannelTransformer(nn.Module):def __init__(self, dim, num_heads):super().__init__()self.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.head_dim = head_dimself.norm = nn.BatchNorm2d(dim)self.relu = nn.ReLU(inplace=True)self.qkv = nn.Conv2d(dim, dim * 3, 1, groups=num_heads)self.qkv2 = nn.Conv2d(dim, dim * 3, 1, groups=head_dim)def forward(self, x):B, C, H, W = x.shapeqkv = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, H * W).transpose(0, 1)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)attn = attn.softmax(dim=-1)x = ((attn @ v).reshape(B, C, H, W) + x).reshape(B, self.num_heads, self.head_dim, H, W).transpose(1, 2).reshape(B, C, H, W)y = self.norm(x)x = self.relu(y)qkv2 = self.qkv2(x).reshape(B, 3, self.head_dim, self.num_heads, H * W).transpose(0, 1)q, k, v = qkv2[0], qkv2[1], qkv2[2]attn = (q @ k.transpose(-2, -1)) * (self.num_heads ** -0.5)attn = torch.sign(attn) * torch.sqrt(torch.abs(attn) + 1e-5)attn = attn.softmax(dim=-1)x = (attn @ v).reshape(B, self.head_dim, self.num_heads, H, W).transpose(1, 2).reshape(B, C, H, W) + yreturn x

损失函数

class ADSH_Loss(nn.Module):def __init__(self, code_length, gamma):super(ADSH_Loss, self).__init__()self.code_length = code_lengthself.gamma = gammadef forward(self, F, B, S, omega):hash_loss = ((self.code_length * S - F @ B.t()) ** 2).sum() / (F.shape[0] * B.shape[0]) / self.code_length * 12quantization_loss = ((F - B[omega, :]) ** 2).sum() / (F.shape[0] * B.shape[0]) * self.gamma / self.code_length * 12loss = hash_loss + quantization_lossreturn loss, hash_loss, quantization_loss

4)网络测试

def valid(query_dataloader, train_dataloader, retrieval_dataloader, code_length, args):num_classes, att_size, feat_size = args.num_classes, 1, 2048model = SEMICON.semicon(code_length=code_length, num_classes=num_classes, att_size=att_size, feat_size=feat_size,device=args.device, pretrained=True)model.to(args.device)model.load_state_dict(torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/model.pkl'), strict=False)model.eval()query_code = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_code.t')query_code = query_code.to(args.device)query_dataloader.dataset.get_onehot_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/query_targets.t')B = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_code.t')B = B.to(args.device)retrieval_targets = torch.load('./checkpoints/' + args.info + '/' + str(code_length) + '/database_targets.t')retrieval_targets = retrieval_targets.to(args.device)mAP = evaluate.mean_average_precision(query_code.to(args.device),B,query_dataloader.dataset.get_onehot_targets().to(args.device),retrieval_targets,args.device,args.topk,)print("Code_Length: " + str(code_length), end="; ")print('[mAP:{:.5f}]'.format(mAP))

5)网络结构 (onnx model)

ECCV2022细粒度图像检索SEMICON代码学习记录相关推荐

  1. ECCV2022细粒度图像检索SEMICON学习记录

    论文题目:SEMICON: A Learning-to-hash Solution for Large-scale Fine-grained Image Retrieval 论文链接:http://w ...

  2. DAB-Deformable-DETR代码学习记录之模型构建

    DAB-DETR的作者在Deformable-DETR基础上,将DAB-DETR的思想融入到了Deformable-DETR中,取得了不错的成绩.今天博主通过源码来学习下DAB-Deformable- ...

  3. 深度学习+心脏医学图像分割——自动心脏诊断挑战赛(ACDC)项目的代码学习记录

    自己的研究方向是心脏AI相关(心脏MRI+深度学习这样子),最近在学习医学图像分割--自动心脏诊断挑战赛(ACDC)的代码: GitHub - baumgach/acdc_segmenter: Pub ...

  4. 【LVI-SAM代码学习记录】

    文章目录 目录 文章目录 前言 一.思路 二.LIO部分代码阅读 1.imagaProjection() 2.featureTracker() 3.imuPreintegration() 4.mapO ...

  5. OTFS代码学习记录Ⅰ

    该代码是Monash University的几位老师开源的(Raviteja Patchava, Yi Hong, and Emanuele Viterbo)首先感谢一下^_^ 接下来就开始我们的代码 ...

  6. Pytorch学习记录-torchtext和Pytorch的实例( 使用神经网络训练Seq2Seq代码)

    Pytorch学习记录-torchtext和Pytorch的实例1 0. PyTorch Seq2Seq项目介绍 1. 使用神经网络训练Seq2Seq 1.1 简介,对论文中公式的解读 1.2 数据预 ...

  7. 2021-01-22学习记录 || 通过二维数组初始化窗体并进行代码重构

    今天主要是通过二维数组将整个界面16个数字块展示出来,并为了下一步添加左移.右移功能创建子类MainFrame继承JFrame类并进行代码重构. 二维数组展示初始化界面 由于2048小游戏需要16个数 ...

  8. AMBA总线协议之AHB学习记录(1)—ahb_bus(附verilog代码)

    目录 0.前言 1.AHB简介 2.ahb_bus实现(verilog) 3.总结反思 & 后面学习计划 0.前言 前段时间粗略过了一下riscv指令集相关内容,并对开源项目tinyriscv ...

  9. Opencv+Python学习记录9:掩膜(掩码)的使用(内附详细代码)

    一,基本概念 OpenCV中的很多函数都会指定一个掩模,也被称为掩码,例如: 计算结果=cv2.add(参数1,参数2,掩模) 当使用掩模参数时,操作只会在掩模值为非空的像素点上执行,并将其他像素点的 ...

最新文章

  1. ROS、realsense开发常用命令汇总
  2. 落谷 P1060 开心的金明
  3. BCH再迎升级,真正比特币即将归来!
  4. 蓝桥杯比赛常考算法_备战蓝桥--算法竞赛入门第一章总结
  5. python socket发送组播数据_Python socket 如何实现广播单播切换
  6. IDEA中安装MyBatis Log Plugin插件完整显示执行的mybatis的sql语句
  7. BOOST_VMD_ASSERT_IS_ARRAY宏相关的测试程序
  8. Mybatis实体类属性名与数据库类名不对应的两种解决方法
  9. 2018 CVPR GAN 相关论文
  10. 做更好的“教练”,用对抗训练增强“知识追踪”
  11. c# 操作word中在右下角插入图片
  12. 柔和渐变UI素材,让设计更加柔和的法宝。
  13. mysql json字段的使用与意义
  14. JSON字符串与Map互转
  15. java quic kcptun_Server-网络加速Kcptun
  16. xp系统计算机怎么连接到网络打印机,XP系统安装网络打印机教程(xp添加网络打印机步骤)...
  17. 给微信小程序页面加载背景图片解决方案
  18. 线条的样式solid dotted dashed
  19. 百度百家号作者昵称、ID、粉丝数量获取
  20. SSD算法的改进版之R-SSD

热门文章

  1. 利用留数定理计算实积分进阶例子
  2. (一)c#Winform自定义控件-基类控件-HZHControls
  3. 惠普战66 三代 pro win10下安装ubuntu20.04
  4. 为特斯拉车主构思设计的一款刹车踩踏数据监测器
  5. 基于springboot的人事管理系统【毕业设计,源码,论文】
  6. linux视频应用程序开发,Linux平台音视频开发和音视频SDK应用
  7. Mevoco 1.6 发布:支持在线克隆云主机的私有云
  8. 《你的孤独,虽败犹荣》阅读笔记
  9. opencv android 羽化,opencv 边缘羽化,边缘过渡
  10. 刚子扯谈:网站运营路在何方?