1.ShuffleNet V1网络架构

ShuffleNet v1网络

2.ShuffleNet V2网络架构

ShuffleNet V2网络

3.代码实现

3.1.model.py

from typing import List,Callable
import torch
from torch import Tensor
import torch.nn as nndef channel_shuffle(x:Tensor,groups:int)->Tensor:batch_size,num_channels,height,width = x.size()channel_per_group = num_channels//groups# reshape# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]x = x.view(batch_size,groups,channel_per_group,height,width)x = torch.transpose(x,1,2).contiguous()# flattenx = x.view(batch_size,-1,height,width)return xclass InvertedResidual(nn.Module):def __init__(self, input_c: int, output_c: int, stride: int):super(InvertedResidual, self).__init__()if stride not in [1, 2]:raise ValueError("illegal stride value.")self.stride = strideassert output_c % 2 == 0branch_features = output_c // 2# 当stride为1时,input_channel应该是branch_features的两倍# python中 '<<' 是位运算,可理解为计算×2的快速方法assert (self.stride != 1) or (input_c == branch_features << 1)if self.stride == 2:self.branch1 = nn.Sequential(self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),nn.BatchNorm2d(input_c),nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True))else:self.branch1 = nn.Sequential()self.branch2 = nn.Sequential(nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True),self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),nn.BatchNorm2d(branch_features),nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),nn.BatchNorm2d(branch_features),nn.ReLU(inplace=True))@staticmethoddef depthwise_conv(input_c:int,output_c:int,kernel_s:int,stride:int=1,padding:int=0,bias:bool=False)->nn.Conv2d:return nn.Conv2d(in_channels=input_c,out_channels=output_c,kernel_size=kernel_s,stride=stride,padding=padding,bias=bias,groups=input_c)def forward(self,x:Tensor)->Tensor:if self.stride == 1:x1,x2 = x.chunk(2,dim=1)out = torch.cat((x1,self.branch2(x2)),dim=1)else:out =torch.cat((self.branch1(x),self.branch2(x)),dim=1)out = channel_shuffle(out,2)return outclass ShuffleNetV2(nn.Module):def __init__(self,stages_repeats:List[int],stages_out_channels:List[int],num_classes:int=1000,inverted_residual:Callable[..., nn.Module]=InvertedResidual) -> None:super(ShuffleNetV2,self).__init__()if len(stages_repeats)!=3:raise ValueError("expected stages_repeats as list of 3 positive ints")if len(stages_out_channels)!=5:raise  ValueError("expected stages_out_channels as list of 5 positive ints")self._stage_out_channels = stages_out_channels# input RGB imageinput_channels = 3output_channels =self._stage_out_channels[0]self.conv1 = nn.Sequential(nn.Conv2d(input_channels,output_channels,kernel_size=3,stride=2,padding=1,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))input_channels=output_channelsself.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)# Static annotations for mypyself.stage2:nn.Sequentialself.stage3:nn.Sequentialself.stage4:nn.Sequentialstage_names = ["stage{}".format(i) for i in [2,3,4]]for name,repeats,output_channels in zip(stage_names,stages_repeats,self._stage_out_channels[1:]):seq = [inverted_residual(input_channels,output_channels,2)]for i in range(repeats-1):seq.append(inverted_residual(output_channels,output_channels,1))setattr(self,name,nn.Sequential(*seq))input_channels = output_channelsoutput_channels = self._stage_out_channels[-1]self.conv5 = nn.Sequential(nn.Conv2d(input_channels,output_channels,kernel_size=1,padding=0,bias=False),nn.BatchNorm2d(output_channels),nn.ReLU(inplace=True))self.fc = nn.Linear(output_channels,num_classes)def _forward_impl(self,x:Tensor)->Tensor:# See note [TorchScript super()]x = self.conv1(x)x = self.maxpool(x)x = self.stage2(x)x = self.stage3(x)x = self.stage4(x)x = self.conv5(x)x = x.mean([2,3]) # global poolx = self.fc(x)return xdef forward(self,x:Tensor)->Tensor:return self._forward_impl(x)def shufflenet_v2_x1_0(num_classes=1000):"""Constructs a ShuffleNetV2 with 1.0x output channels, as described in`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"<https://arxiv.org/abs/1807.11164>`.weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth:param num_classes::return:"""model = ShuffleNetV2(stages_repeats=[4,8,4],stages_out_channels=[24,116,232,464,1024],num_classes=num_classes)return modeldef shufflenet_v2_x0_5(num_classes=1000):"""Constructs a ShuffleNetV2 with 0.5x output channels, as described in`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"<https://arxiv.org/abs/1807.11164>`.weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth:param num_classes::return:"""model = ShuffleNetV2(stages_repeats=[4,8,4],stages_out_channels=[24,48,96,192,1024],num_classes=num_classes)return modelmodel_x1_0 = shufflenet_v2_x1_0(num_classes=5)
print(model_x1_0)

