深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现

  • 0 前言
  • 1 RepVGG Block详解
  • 2 结构重参数化
    • 2.1 融合Conv2d和BN
    • 2.2 Conv2d+BN融合实验(Pytorch)
    • 2.3 将1x1卷积转换成3x3卷积
    • 2.4 将BN转换成3x3卷积
    • 2.5 多分支融合
    • 2.6 结构重参数化实验(Pytorch)
  • 3 模型配置
  • 4、RepVGG网络训练花分类数据集整体项目实现
    • (1)模型构建:model.py
    • (2)模型训练:train.py
    • (3)模型推理测试:predict.py
    • (5)整体项目代码

论文名称: RepVGG: Making VGG-style ConvNets Great Again
论文下载地址: https://arxiv.org/abs/2101.03697
官方源码(Pytorch实现): https://github.com/DingXiaoH/RepVGG

0 前言

1 RepVGG Block详解

2 结构重参数化

2.1 融合Conv2d和BN


2.2 Conv2d+BN融合实验(Pytorch)

from collections import OrderedDictimport numpy as np
import torch
import torch.nn as nndef main():torch.random.manual_seed(0)f1 = torch.randn(1, 2, 3, 3)module = nn.Sequential(OrderedDict(conv=nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False),bn=nn.BatchNorm2d(num_features=2)))module.eval()with torch.no_grad():output1 = module(f1)print(output1)# fuse conv + bnkernel = module.conv.weight running_mean = module.bn.running_meanrunning_var = module.bn.running_vargamma = module.bn.weightbeta = module.bn.biaseps = module.bn.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)  # [ch] -> [ch, 1, 1, 1]kernel = kernel * tbias = beta - running_mean * gamma / stdfused_conv = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, stride=1, padding=1, bias=True)fused_conv.load_state_dict(OrderedDict(weight=kernel, bias=bias))with torch.no_grad():output2 = fused_conv(f1)print(output2)np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)print("convert module has been tested, and the result looks good!")if __name__ == '__main__':main()

终端输出结果:

2.3 将1x1卷积转换成3x3卷积

2.4 将BN转换成3x3卷积


代码截图如下所示:

2.5 多分支融合


代码截图:

图像演示:

2.6 结构重参数化实验(Pytorch)

import time
import torch.nn as nn
import numpy as np
import torchdef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = nn.Sequential()result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding,groups=groups, bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultclass RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3,stride=1, padding=1, dilation=1, groups=1, padding_mode='zeros', deploy=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsself.nonlinearity = nn.ReLU()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups,bias=True, padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) \if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1,stride=stride, padding=0, groups=groups)def forward(self, inputs):if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.rbr_reparam(inputs))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)def get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef switch_to_deploy(self):if hasattr(self, 'rbr_reparam'):returnkernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels,out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation,groups=self.rbr_dense.conv.groups, bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasfor para in self.parameters():para.detach_()self.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Truedef main():f1 = torch.randn(1, 64, 64, 64)block = RepVGGBlock(in_channels=64, out_channels=64)block.eval()with torch.no_grad():output1 = block(f1)start_time = time.time()for _ in range(100):block(f1)print(f"consume time: {time.time() - start_time}")# re-parameterizationblock.switch_to_deploy()output2 = block(f1)start_time = time.time()for _ in range(100):block(f1)print(f"consume time: {time.time() - start_time}")np.testing.assert_allclose(output1.numpy(), output2.numpy(), rtol=1e-03, atol=1e-05)print("convert module has been tested, and the result looks good!")if __name__ == '__main__':main()

终端输出结果如下:

通过对比能够发现,结构重参数化后推理速度翻倍了,并且转换前后的输出保持一致。

3 模型配置

4、RepVGG网络训练花分类数据集整体项目实现

项目整体目录如下图所示:

其中model.pyRepVGG模型实现代码,可以供选择的模型依次如下:

