深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现
深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现
- 0 前言
- 1 RepVGG Block详解
- 2 结构重参数化
- 2.1 融合Conv2d和BN
- 2.2 Conv2d+BN融合实验(Pytorch)
- 2.3 将1x1卷积转换成3x3卷积
- 2.4 将BN转换成3x3卷积
- 2.5 多分支融合
- 2.6 结构重参数化实验(Pytorch)
- 3 模型配置
- 4、RepVGG网络训练花分类数据集整体项目实现
- (1)模型构建:model.py
- (2)模型训练:train.py
- (3)模型推理测试:predict.py
- (5)整体项目代码
论文名称: RepVGG: Making VGG-style ConvNets Great Again
论文下载地址: https://arxiv.org/abs/2101.03697
官方源码(Pytorch实现): https://github.com/DingXiaoH/RepVGG
0 前言
1 RepVGG Block详解
2 结构重参数化
2.1 融合Conv2d和BN
2.2 Conv2d+BN融合实验(Pytorch)
from collections import OrderedDictimport numpy as np
import torch
import torch.nn as nndef main():torch.random.manual_seed(0)f1 = torch.randn(1, 2, 3, 3)module = nn.Sequential(OrderedDict(conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),bn=nn.BatchNorm2d(num_features=2)))module.eval()with torch.no_grad():output1 = module(f1)print(output1)# fuse conv + bnkernel = module.conv.weight running_mean = module.bn.running_meanrunning_var = module.bn.running_vargamma = module.bn.weightbeta = module.bn.biaseps = module.bn.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1) # [ch] -> [ch, 1, 1, 1]kernel = kernel * tbias = beta - running_mean * gamma / stdfused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))with torch.no_grad():output2 = fused_conv(f1)print(output2)np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)print("convert module has been tested, and the result looks good!")if __name__ == '__main__':main()
终端输出结果:
2.3 将1x1卷积转换成3x3卷积
2.4 将BN转换成3x3卷积
代码截图如下所示:
2.5 多分支融合
代码截图:
图像演示:
2.6 结构重参数化实验(Pytorch)
import time
import torch.nn as nn
import numpy as np
import torchdef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = nn.Sequential()result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding,groups=groups, bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultclass RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3,stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsself.nonlinearity = nn.ReLU()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups,bias=True, padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) \if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,stride=stride, padding=0, groups=groups)def forward(self, inputs):if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.rbr_reparam(inputs))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)def get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef switch_to_deploy(self):if hasattr(self, 'rbr_reparam'):returnkernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,groups=self.rbr_dense.conv.groups, bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasfor para in self.parameters():para.detach_()self.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Truedef main():f1 = torch.randn(1, 64, 64, 64)block = RepVGGBlock(in_channels=64, out_channels=64)block.eval()with torch.no_grad():output1 = block(f1)start_time = time.time()for _ in range(100):block(f1)print(f"consume time: {time.time() - start_time}")# re-parameterizationblock.switch_to_deploy()output2 = block(f1)start_time = time.time()for _ in range(100):block(f1)print(f"consume time: {time.time() - start_time}")np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)print("convert module has been tested, and the result looks good!")if __name__ == '__main__':main()
终端输出结果如下:
通过对比能够发现,结构重参数化后推理速度翻倍了,并且转换前后的输出保持一致。
3 模型配置
4、RepVGG网络训练花分类数据集整体项目实现
项目整体目录如下图所示:
其中model.py是RepVGG模型实现代码,可以供选择的模型依次如下:
func_dict = {'RepVGG-A0': create_RepVGG_A0,
'RepVGG-A1': create_RepVGG_A1,
'RepVGG-A2': create_RepVGG_A2,
'RepVGG-B0': create_RepVGG_B0,
'RepVGG-B1': create_RepVGG_B1,
'RepVGG-B1g2': create_RepVGG_B1g2,
'RepVGG-B1g4': create_RepVGG_B1g4,
'RepVGG-B2': create_RepVGG_B2,
'RepVGG-B2g2': create_RepVGG_B2g2,
'RepVGG-B2g4': create_RepVGG_B2g4,
'RepVGG-B3': create_RepVGG_B3,
'RepVGG-B3g2': create_RepVGG_B3g2,
'RepVGG-B3g4': create_RepVGG_B3g4,
'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
}
(1)模型构建:model.py
# --------------------------------------------------------
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
# Github source: https://github.com/DingXiaoH/RepVGG
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch.nn as nn
import numpy as np
import torch
import copy
from se_block import SEBlock
import torch.utils.checkpoint as checkpointdef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = nn.Sequential()result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultclass RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsassert kernel_size == 3assert padding == 1padding_11 = padding - kernel_size // 2self.nonlinearity = nn.ReLU()if use_se:# Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)else:self.se = nn.Identity()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)print('RepVGG Block, identity = ', self.rbr_identity)def forward(self, inputs):if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.se(self.rbr_reparam(inputs)))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))# Optional. This may improve the accuracy and facilitates quantization in some cases.# 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.# 2. Use like this.# loss = criterion(....)# for every RepVGGBlock blk:# loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()# optimizer.zero_grad()# loss.backward()def get_custom_L2(self):K3 = self.rbr_dense.conv.weightK1 = self.rbr_1x1.conv.weightt3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum() # The L2 loss of the "circle" of weights_ in 3x3 kernel. Use regular L2 on them.eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum() # Normalize for an L2 coefficient comparable to regular L2.return l2_loss_eq_kernel + l2_loss_circle# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
# You can get the equivalent kernel and bias at any time and do whatever you want,# for example, apply some penalties or constraints during training, just like you do to the other models.
# May be useful for quantization or pruning.def get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1,1,1,1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef switch_to_deploy(self):if hasattr(self, 'rbr_reparam'):returnkernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasself.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Trueclass RepVGG(nn.Module):def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False):super(RepVGG, self).__init__()assert len(width_multiplier) == 4self.deploy = deployself.override_groups_map = override_groups_map or dict()assert 0 not in self.override_groups_mapself.use_se = use_seself.use_checkpoint = use_checkpointself.in_planes = min(64, int(64 * width_multiplier[0]))self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se)self.cur_layer_idx = 1self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)self.gap = nn.AdaptiveAvgPool2d(output_size=1)print("----------------------------",num_classes)self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)def _make_stage(self, planes, num_blocks, stride):strides = [stride] + [1]*(num_blocks-1)blocks = []for stride in strides:cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se))self.in_planes = planesself.cur_layer_idx += 1return nn.ModuleList(blocks)def forward(self, x):out = self.stage0(x)for stage in (self.stage1, self.stage2, self.stage3, self.stage4):for block in stage:if self.use_checkpoint:out = checkpoint.checkpoint(block, out)else:out = block(out)out = self.gap(out)out = out.view(out.size(0), -1)out = self.linear(out)return outoptional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}def create_RepVGG_A0(deploy=False, use_checkpoint=False,num_classes=None):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=num_classes,width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_A1(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_A2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B0(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_D2se(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)func_dict = {'RepVGG-A0': create_RepVGG_A0,
'RepVGG-A1': create_RepVGG_A1,
'RepVGG-A2': create_RepVGG_A2,
'RepVGG-B0': create_RepVGG_B0,
'RepVGG-B1': create_RepVGG_B1,
'RepVGG-B1g2': create_RepVGG_B1g2,
'RepVGG-B1g4': create_RepVGG_B1g4,
'RepVGG-B2': create_RepVGG_B2,
'RepVGG-B2g2': create_RepVGG_B2g2,
'RepVGG-B2g4': create_RepVGG_B2g4,
'RepVGG-B3': create_RepVGG_B3,
'RepVGG-B3g2': create_RepVGG_B3g2,
'RepVGG-B3g4': create_RepVGG_B3g4,
'RepVGG-D2se': create_RepVGG_D2se, # Updated at April 25, 2021. This is not reported in the CVPR paper.
}
def get_RepVGG_func_by_name(name):return func_dict[name]# Use this for converting a RepVGG model or a bigger model with RepVGG as its component
# Use like this
# model = create_RepVGG_A0(deploy=False)
# train model or load weights_
# repvgg_model_convert(model, save_path='repvgg_deploy.pth')
# If you want to preserve the original model, call with do_copy=True# ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
# train_backbone = create_RepVGG_B2(deploy=False)
# train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
# train_pspnet = build_pspnet(backbone=train_backbone)
# segmentation_train(train_pspnet)
# deploy_pspnet = repvgg_model_convert(train_pspnet)
# segmentation_test(deploy_pspnet)
# ===================== example_pspnet.py shows an exampledef repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):if do_copy:model = copy.deepcopy(model)for module in model.modules():if hasattr(module, 'switch_to_deploy'):module.switch_to_deploy()if save_path is not None:torch.save(model.state_dict(), save_path)return model
(2)模型训练:train.py
本次我训练是轻量级的级别的模型RepVGG-A0使用官方预训练模型,测试集达到86%准确率,可能模型比较浅的原因,官方预训练模型连接如下:
RepVGG官方预训练模型
提取码:rvgg
在训练过程中根据自己需求导入不同深度级别模型
如下代码位置为分类的类别数:
模型训练整体代码:
train.py
import os
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from model import create_RepVGG_A0 as create_model
from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(f"using {device} device.")if os.path.exists("weights_") is False:os.makedirs("weights_")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights_ file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))# pg = [p for p in model.parameters() if p.requires_grad]pg = get_params_groups(model, weight_decay=args.wd)optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,warmup=True, warmup_epochs=1)best_acc = 0.for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch,lr_scheduler=lr_scheduler)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)if best_acc < val_acc:torch.save(model.state_dict(), "./weights_/best_model.pth")best_acc = val_accif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=200)parser.add_argument('--batch-size', type=int, default=256)parser.add_argument('--lr', type=float, default=5e-4)parser.add_argument('--wd', type=float, default=5e-2)# 数据集所在根目录parser.add_argument('--data-path', type=str,default="./flower_photos")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='',help='initial weights_ path')# 是否冻结head以外所有权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)
(3)模型推理测试:predict.py
由于RepVGG是训练-推理解耦的,也即其训练时的网络结构和推理时的网络结构是不同的:训练阶段,网络包含了残差结构、不同大小的卷积核(33、11);而推理阶段,则只包含3*3卷积且为plain结构。需要将训练得到的模型进行相应的转化
# create model
from model import repvgg_model_convert, create_RepVGG_A0
train_model = create_RepVGG_A0(deploy=False,num_classes=5).to(device)
train_model.load_state_dict(torch.load("weights_/best_model.pth"))
deploy_model = repvgg_model_convert(train_model)
或者:
from model import repvgg_model_convert, create_RepVGG_A0
deploy_model = create_RepVGG_A0(deploy=True)
deploy_model.load_state_dict(torch.load('weights_/best_model.pth'))
测试结果:
测试代码:predict.py
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltimport os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "test.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)print(img)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelfrom model import repvgg_model_convert, create_RepVGG_A0train_model = create_RepVGG_A0(deploy=False,num_classes=5).to(device)train_model.load_state_dict(torch.load("weights_/best_model.pth",map_location=torch.device('cpu')))deploy_model = repvgg_model_convert(train_model)deploy_model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(deploy_model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()
(5)整体项目代码
整体项目已经传至CSDN!!!
深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现相关推荐
- 深度学习网络模型——Vision Transformer详解 VIT详解
深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...
- 【沐神课程 - 动手学深度学习】实战二详解之 Kaggle比赛:分类树叶
参考与前言 课程链接:跟李沐学AI的个人空间_哔哩哔哩_bilibili 课程主页:https://zh-v2.d2l.ai/chapter_introduction/index.html 相关代码参 ...
- 机器学习,深度学习基础算法原理详解(图的搜索、交叉验证、PAC框架、VC-维(持续更新))
机器学习,深度学习基础算法原理详解(图的搜索.交叉验证.PAC框架.VC-维.支持向量机.核方法(持续更新)) 机器学习,深度学习基础算法原理详解(数据结构部分(持续更新)) 文章目录 1. 图的搜索 ...
- 三维深度学习之pointnet系列详解(一)
目前二维深度学习取得了很大的进步并且应用范围越来越广,随着三维设备的发展,三维深度学习得到了很大的关注. 最近接触了三维深度学习方面的研究,从pointnet入手,对此有了一点点了解希望记录下来并分享 ...
- 【深度学习】Batch Normalization详解
Batch Normalization 学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50866313 作者:hjimce 一.背景意义 ...
- 深度学习---循环神经网络RNN详解(LSTM)
上一节我们详细讲解了RNN的其中一个学习算法即BPTT,这个算法是基于BP的,只是和BP不同的是在反向传播时,BPTT的需要追溯上一个时间的权值更新,如下图,当前时刻是s(t),但是反向传播时,他需要 ...
- 深度学习环境搭建超级详解(Miniconda、pytorch安装)
小白刚开始学习<动手学深度学习>,第一次发文,本文主要是为了记录在环境搭建过程中遇到的问题和疑惑,以及解决方法,同时希望能帮到遇到相同问题的小伙伴. 在学习中遇到的疑惑和最后搜索得到的解答 ...
- 深度学习 --- 循环神经网络RNN详解(BPTT)
今天开始深度学习的最后一个重量级的神经网络即RNN,这个网络在自然语言处理中用处很大,因此需要掌握它,同时本人打算在深度学习总结完成以后就开始自然语言处理的总结,至于强化学习呢,目前不打算总结了,因为 ...
- 打开深度学习的黑盒,详解神经网络可解释性
深度学习的可解释性研究在近年来顶会的录取文献词云上频频上榜,越来越多的研究工作表明,打开深度学习的黑盒并不是那么遥不可及.这些工作令人们更加信赖深度学习算法生成的结果,也通过分析模型工作的机理,让新的 ...
最新文章
- 强连通 Tarjan
- 一路慢行的JavaScript之旅(add)!!!
- Windows 10 LTSB
- 技术动态 | 藏经阁计划发布一年,阿里知识引擎有哪些技术突破?
- c++编写算法判断二叉树是否为完全二叉树_字节面试官:连这90道LeetCode算法题都不会也来面试?...
- Spring面试问答
- numpy教程:随机数模块numpy.random
- 自动化测试工具Selenium Appium
- 学习日记day36 平面设计 字体设计
- 自学3D游戏建模有哪些教材?
- outlook导入服务器邮件,OUTLOOK怎么导入邮件?
- 管理之旅(01)游学阿里
- c语言坦克大战程序设计,用纯C语言实现坦克大战
- 中国剩余定理(c语言)
- 零基础的小明要如何成为前端工程师?
- win10修改系统时间(2038,2050)重启后桌面一直刷新
- 环洋市场调研-2021年全球企业估值服务行业调研及趋势分析报告
- 一个jar包冲突引起的StackOverflowError
- 小程序公共方法封装(app.js 源码分享)
- SAP BC 角色组织级别的参数读取和修改