3.2.train.py

import os
import math
import argparseimport torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_schedulerfrom model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluatedef main(args):device = torch.device(args.device if torch.cuda.is_available() else "cpu")print(args)print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')tb_writer = SummaryWriter()if os.path.exists("./weights") is False:os.makedirs("./weights")train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),"val": transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}# 实例化训练数据集train_dataset = MyDataSet(images_path=train_images_path,images_class=train_images_label,transform=data_transform["train"])# 实例化验证数据集val_dataset = MyDataSet(images_path=val_images_path,images_class=val_images_label,transform=data_transform["val"])batch_size = args.batch_sizenw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,pin_memory=True,num_workers=nw,collate_fn=train_dataset.collate_fn)val_loader = torch.utils.data.DataLoader(val_dataset,batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=nw,collate_fn=val_dataset.collate_fn)# 如果存在预训练权重则载入model = shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)if args.weights != "":if os.path.exists(args.weights):weights_dict = torch.load(args.weights, map_location=device)load_weights_dict = {k: v for k, v in weights_dict.items()if model.state_dict()[k].numel() == v.numel()}print(model.load_state_dict(load_weights_dict, strict=False))else:raise FileNotFoundError("not found weights file: {}".format(args.weights))# 是否冻结权重if args.freeze_layers:for name, para in model.named_parameters():# 除最后的全连接层外,其他权重全部冻结if "fc" not in name:para.requires_grad_(False)pg = [p for p in model.parameters() if p.requires_grad]optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=4E-5)# Scheduler https://arxiv.org/pdf/1812.01187.pdflf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosinescheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)for epoch in range(args.epochs):# trainmean_loss = train_one_epoch(model=model,optimizer=optimizer,data_loader=train_loader,device=device,epoch=epoch)scheduler.step()# validateacc = evaluate(model=model,data_loader=val_loader,device=device)print("[epoch {}] accuracy: {}".format(epoch, round(acc, 3)))tags = ["loss", "accuracy", "learning_rate"]tb_writer.add_scalar(tags[0], mean_loss, epoch)tb_writer.add_scalar(tags[1], acc, epoch)tb_writer.add_scalar(tags[2], optimizer.param_groups[0]["lr"], epoch)torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))if __name__ == '__main__':parser = argparse.ArgumentParser()parser.add_argument('--num_classes', type=int, default=5)parser.add_argument('--epochs', type=int, default=30)parser.add_argument('--batch-size', type=int, default=16)parser.add_argument('--lr', type=float, default=0.01)parser.add_argument('--lrf', type=float, default=0.1)# 数据集所在根目录# https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgzparser.add_argument('--data-path', type=str,default="./data/flower_photos")# shufflenetv2_x1.0 官方权重下载地址# https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pthparser.add_argument('--weights', type=str, default='./shufflenetv2_x1.pth',help='initial weights path')parser.add_argument('--freeze-layers', type=bool, default=False)parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')opt = parser.parse_args(args=[])main(opt)

3.3.predict.py

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import shufflenet_v2_x1_0def main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path = "./tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict = json.load(f)# create modelmodel = shufflenet_v2_x1_0(num_classes=5).to(device)# load model weightsmodel_weight_path = "./weights/model-2.pth"model.load_state_dict(torch.load(model_weight_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10}   prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ == '__main__':main()

3.4.my_dataset.py

from PIL import Image
import torch
from torch.utils.data import Datasetclass MyDataSet(Dataset):"""自定义数据集"""def __init__(self, images_path: list, images_class: list, transform=None):self.images_path = images_pathself.images_class = images_classself.transform = transformdef __len__(self):return len(self.images_path)def __getitem__(self, item):img = Image.open(self.images_path[item])# RGB为彩色图片,L为灰度图片if img.mode != 'RGB':raise ValueError("image: {} isn't RGB mode.".format(self.images_path[item]))label = self.images_class[item]if self.transform is not None:img = self.transform(img)return img, label@staticmethoddef collate_fn(batch):# 官方实现的default_collate可以参考# https://github.com/pytorch/pytorch/blob/67b7e751e6b5931a9f45274653f4f653a4e6cdf6/torch/utils/data/_utils/collate.pyimages, labels = tuple(zip(*batch))images = torch.stack(images, dim=0)labels = torch.as_tensor(labels)return images, labels

3.5.utils.py