func_dict = {'RepVGG-A0': create_RepVGG_A0,
'RepVGG-A1': create_RepVGG_A1,
'RepVGG-A2': create_RepVGG_A2,
'RepVGG-B0': create_RepVGG_B0,
'RepVGG-B1': create_RepVGG_B1,
'RepVGG-B1g2': create_RepVGG_B1g2,
'RepVGG-B1g4': create_RepVGG_B1g4,
'RepVGG-B2': create_RepVGG_B2,
'RepVGG-B2g2': create_RepVGG_B2g2,
'RepVGG-B2g4': create_RepVGG_B2g4,
'RepVGG-B3': create_RepVGG_B3,
'RepVGG-B3g2': create_RepVGG_B3g2,
'RepVGG-B3g4': create_RepVGG_B3g4,
'RepVGG-D2se': create_RepVGG_D2se,      #   Updated at April 25, 2021. This is not reported in the CVPR paper.
}

(1)模型构建:model.py

# --------------------------------------------------------
# RepVGG: Making VGG-style ConvNets Great Again (https://openaccess.thecvf.com/content/CVPR2021/papers/Ding_RepVGG_Making_VGG-Style_ConvNets_Great_Again_CVPR_2021_paper.pdf)
# Github source: https://github.com/DingXiaoH/RepVGG
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
import torch.nn as nn
import numpy as np
import torch
import copy
from se_block import SEBlock
import torch.utils.checkpoint as checkpointdef conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):result = nn.Sequential()result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False))result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))return resultclass RepVGGBlock(nn.Module):def __init__(self, in_channels, out_channels, kernel_size,stride=1, padding=0, dilation=1, groups=1, padding_mode='zeros', deploy=False, use_se=False):super(RepVGGBlock, self).__init__()self.deploy = deployself.groups = groupsself.in_channels = in_channelsassert kernel_size == 3assert padding == 1padding_11 = padding - kernel_size // 2self.nonlinearity = nn.ReLU()if use_se:#   Note that RepVGG-D2se uses SE before nonlinearity. But RepVGGplus models uses SE after nonlinearity.self.se = SEBlock(out_channels, internal_neurons=out_channels // 16)else:self.se = nn.Identity()if deploy:self.rbr_reparam = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,padding=padding, dilation=dilation, groups=groups, bias=True, padding_mode=padding_mode)else:self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels and stride == 1 else Noneself.rbr_dense = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups)self.rbr_1x1 = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, padding=padding_11, groups=groups)print('RepVGG Block, identity = ', self.rbr_identity)def forward(self, inputs):if hasattr(self, 'rbr_reparam'):return self.nonlinearity(self.se(self.rbr_reparam(inputs)))if self.rbr_identity is None:id_out = 0else:id_out = self.rbr_identity(inputs)return self.nonlinearity(self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))#   Optional. This may improve the accuracy and facilitates quantization in some cases.#   1.  Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.#   2.  Use like this.#       loss = criterion(....)#       for every RepVGGBlock blk:#           loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()#       optimizer.zero_grad()#       loss.backward()def get_custom_L2(self):K3 = self.rbr_dense.conv.weightK1 = self.rbr_1x1.conv.weightt3 = (self.rbr_dense.bn.weight / ((self.rbr_dense.bn.running_var + self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var + self.rbr_1x1.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()l2_loss_circle = (K3 ** 2).sum() - (K3[:, :, 1:2, 1:2] ** 2).sum()      # The L2 loss of the "circle" of weights_ in 3x3 kernel. Use regular L2 on them.eq_kernel = K3[:, :, 1:2, 1:2] * t3 + K1 * t1                           # The equivalent resultant central point of 3x3 kernel.l2_loss_eq_kernel = (eq_kernel ** 2 / (t3 ** 2 + t1 ** 2)).sum()        # Normalize for an L2 coefficient comparable to regular L2.return l2_loss_eq_kernel + l2_loss_circle#   This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
#   You can get the equivalent kernel and bias at any time and do whatever you want,#   for example, apply some penalties or constraints during training, just like you do to the other models.
#   May be useful for quantization or pruning.def get_equivalent_kernel_bias(self):kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasiddef _pad_1x1_to_3x3_tensor(self, kernel1x1):if kernel1x1 is None:return 0else:return torch.nn.functional.pad(kernel1x1, [1,1,1,1])def _fuse_bn_tensor(self, branch):if branch is None:return 0, 0if isinstance(branch, nn.Sequential):kernel = branch.conv.weightrunning_mean = branch.bn.running_meanrunning_var = branch.bn.running_vargamma = branch.bn.weightbeta = branch.bn.biaseps = branch.bn.epselse:assert isinstance(branch, nn.BatchNorm2d)if not hasattr(self, 'id_tensor'):input_dim = self.in_channels // self.groupskernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)for i in range(self.in_channels):kernel_value[i, i % input_dim, 1, 1] = 1self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)kernel = self.id_tensorrunning_mean = branch.running_meanrunning_var = branch.running_vargamma = branch.weightbeta = branch.biaseps = branch.epsstd = (running_var + eps).sqrt()t = (gamma / std).reshape(-1, 1, 1, 1)return kernel * t, beta - running_mean * gamma / stddef switch_to_deploy(self):if hasattr(self, 'rbr_reparam'):returnkernel, bias = self.get_equivalent_kernel_bias()self.rbr_reparam = nn.Conv2d(in_channels=self.rbr_dense.conv.in_channels, out_channels=self.rbr_dense.conv.out_channels,kernel_size=self.rbr_dense.conv.kernel_size, stride=self.rbr_dense.conv.stride,padding=self.rbr_dense.conv.padding, dilation=self.rbr_dense.conv.dilation, groups=self.rbr_dense.conv.groups, bias=True)self.rbr_reparam.weight.data = kernelself.rbr_reparam.bias.data = biasself.__delattr__('rbr_dense')self.__delattr__('rbr_1x1')if hasattr(self, 'rbr_identity'):self.__delattr__('rbr_identity')if hasattr(self, 'id_tensor'):self.__delattr__('id_tensor')self.deploy = Trueclass RepVGG(nn.Module):def __init__(self, num_blocks, num_classes=1000, width_multiplier=None, override_groups_map=None, deploy=False, use_se=False, use_checkpoint=False):super(RepVGG, self).__init__()assert len(width_multiplier) == 4self.deploy = deployself.override_groups_map = override_groups_map or dict()assert 0 not in self.override_groups_mapself.use_se = use_seself.use_checkpoint = use_checkpointself.in_planes = min(64, int(64 * width_multiplier[0]))self.stage0 = RepVGGBlock(in_channels=3, out_channels=self.in_planes, kernel_size=3, stride=2, padding=1, deploy=self.deploy, use_se=self.use_se)self.cur_layer_idx = 1self.stage1 = self._make_stage(int(64 * width_multiplier[0]), num_blocks[0], stride=2)self.stage2 = self._make_stage(int(128 * width_multiplier[1]), num_blocks[1], stride=2)self.stage3 = self._make_stage(int(256 * width_multiplier[2]), num_blocks[2], stride=2)self.stage4 = self._make_stage(int(512 * width_multiplier[3]), num_blocks[3], stride=2)self.gap = nn.AdaptiveAvgPool2d(output_size=1)print("----------------------------",num_classes)self.linear = nn.Linear(int(512 * width_multiplier[3]), num_classes)def _make_stage(self, planes, num_blocks, stride):strides = [stride] + [1]*(num_blocks-1)blocks = []for stride in strides:cur_groups = self.override_groups_map.get(self.cur_layer_idx, 1)blocks.append(RepVGGBlock(in_channels=self.in_planes, out_channels=planes, kernel_size=3,stride=stride, padding=1, groups=cur_groups, deploy=self.deploy, use_se=self.use_se))self.in_planes = planesself.cur_layer_idx += 1return nn.ModuleList(blocks)def forward(self, x):out = self.stage0(x)for stage in (self.stage1, self.stage2, self.stage3, self.stage4):for block in stage:if self.use_checkpoint:out = checkpoint.checkpoint(block, out)else:out = block(out)out = self.gap(out)out = out.view(out.size(0), -1)out = self.linear(out)return outoptional_groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
g2_map = {l: 2 for l in optional_groupwise_layers}
g4_map = {l: 4 for l in optional_groupwise_layers}def create_RepVGG_A0(deploy=False, use_checkpoint=False,num_classes=None):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=num_classes,width_multiplier=[0.75, 0.75, 0.75, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_A1(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_A2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[2, 4, 14, 1], num_classes=1000,width_multiplier=[1.5, 1.5, 1.5, 2.75], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B0(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[1, 1, 1, 2.5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B1g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2, 2, 2, 4], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B2g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=None, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3g2(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=g2_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_B3g4(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[4, 6, 16, 1], num_classes=1000,width_multiplier=[3, 3, 3, 5], override_groups_map=g4_map, deploy=deploy, use_checkpoint=use_checkpoint)def create_RepVGG_D2se(deploy=False, use_checkpoint=False):return RepVGG(num_blocks=[8, 14, 24, 1], num_classes=1000,width_multiplier=[2.5, 2.5, 2.5, 5], override_groups_map=None, deploy=deploy, use_se=True, use_checkpoint=use_checkpoint)func_dict = {'RepVGG-A0': create_RepVGG_A0,
'RepVGG-A1': create_RepVGG_A1,
'RepVGG-A2': create_RepVGG_A2,
'RepVGG-B0': create_RepVGG_B0,
'RepVGG-B1': create_RepVGG_B1,
'RepVGG-B1g2': create_RepVGG_B1g2,
'RepVGG-B1g4': create_RepVGG_B1g4,
'RepVGG-B2': create_RepVGG_B2,
'RepVGG-B2g2': create_RepVGG_B2g2,
'RepVGG-B2g4': create_RepVGG_B2g4,
'RepVGG-B3': create_RepVGG_B3,
'RepVGG-B3g2': create_RepVGG_B3g2,
'RepVGG-B3g4': create_RepVGG_B3g4,
'RepVGG-D2se': create_RepVGG_D2se,      #   Updated at April 25, 2021. This is not reported in the CVPR paper.
}
def get_RepVGG_func_by_name(name):return func_dict[name]#   Use this for converting a RepVGG model or a bigger model with RepVGG as its component
#   Use like this
#   model = create_RepVGG_A0(deploy=False)
#   train model or load weights_
#   repvgg_model_convert(model, save_path='repvgg_deploy.pth')
#   If you want to preserve the original model, call with do_copy=True#   ====================== for using RepVGG as the backbone of a bigger model, e.g., PSPNet, the pseudo code will be like
#   train_backbone = create_RepVGG_B2(deploy=False)
#   train_backbone.load_state_dict(torch.load('RepVGG-B2-train.pth'))
#   train_pspnet = build_pspnet(backbone=train_backbone)
#   segmentation_train(train_pspnet)
#   deploy_pspnet = repvgg_model_convert(train_pspnet)
#   segmentation_test(deploy_pspnet)
#   =====================   example_pspnet.py shows an exampledef repvgg_model_convert(model:torch.nn.Module, save_path=None, do_copy=True):if do_copy:model = copy.deepcopy(model)for module in model.modules():if hasattr(module, 'switch_to_deploy'):module.switch_to_deploy()if save_path is not None:torch.save(model.state_dict(), save_path)return model

(2)模型训练:train.py

本次我训练是轻量级的级别的模型RepVGG-A0使用官方预训练模型,测试集达到86%准确率,可能模型比较浅的原因,官方预训练模型连接如下:
RepVGG官方预训练模型
提取码:rvgg
在训练过程中根据自己需求导入不同深度级别模型

如下代码位置为分类的类别数:

模型训练整体代码:
train.py

import os
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transformsfrom my_dataset import MyDataSet
from model import create_RepVGG_A0 as create_model
from utils import read_split_data, create_lr_scheduler, get_params_groups, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(f"using {device} device.")if os.path.exists("weights_") is False:os.makedirs("weights_")tb_writer = SummaryWriter()train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)img_size = 224data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(img_size),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(int(img_size * 1.143)),transforms.CenterCrop(img_size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)model = create_model(num_classes=args.num_classes).to(device)if args.weights != "":assert os.path.exists(args.weights), "weights_ file: '{}' not exist.".format(args.weights)weights_dict = torch.load(args.weights, map_location=device)["model"]# 删除有关分类类别的权重for k in list(weights_dict.keys()):if "head" in k:del weights_dict[k]print(model.load_state_dict(weights_dict, strict=False))if args.freeze_layers:for name, para in model.named_parameters():# 除head外,其他权重全部冻结if "head" not in name:para.requires_grad_(False)else:print("training {}".format(name))# pg = [p for p in model.parameters() if p.requires_grad]pg = get_params_groups(model, weight_decay=args.wd)optimizer = optim.AdamW(pg, lr=args.lr, weight_decay=args.wd)lr_scheduler = create_lr_scheduler(optimizer, len(train_loader), args.epochs,warmup=True, warmup_epochs=1)best_acc = 0.for epoch in range(args.epochs):# traintrain_loss, train_acc = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch,lr_scheduler=lr_scheduler)# validateval_loss, val_acc = evaluate(model=model,data_loader=val_loader,device=device,epoch=epoch)tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]tb_writer.add_scalar(tags[0], train_loss, epoch)tb_writer.add_scalar(tags[1], train_acc, epoch)tb_writer.add_scalar(tags[2], val_loss, epoch)tb_writer.add_scalar(tags[3], val_acc, epoch)tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)if best_acc < val_acc:torch.save(model.state_dict(), "./weights_/best_model.pth")best_acc = val_accif __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=200)parser.add_argument('--batch-size', type=int, default=256)parser.add_argument('--lr', type=float, default=5e-4)parser.add_argument('--wd', type=float, default=5e-2)# 数据集所在根目录parser.add_argument('--data-path', type=str,default="./flower_photos")# 预训练权重路径,如果不想载入就设置为空字符parser.add_argument('--weights', type=str, default='',help='initial weights_ path')# 是否冻结head以外所有权重parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args()main(opt)

(3)模型推理测试:predict.py

由于RepVGG是训练-推理解耦的,也即其训练时的网络结构和推理时的网络结构是不同的:训练阶段,网络包含了残差结构、不同大小的卷积核(33、11);而推理阶段,则只包含3*3卷积且为plain结构。需要将训练得到的模型进行相应的转化

# create model
from model import repvgg_model_convert, create_RepVGG_A0
train_model = create_RepVGG_A0(deploy=False,num_classes=5).to(device)
train_model.load_state_dict(torch.load("weights_/best_model.pth"))
deploy_model = repvgg_model_convert(train_model)

或者:

from model import repvgg_model_convert, create_RepVGG_A0
deploy_model = create_RepVGG_A0(deploy=True)
deploy_model.load_state_dict(torch.load('weights_/best_model.pth'))

测试结果:

测试代码:predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltimport os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "test.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)print(img)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelfrom model import repvgg_model_convert, create_RepVGG_A0train_model = create_RepVGG_A0(deploy=False,num_classes=5).to(device)train_model.load_state_dict(torch.load("weights_/best_model.pth",map_location=torch.device('cpu')))deploy_model = repvgg_model_convert(train_model)deploy_model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(deploy_model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

(5)整体项目代码

整体项目已经传至CSDN!!!

深度学习网络模型——RepVGG网络详解、RepVGG网络训练花分类数据集整体项目实现相关推荐

  1. 深度学习网络模型——Vision Transformer详解 VIT详解

    深度学习网络模型--Vision Transformer详解 VIT详解 通用深度学习网络效果改进调参训练公司自己的数据集,训练步骤记录: 代码实现version-Transformer网络各个流程, ...

  2. 【沐神课程 - 动手学深度学习】实战二详解之 Kaggle比赛:分类树叶

    参考与前言 课程链接:跟李沐学AI的个人空间_哔哩哔哩_bilibili 课程主页:https://zh-v2.d2l.ai/chapter_introduction/index.html 相关代码参 ...

  3. 机器学习,深度学习基础算法原理详解(图的搜索、交叉验证、PAC框架、VC-维(持续更新))

    机器学习,深度学习基础算法原理详解(图的搜索.交叉验证.PAC框架.VC-维.支持向量机.核方法(持续更新)) 机器学习,深度学习基础算法原理详解(数据结构部分(持续更新)) 文章目录 1. 图的搜索 ...

  4. 三维深度学习之pointnet系列详解(一)

    目前二维深度学习取得了很大的进步并且应用范围越来越广,随着三维设备的发展,三维深度学习得到了很大的关注. 最近接触了三维深度学习方面的研究,从pointnet入手,对此有了一点点了解希望记录下来并分享 ...

  5. 【深度学习】Batch Normalization详解

    Batch Normalization 学习笔记 原文地址:http://blog.csdn.net/hjimce/article/details/50866313 作者:hjimce 一.背景意义 ...

  6. 深度学习---循环神经网络RNN详解(LSTM)

    上一节我们详细讲解了RNN的其中一个学习算法即BPTT,这个算法是基于BP的,只是和BP不同的是在反向传播时,BPTT的需要追溯上一个时间的权值更新,如下图,当前时刻是s(t),但是反向传播时,他需要 ...

  7. 深度学习环境搭建超级详解(Miniconda、pytorch安装)

    小白刚开始学习<动手学深度学习>,第一次发文,本文主要是为了记录在环境搭建过程中遇到的问题和疑惑,以及解决方法,同时希望能帮到遇到相同问题的小伙伴. 在学习中遇到的疑惑和最后搜索得到的解答 ...

  8. 深度学习 --- 循环神经网络RNN详解(BPTT)

    今天开始深度学习的最后一个重量级的神经网络即RNN,这个网络在自然语言处理中用处很大,因此需要掌握它,同时本人打算在深度学习总结完成以后就开始自然语言处理的总结,至于强化学习呢,目前不打算总结了,因为 ...

  9. 打开深度学习的黑盒,详解神经网络可解释性

    深度学习的可解释性研究在近年来顶会的录取文献词云上频频上榜,越来越多的研究工作表明,打开深度学习的黑盒并不是那么遥不可及.这些工作令人们更加信赖深度学习算法生成的结果,也通过分析模型工作的机理,让新的 ...

最新文章

  1. 强连通 Tarjan
  2. 一路慢行的JavaScript之旅(add)!!!
  3. Windows 10 LTSB
  4. 技术动态 | 藏经阁计划发布一年,阿里知识引擎有哪些技术突破?
  5. c++编写算法判断二叉树是否为完全二叉树_字节面试官:连这90道LeetCode算法题都不会也来面试?...
  6. Spring面试问答
  7. numpy教程:随机数模块numpy.random
  8. 自动化测试工具Selenium Appium
  9. 学习日记day36 平面设计 字体设计
  10. 自学3D游戏建模有哪些教材?
  11. outlook导入服务器邮件,OUTLOOK怎么导入邮件?
  12. 管理之旅(01)游学阿里
  13. c语言坦克大战程序设计,用纯C语言实现坦克大战
  14. 中国剩余定理(c语言)
  15. 零基础的小明要如何成为前端工程师?
  16. win10修改系统时间(2038,2050)重启后桌面一直刷新
  17. 环洋市场调研-2021年全球企业估值服务行业调研及趋势分析报告
  18. 一个jar包冲突引起的StackOverflowError
  19. 小程序公共方法封装(app.js 源码分享)
  20. SAP BC 角色组织级别的参数读取和修改

热门文章

  1. 惯性坐标系,物体坐标系,世界坐标系
  2. 桥梁通服务器没有响应,[桥梁通问答.doc
  3. 辗转相除法和更相减损法 定义,原理,Java实现以及优化
  4. url地址格式的小结
  5. qt210 裸机ac97
  6. 设计模式-代理模式~流星
  7. python判断是否为数字if,python如何判断是否为数字字符串
  8. python实现亚毫秒(微秒)级延时
  9. 各种激活函数及其图像
  10. 容器数据卷的应用之MySQL