下图展示了传统卷积与DW卷积的差异,在传统卷积中,每个卷积核的channel与输入特征矩阵的channel相等(每个卷积核都会与输入特征矩阵的每一个维度进行卷积运算)。而在DW卷积中,每个卷积核的channel都是等于1的(每个卷积核只负责输入特征矩阵的一个channel,故卷积核的个数必须等于输入特征矩阵的channel数,从而使得输出特征矩阵的channel数也等于输入特征矩阵的channel数)

如果想改变输出特征矩阵的channel,只需要在DW卷积后接上一个PW卷积即可,如下图所示,其实PW卷积就是普通的卷积而已(只不过卷积核大小为1)。通常DW卷积和PW卷积是放在一起使用的,一起叫做Depthwise Separable Convolution(深度可分卷积)

左侧是ResNet网络中的残差结构,右侧就是MobileNet v2中的倒残差结构。在残差结构中是1x1卷积降维->3x3卷积->1x1卷积升维,在倒残差结构中是1x1卷积升维->3x3DW卷积->1x1卷积降维。(注意倒残差结构中基本使用的都是ReLU6激活函数,但是最后一个1x1的卷积层使用的是线性激活函数)

输入特征矩阵为h*w*k,经过1*1conv(卷积核个数为tk)后为h*w*tk,【t为一个扩展因子,对应倒残差结构中第一层1*1conv卷积核的扩展倍率】,再经过一个3*3步距为s的DW卷积后为h/s*w/s*tk,再经过1*1conv(卷积核个数为k')后为h/s*w/s*k'

t为一个扩展因子,对应倒残差结构中第一层1*1conv卷积核的扩展倍率,

c代表输出特征矩阵的channel,

n代表bottlenect(倒残差结构)重复的次数,

s代表每一个block中第一层bottlenect(倒残差结构)所对应的步距,该block中其它层bottlenect所对应的步距都为1,步距指的是DW卷积的步距

model_v2.py

from torch import nn
import torchdef _make_divisible(ch, divisor=8, min_ch=None):"""This function is taken from the original tf repo.It ensures that all layers have a channel number that is divisible by 8It can be seen here:https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py"""if min_ch is None:min_ch = divisornew_ch = max(min_ch, int(ch + divisor / 2) // divisor * divisor) # 保证ch是divisor的整数倍# Make sure that round down does not go down by more than 10%.if new_ch < 0.9 * ch:new_ch += divisorreturn new_chclass ConvBNReLU(nn.Sequential): # 定义普通卷积def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, groups=1): # groups=1为普通卷积,groups=in_channel为depthwise卷积padding = (kernel_size - 1) // 2 # kernel_size=3则padding=1;kernel_size=1则padding=0super(ConvBNReLU, self).__init__(nn.Conv2d(in_channel, out_channel, kernel_size, stride, padding, groups=groups, bias=False), # 如果要使用BN层,就不用使用偏置了nn.BatchNorm2d(out_channel),nn.ReLU6(inplace=True))class InvertedResidual(nn.Module): # 定义倒残差结构def __init__(self, in_channel, out_channel, stride, expand_ratio): # expand_ratio为扩展因子,就是表格中的tsuper(InvertedResidual, self).__init__()hidden_channel = in_channel * expand_ratio # 第一层卷积层的卷积核的个数self.use_shortcut = stride == 1 and in_channel == out_channel # 用于判断在正向传播过程中是否使用shortcutlayers = []if expand_ratio != 1:# 倒残差结构的第一层  1x1 pointwise convlayers.append(ConvBNReLU(in_channel, hidden_channel, kernel_size=1)) # 扩展因子等于1,这个卷积层可以省略layers.extend([     # .extend批量插入很多函数# 倒残差结构的第二层  3x3 depthwise convConvBNReLU(hidden_channel, hidden_channel, stride=stride, groups=hidden_channel),# 倒残差结构的第三层  1x1 pointwise conv(linear) 线性激活函数就是不用添加激活函数(y=x)nn.Conv2d(hidden_channel, out_channel, kernel_size=1, bias=False),nn.BatchNorm2d(out_channel),])self.conv = nn.Sequential(*layers) # 将一系列层结构打包成一个整体def forward(self, x): # 定义正向传播过程if self.use_shortcut: # 使用shortcutreturn x + self.conv(x)else: # 不使用shortcutreturn self.conv(x)class MobileNetV2(nn.Module): # 定义MobileNetV2结构def __init__(self, num_classes=1000, alpha=1.0, round_nearest=8): # alpha为一个超参数,卷积核的倍率super(MobileNetV2, self).__init__()block = InvertedResidual # 倒残差结构传给blockinput_channel = _make_divisible(32 * alpha, round_nearest) # 将卷积核的个数调整到8的整数倍last_channel = _make_divisible(1280 * alpha, round_nearest)inverted_residual_setting = [# t, c, n, s[1, 16, 1, 1],[6, 24, 2, 2],[6, 32, 3, 2],[6, 64, 4, 2],[6, 96, 3, 1],[6, 160, 3, 2],[6, 320, 1, 1],]features = []# 定义第一层卷积层,输入为RGB三通道,输出为input_channel  conv1 layerfeatures.append(ConvBNReLU(3, input_channel, stride=2))# 定义一系列bottleneck层  building inverted residual residual blockesfor t, c, n, s in inverted_residual_setting:output_channel = _make_divisible(c * alpha, round_nearest)for i in range(n):stride = s if i == 0 else 1features.append(block(input_channel, output_channel, stride, expand_ratio=t))input_channel = output_channel# 定义倒数第三层的卷积层  building last several layersfeatures.append(ConvBNReLU(input_channel, last_channel, kernel_size=1))# combine feature layersself.features = nn.Sequential(*features) # 将一系列层结构打包成一个整体#-----------------------以上是特征提取部分-------------------------# 定义分类器部分(表格中的最后两层)  building classifierself.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # 自适应的平均池化下采样操作self.classifier = nn.Sequential(nn.Dropout(0.2),nn.Linear(last_channel, num_classes))# ------------------------以上是分类器部分-----------------------# 权重初始化 weight initializationfor m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out')if m.bias is not None:nn.init.zeros_(m.bias)elif isinstance(m, nn.BatchNorm2d):nn.init.ones_(m.weight)nn.init.zeros_(m.bias)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.zeros_(m.bias)def forward(self, x): # 前向传播过程x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

train.py

import os
import sys
import jsonimport torch
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from torchvision import transforms, datasets
from torchvision.datasets import ImageFolder
from tqdm import tqdmfrom model_v2 import MobileNetV2# 下载预训练权重
import torchvision.models.mobilenetimport os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'# 解决中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = FalseROOT_TRAIN = r'D:/cnn/All Classfication/ResNet/data/train'
ROOT_TEST = r'D:/cnn/All Classfication/ResNet/data/val'def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), # 这里的标准化参数是官网提供的,不做修改"val": transforms.Compose([transforms.Resize(256), # 将原图像长宽比固定,再将其最小边缩放到256transforms.CenterCrop(224), # 在使用中心裁剪到224 * 224大小transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}train_dataset = ImageFolder(ROOT_TRAIN, transform=data_transform["train"])  # 加载训练集train_num = len(train_dataset)  # 打印训练集有多少张图片animal_list = train_dataset.class_to_idx  # 获取类别名称以及对应的索引cla_dict = dict((val, key) for key, val in animal_list.items())  # 将上面的键值对位置对调一下json_str = json.dumps(cla_dict, indent=4)  # 把类别和对应的索引写入根目录下class_indices.json文件中with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = ImageFolder(ROOT_TEST, transform=data_transform["val"])  # 载入测试集val_num = len(validate_dataset)  # 打印测试集有多少张图片validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=16, shuffle=False,num_workers=0)print("using {} images for training, {} images for validation.".format(train_num, val_num))# create modelnet = MobileNetV2(num_classes=2) # 实例化模型,定义类别个数# load pretrain weightsmodel_weight_path = "./mobilenet_v2.pth"assert os.path.exists(model_weight_path), "file {} dose not exist.".format(model_weight_path)pre_weights = torch.load(model_weight_path, map_location='cpu') # 通过torch.load载入预训练模型参数# delete classifier weights 便利权重字典,去除含classifier的层pre_dict = {k: v for k, v in pre_weights.items() if net.state_dict()[k].numel() == v.numel()}missing_keys, unexpected_keys = net.load_state_dict(pre_dict, strict=False)# freeze features weights 冻结特征提取部分的权重for param in net.features.parameters():param.requires_grad = Falsenet.to(device)# define loss functionloss_function = nn.CrossEntropyLoss()# construct an optimizerparams = [p for p in net.parameters() if p.requires_grad]optimizer = optim.Adam(params, lr=0.0001)epochs = 10best_acc = 0.0save_path = './MobileNetV2.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits = net(images.to(device))loss = loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader, file=sys.stdout)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))# loss = loss_function(outputs, test_labels)predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc = "valid epoch[{}/{}]".format(epoch + 1,epochs)val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

reference

MobileNet(v1、v2)网络详解与模型的搭建_太阳花的小绿豆的博客-CSDN博客

7.1 MobileNet网络详解_哔哩哔哩_bilibili

MobileNet网络结构详解相关推荐

  1. AlexNet网络结构详解与代码复现

    参考内容来自up:3.1 AlexNet网络结构详解与花分类数据集下载_哔哩哔哩_bilibili up主的CSDN博客:太阳花的小绿豆的博客_CSDN博客-深度学习,软件安装,Tensorflow领 ...

  2. U-Net网络结构详解

    U-Net网络结构详解 U-Net网络结构是对称的,由于网络结构像U型,所以被命名为U-Net.整体而言,U-Net是一个Encoder-Decoder(编码器-解码器)的结构,这一点是与FCN的结构 ...

  3. pytorch图像分类篇:6. ResNet网络结构详解与迁移学习简介

    前言 最近在b站发现了一个非常好的 计算机视觉 + pytorch 的教程,相见恨晚,能让初学者少走很多弯路. 因此决定按着up给的教程路线:图像分类→目标检测→-一步步学习用pytorch实现深度学 ...

  4. 深度学习之目标检测(五)-- RetinaNet网络结构详解

    深度学习之目标检测(五)-- RetinaNet网络结构详解 深度学习之目标检测(五)RetinaNet网络结构详解 1. RetinaNet 1.1 backbone 部分 1.2 预测器部分 1. ...

  5. AlexNet网络结构详解(含各层维度大小计算过程)与PyTorch实现

    AlexNet网络结构详解(含各层维度大小计算过程)与PyTorch实现 1.AlexNet之前的思考 2.AlexNet网络结构 3.AlexNet网络结构的主要贡献 4.PyTorch实现     ...

  6. 基于CIFAR100的VGG网络结构详解

    基于CIFAR100的VGG网络详解 码字不易,点赞收藏 1 数据集概况 1.1 CIFAR100 cifar100包含20个大类,共100类,train集50000张图片,test集10000张图片 ...

  7. ResNet网络结构详解,网络搭建,迁移学习

    前言: 参考内容来自up:6.1 ResNet网络结构,BN以及迁移学习详解_哔哩哔哩_bilibili up的代码和ppt:https://github.com/WZMIAOMIAO/deep-le ...

  8. OSI七层网络结构详解

    OSI模型的分层结构 OSI(Open System Interconnection),开放式系统互联参考模型 ,它把网络协议从逻辑上分为了7层.这7层分别为:物理层.数据链路层.网络层.传输层.会话 ...

  9. Network in Network(NIN)网络结构详解,网络搭建

    一.简介 Network in Network,描述了一种新型卷积神经网络结构. LeNet,AlexNet,VGG都秉承一种设计思路:先用卷积层构成的模块提取空间特征,再用全连接层模块来输出分类结果 ...

最新文章

  1. Java初学者如何迈出AOP第一步--使用Java 动态代理实现AOP
  2. 新加坡计划通过区块链促进东盟金融包容性
  3. python中组合框_PyQt 组合框
  4. php实现sql server数据导入到mysql数据库_php实现SQL Server数据导入Mysql数据库(示例)...
  5. IoU-aware的目标检测,显著提高定位精度
  6. 年味十足的手绘年画风新年春节海报PSD模板
  7. 中继链路,以太网通道,DHCP配置
  8. 牛客网编程题——字符串_确定两串乱序同构
  9. Java模拟实现一个基于文本界面的《家庭记账软件》
  10. 单片机用C语言锯齿波,试用c语言编写一个能输出锯齿波信号的单片机c51程序
  11. Linux网络下载管理工具(lftp, ftp, lftpget, wget)
  12. Oracle pmon是什么,oracle 11g pmon工作内容系列二
  13. linux远程连接交换机,总结:交换机远程登陆的两种方式,Telnet与SSH那种好?
  14. [网络] 数字签名和数字证书的原理机制
  15. 数据一致性、准确性、完整性、及时性、有效性
  16. 手机无线如何共享给台式计算机,教你用手机做热点分享wifi给台式电脑用,不是用数据网络哦...
  17. 通信工程和计算机考研哪个好,通信工程考研还是就业
  18. 【首次起用黑人模特的Prada】
  19. 蓝桥 算法训练 藏匿的刺客(C语言)
  20. sicily 1050——5个数通过加减乘除运算得到目标数

热门文章

  1. android motion linux handle,Android实现刮奖的效果
  2. java实现多级反馈队列_多级反馈队列调度算法
  3. tq2440流水灯实验
  4. ubuntu启动、关闭、重启服务service命令
  5. 【调剂】四川大学计算机学院(软件学院、智能科学与技术学院)2023年非全日制硕士研究生接收调剂生的通知...
  6. js控制html页面缓存,js页面缓存问题
  7. 设计一个三阶巴特沃斯滤波器_二、三阶巴特沃斯滤波器电路设计—电路精选(47)...
  8. 从千年虫,闰年虫,闰秒虫看测试数据设计
  9. 泛微消息服务器ecology9,ecology-9-demo
  10. 百度地图API (1):往地图中添加标注点