import os
import sys
import json
import pickle
import randomimport torch
from tqdm import tqdmimport matplotlib.pyplot as pltdef read_split_data(root: str, val_rate: float = 0.2):random.seed(0)  # 保证随机结果可复现assert os.path.exists(root), "dataset root: {} does not exist.".format(root)# 遍历文件夹,一个文件夹对应一个类别flower_class = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]# 排序,保证顺序一致flower_class.sort()# 生成类别名称以及对应的数字索引class_indices = dict((k, v) for v, k in enumerate(flower_class))json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)train_images_path = []  # 存储训练集的所有图片路径train_images_label = []  # 存储训练集图片对应索引信息val_images_path = []  # 存储验证集的所有图片路径val_images_label = []  # 存储验证集图片对应索引信息every_class_num = []  # 存储每个类别的样本总数supported = [".jpg", ".JPG", ".png", ".PNG"]  # 支持的文件后缀类型# 遍历每个文件夹下的文件for cla in flower_class:cla_path = os.path.join(root, cla)# 遍历获取supported支持的所有文件路径images = [os.path.join(root, cla, i) for i in os.listdir(cla_path)if os.path.splitext(i)[-1] in supported]# 获取该类别对应的索引image_class = class_indices[cla]# 记录该类别的样本数量every_class_num.append(len(images))# 按比例随机采样验证样本val_path = random.sample(images, k=int(len(images) * val_rate))for img_path in images:if img_path in val_path:  # 如果该路径在采样的验证集样本中则存入验证集val_images_path.append(img_path)val_images_label.append(image_class)else:  # 否则存入训练集train_images_path.append(img_path)train_images_label.append(image_class)print("{} images were found in the dataset.".format(sum(every_class_num)))print("{} images for training.".format(len(train_images_path)))print("{} images for validation.".format(len(val_images_path)))plot_image = Falseif plot_image:# 绘制每种类别个数柱状图plt.bar(range(len(flower_class)), every_class_num, align='center')# 将横坐标0,1,2,3,4替换为相应的类别名称plt.xticks(range(len(flower_class)), flower_class)# 在柱状图上添加数值标签for i, v in enumerate(every_class_num):plt.text(x=i, y=v + 5, s=str(v), ha='center')# 设置x坐标plt.xlabel('image class')# 设置y坐标plt.ylabel('number of images')# 设置柱状图的标题plt.title('flower class distribution')plt.show()return train_images_path, train_images_label, val_images_path, val_images_labeldef plot_data_loader_image(data_loader):batch_size = data_loader.batch_sizeplot_num = min(batch_size, 4)json_path = './class_indices.json'assert os.path.exists(json_path), json_path + " does not exist."json_file = open(json_path, 'r')class_indices = json.load(json_file)for data in data_loader:images, labels = datafor i in range(plot_num):# [C, H, W] -> [H, W, C]img = images[i].numpy().transpose(1, 2, 0)# 反Normalize操作img = (img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255label = labels[i].item()plt.subplot(1, plot_num, i+1)plt.xlabel(class_indices[str(label)])plt.xticks([])  # 去掉x轴的刻度plt.yticks([])  # 去掉y轴的刻度plt.imshow(img.astype('uint8'))plt.show()def write_pickle(list_info: list, file_name: str):with open(file_name, 'wb') as f:pickle.dump(list_info, f)def read_pickle(file_name: str) -> list:with open(file_name, 'rb') as f:info_list = pickle.load(f)return info_listdef train_one_epoch(model, optimizer, data_loader, device, epoch):model.train()loss_function = torch.nn.CrossEntropyLoss()mean_loss = torch.zeros(1).to(device)optimizer.zero_grad()data_loader = tqdm(data_loader, file=sys.stdout)for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))loss = loss_function(pred, labels.to(device))loss.backward()mean_loss = (mean_loss * step + loss.detach()) / (step + 1)  # update mean lossesdata_loader.desc = "[epoch {}] mean loss {}".format(epoch, round(mean_loss.item(), 3))if not torch.isfinite(loss):print('WARNING: non-finite loss, ending training ', loss)sys.exit(1)optimizer.step()optimizer.zero_grad()return mean_loss.item()@torch.no_grad()
def evaluate(model, data_loader, device):model.eval()# 验证样本总个数total_num = len(data_loader.dataset)# 用于存储预测正确的样本个数sum_num = torch.zeros(1).to(device)data_loader = tqdm(data_loader, file=sys.stdout)for step, data in enumerate(data_loader):images, labels = datapred = model(images.to(device))pred = torch.max(pred, dim=1)[1]sum_num += torch.eq(pred, labels.to(device)).sum()return sum_num.item() / total_num

4.预测结果

ShuffleNet神经网络相关推荐

  1. 轻量化卷积神经网络:SqueezeNet、MobileNet、ShuffleNet、Xception

    一 引言 二 轻量化模型 2.1 SqueezeNet 2.2 MobileNet 2.3 ShuffleNet 2.4 Xception 三 网络对比 一 引言 自2012年AlexNet以来,卷积 ...

  2. 轻量化神经网络篇(SqueezeNet、Xception、MobileNet、ShuffleNet)

    写在前面:此文只记录了下本人感觉需要注意的地方,不全且不一定准确.详细内容可以参考文中帖的链接,比较好!!! 最近看的轻量化神经网络:SqueezeNet.Xception.MobileNet.Shu ...

  3. 卷积神经网络学习路线(二十一) | 旷世科技 ECCV 2018 ShuffleNet V2

    前言 这个系列已经更新了20多篇了,感谢一直以来大家的支持和等待.前面已经介绍过MobileNet V1,MobileNet V2,MobileNet V3,ShuffleNet V1这几个针对移动端 ...

  4. ShuffleNet V1/V2 | 轻量级深层神经网络

    1.简介 ShuffleNet V1是Face++于2017年提出的轻量级深层神经网络.作者在2018年又提出了基于V1版本改进的ShuffleNet V2版本.ShuffleNet V1中的核心思想 ...

  5. 轻量化神经网络总结:SqueezeNet、Xception、MobileNet v1、MobileNet v2、ShuffleNet v1、ShuffleNet v2

    2016.02 伯克利&斯坦福提出 SqueezeNet 在ImageNet上实现了与Alexnet相似的效果,参数只有其1/50 核心点: 1x1卷积核代替3x3卷积核 Squeeze La ...

  6. 纵览轻量化卷积神经网络:SqueezeNet、MobileNet、ShuffleNet、Xception

    https://zhuanlan.zhihu.com/p/32746221 https://zhuanlan.zhihu.com/p/35405071 一.简介 虽然网络性能得到了提高,但随之而来的就 ...

  7. 【神经网络】(14) MnasNet 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现谷歌轻量化神经网络 MnasNet  通常而言,移动端(手机)和终端(安防监控.无人驾驶)上的设备计算能力有限,无法搭载庞大的神经网络 ...

  8. 【神经网络】(13) ShuffleNetV2 代码复现,网络解析,附Tensorflow完整代码

    各位同学好,今天和大家分享一下如何使用 Tensorflow 复现轻量化神经网络 ShuffleNetV2. 为了能将神经网络模型用于移动端(手机)和终端(安防监控.无人驾驶)的实时计算,通常这些设备 ...

  9. sift论文_卷积神经网络设计相关论文

    最近梳理了一下卷积神经网络设计相关的论文(这个repo现在只列出了最重要的一些论文,后面会持续补充): Neural network architecture design​github.com 1. ...

  10. 卷积神经网络中十大拍案叫绝的操作

    作者 | Justin ho 来源 | 知乎 CNN从2012年的AlexNet发展至今,科学家们发明出各种各样的CNN模型,一个比一个深,一个比一个准确,一个比一个轻量.下面会对近几年一些具有变革性 ...

最新文章

  1. ubuntu 下利用ndiswrapper安装无线网卡驱动
  2. python代码判断两棵二叉树是否相同
  3. Java Socket多线程异步通信
  4. GUI应用程序架构的十年变迁:MVC,MVP,MVVM,Unidirectional,Clean
  5. 爱奇艺六季度付费用户数据一览,巨头A股还有希望吗?
  6. 746. Min Cost Climbing Stairs 题解
  7. C++_类和对象_对象特性_友元_友元类_在一个类中声明另一类作为自己的友元类_可以访问自己类中的private变量---C++语言工作笔记053
  8. python从入门到放弃-学Python方法用错,直接从入门到放弃!
  9. bzoj 3749: [POI2015]Łasuchy
  10. MySQL入门 - 数据库的编辑与备份,DOS与图形界面演示,Navicat软件使用
  11. 图像滤镜艺术---美颜相机之高级柔焦效果实现
  12. 服务器单核性能天梯图,台式机cpu性能排行(cpu单核性能天梯图)
  13. 印刷业ERP系统解决方案
  14. 海康8800实时视频Android客户端集成总结
  15. 浅谈对transforms.ToTensor()和transforms.Normalize()函数的理解
  16. SHEEL-远程调用执行命令模板
  17. 远程桌面连接后闪退的解决方法
  18. C# 添加Word水印(基于Spire.Cloud.SDK for .NET )
  19. python规则引擎pyke_规则引擎Pyke与PyClips对比研究报告
  20. 我发现雷迅提供了完整的中文版的教程,这个很棒

热门文章

  1. 兼容性问题以及解决方案
  2. [存档]使用CxServer的7个战略原因
  3. 【WINDOWS / DOS 批处理】添加注释
  4. 诺基亚 java 软件_初学篇:诺基亚手机软件安装不求人
  5. 大数据可视化:Echarts
  6. MacQQ消息防撤回
  7. 小米5s安装xpose 下
  8. C语言自制小游戏:三子棋(井字棋)游戏(超详细)
  9. 将CentOS的yum源更换成阿里源
  10. 迷你KMS mini-KMS_Activator_v1.3_Office2010_VL_ENG使用