跟着大佬学图像分类系列,→ 传送门 ←

本博客图像分类系列文章传送门:

AlexNet
VGG
GoogleNet(当前)
ResNet

前言

图像分类是学习目标检测的“量变”内容,那么,废话不多说,开搞!


一、GoogleNet 是什么?

GoogleNet 网络是14年由 Google 团队提出,斩获该年 ImageNet 竞赛中 Classification Task(分类任务)第一名(刚学的VGG 是第二,哈哈)。

GoogleNet 论文的 → 传送门 ←

二、网络结构

1.网络特点

  • 引入了 Inception 结构(融合不同尺度的特征信息)
  • 使用 1*1 的卷积核进行降维以及映射处理

左图为 Inception 结果的初始版本。从图中可以看出,之前学的网络,卷积层之间、卷积层和池化层之间是串联结构,而这里是并联结构。而右图增减了三个 1*1 的卷积核,使得特征矩阵的深度降低,极大的减少了参数,进而减少了计算量。

  • 添加两个辅助分类器帮助训练
    (4a)和(4d )旁边的结构
  • 丢弃全连接层,而使用平均池化层(这大大的减少了模型参数)

2.结构

(参数说明)

列名 含义
type 网络中每层的结构名
patch size / stride 结构的参数(以第一行为例,结构为卷积层,卷积核大小为 7*7,步长为2)
output size 输出的特征矩阵的大小
depth 该行对应的结构有几个(以第三行为例,结构为卷积层,depth=2,表示经过两层卷积层(可对照论文中的网络结构来看,由于图太长,所以没贴,网络结构图是论文中的 Figure 3))
后8列 Inception 结构的一些配置

(如图所示,列名为 #1x1 的结构表示 Inception 中的 1x1 convolutions,对应行中的数字为卷积核的数量)

三、使用 Pytorch 搭建 GoogleNet 网络

本代码使用的数据集来自 “花分类” 数据集,→ 传送门 ←(具体内容看 data_set文件夹下的 README.md)

  • model.py ( 搭建 GoogleNet 网络模型 )
import torch.nn as nn
import torch
import torch.nn.functional as Fclass GoogleNet(nn.Module):# aux_logits: 是否使用辅助分类器(训练的时候为True, 验证的时候为False)def __init__(self, num_classes=1000, aux_logits=True, init_weight=False):super(GoogleNet, self).__init__()self.aux_logits = aux_logitsself.conv1 = BasicConv2d(3, 64, kernel_size=7, stride=2, padding=3)self.maxpool1 = nn.MaxPool2d(3, stride=2, ceil_mode=True)   # 当结构为小数时,ceil_mode=True向上取整,=False向下取整# nn.LocalResponseNorm (此处省略)self.conv2 = nn.Sequential(BasicConv2d(64, 64, kernel_size=1),BasicConv2d(64, 192, kernel_size=3, padding=1))self.maxpool2 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception3a = Inception(192, 64, 96, 128, 16, 32, 32)self.inception3b = Inception(256, 128, 128, 192, 32, 96, 64)self.maxpool3 = nn.MaxPool2d(3, stride=2, ceil_mode=True)self.inception4a = Inception(480, 192, 96, 208, 16, 48, 64)self.inception4b = Inception(512, 160, 112, 224, 24, 64, 64)self.inception4c = Inception(512, 128, 128, 256, 24, 64, 64)self.inception4d = Inception(512, 112, 144, 288, 32, 64, 64)self.inception4e = Inception(528, 256, 160, 320, 32, 128, 128)self.maxpool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)self.inception5a = Inception(832, 256, 160, 320, 32, 128, 128)self.inception5b = Inception(832, 384, 192, 384, 48, 128, 128)if aux_logits:      # 使用辅助分类器self.aux1 = InceptionAux(512, num_classes)self.aux2 = InceptionAux(528, num_classes)self.avgpool = nn.AdaptiveAvgPool1d((1, 1))self.dropout = nn.Dropout(0.4)self.fc = nn.Linear(1024, num_classes)if init_weight:self._initialize_weight()def forward(self, x):x = self.conv1(x)x = self.maxpool1(x)x = self.conv2(x)x = self.maxpool2(x)x = self.inception3a(x)x = self.inception3b(x)x =self.maxpool3(x)x =self.inception4a(x)if self.training and self.aux_logits:aux1 = self.aux1(x)x = self.inception4b(x)x = self.inception4c(x)x = self.inception4d(x)if self.training and self.aux_logits:aux2 = self.aux2(x)x = self.inception4e(x)x =self.maxpool4(x)x = self.inception5a(x)x = self.inception5b(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.dropout(x)x = self.fc(x)if self.training and self.aux_logits:return x, aux1, aux2return xdef _initialize_weight(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='')if m.bias is not None:nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)# 创建 Inception 结构函数(模板)
class Inception(nn.Module):# 参数为 Inception 结构的那几个卷积核的数量(详细见表)def __init__(self, in_channels, ch1x1, ch3x3red, ch3x3, ch5x5red, ch5x5, pool_proj):super(Inception, self).__init__()# 四个并联结构self.branch1 = BasicConv2d(in_channels, ch1x1, kernel_size=1)self.branch2 = nn.Sequential(BasicConv2d(in_channels, ch3x3red, kernel_size=1),BasicConv2d(ch3x3red, ch3x3, kernel_size=3, padding=1))self.branch3 = nn.Sequential(BasicConv2d(in_channels, ch5x5red, kernel_size=1),BasicConv2d(ch5x5red, ch5x5, kernel_size=5, padding=2))self.branch4 = nn.Sequential(nn.MaxPool2d(kernel_size=3, stride=1, padding=1),BasicConv2d(in_channels, pool_proj, kernel_size=1))def forward(self, x):branch1 = self.branch1(x)branch2 = self.branch2(x)branch3 = self.branch3(x)branch4 = self.branch4(x)outputs = [branch1, branch2, branch3, branch4]return torch.cat(outputs, 1)# 创建辅助分类器结构函数(模板)
class InceptionAux(nn.Module):def __init__(self, in_channels, num_classes):super(InceptionAux, self).__init__()self.avgPool = nn.AvgPool2d(kernel_size=5, stride=3)self.conv = BasicConv2d(in_channels, 128, kernel_size=1)self.fc1 = nn.Linear(2048, 1024)self.fc2 = nn.Linear(1024, num_classes)def forward(self, x):# aux1: N x 512 x 14 x 14   aux2: N x 528 x 14 x 14(输入)x = self.avgPool(x)# aux1: N x 512 x 4 x 4  aux2: N x 528 x 4 x 4(输出) 4 = (14 - 5)/3 + 1x = self.conv(x)x = torch.flatten(x, 1)     # 展平x = F.dropout(x, 0.5, training=self.training)x = F.relu(self.fc1(x), inplace=True)x = F.dropout(x, 0.5, training=self.training)x = self.fc2(x)return x# 创建卷积层函数(模板)
class BasicConv2d(nn.Module):def __init__(self, in_channels, out_channels, **kwargs):super(BasicConv2d, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)self.relu = nn.ReLU(True)def forward(self, x):x = self.conv(x)x = self.relu(x)return x
  • train.py ( 训练网络 )
import os
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import GoogleNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))net = GoogleNet(num_classes=5, aux_logits=True, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0003)epochs = 30best_acc = 0.0save_path = './googleNet.pth'train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()logits, aux_logits2, aux_logits1 = net(images.to(device))   # 由于训练的时候会使用辅助分类器,所有相当于有三个返回结果loss0 = loss_function(logits, labels.to(device))loss1 = loss_function(aux_logits1, labels.to(device))loss2 = loss_function(aux_logits2, labels.to(device))loss = loss0 + loss1 * 0.3 + loss2 * 0.3loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))  # eval model only have last output layerpredict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()
  • predict.py ( 使用训练好的模型网络对图像分类 )
import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import GoogleNetdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)json_file = open(json_path, "r")class_indict = json.load(json_file)# create modelmodel = GoogleNet(num_classes=5, aux_logits=False).to(device)# load model weightsweights_path = "./googleNet.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)missing_keys, unexpected_keys = model.load_state_dict(torch.load(weights_path, map_location=device),strict=False)model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)print(print_res)plt.show()if __name__ == '__main__':main()

代码连接 https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/Test4_googlenet

【学习笔记】GoogleNet 网络结构相关推荐

  1. 深入解读GoogLeNet网络结构(附代码实现)

    前言 七夕了,看着你们秀恩爱,单身狗的我还是做俺该做的事吧! 在上一篇文章中介绍了VGG网络结构,VGG在2014年ImageNet 中获得了定位任务第1名和分类任务第2名的好成绩,而同年分类任务的第 ...

  2. 经典神经网络论文超详细解读(三)——GoogLeNet InceptionV1学习笔记(翻译+精读+代码复现)

    前言 在上一期中介绍了VGG,VGG在2014年ImageNet 中获得了定位任务第1名和分类任务第2名的好成绩,而今天要介绍的就是同年分类任务的第一名--GoogLeNet . 作为2014年Ima ...

  3. 【学习笔记】VGG 网络结构

    跟着大佬学图像分类系列,→ 传送门 ← 本博客图像分类系列文章传送门: AlexNet VGG(当前) GoogleNet ResNet 前言 图像分类是学习目标检测的"量变"内容 ...

  4. 动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet)

    动手深度学习笔记(四十)7.4. 含并行连结的网络(GoogLeNet) 7.4. 含并行连结的网络(GoogLeNet) 7.4.1. Inception块 7.4.2. GoogLeNet模型 7 ...

  5. YOLOV3 网络结构学习笔记

    注:本文非原创,文章内容都是引用以下文章中,本文只是记录学习笔记. yolo系列之yolo v3[深度解析]_木盏的博客-CSDN博客_yolo3 YOLO v3算法详解_'Atlas'的博客-CSD ...

  6. resnet50网络结构_学习笔记(一):分析resnet源码理解resnet网络结构

    最近在跑实验的过程中一直在使用resnet50和resnet34,为了弄清楚网络的结构和原理的实现,打开resnet的源码进行了学习. 残差网络学习的原理 针对神经网络过深而导致的学习准确率饱和甚至是 ...

  7. 深度学习入门之PyTorch学习笔记:卷积神经网络

    深度学习入门之PyTorch学习笔记 绪论 1 深度学习介绍 2 深度学习框架 3 多层全连接网络 4 卷积神经网络 4.1 主要任务及起源 4.2 卷积神经网络的原理和结构 4.2.1 卷积层 1. ...

  8. 经典神经网络论文超详细解读(八)——ResNeXt学习笔记(翻译+精读+代码复现)

    前言 今天我们一起来学习何恺明大神的又一经典之作: ResNeXt(<Aggregated Residual Transformations for Deep Neural Networks&g ...

  9. Tensorflow深度学习学习笔记

    Tensorflow学习笔记 一.Tensorflow基础及深度学习原理 1.Tensorflow中网络搭建的三种方法 1.keras.models.Sequential() 2.keras.mode ...

  10. 深度神经网络及目标检测学习笔记

    这是一段实时目标识别的演示, 计算机在视频流上标注出物体的类别, 包括人.汽车.自行车.狗.背包.领带.椅子等. 今天的计算机视觉技术已经可以在图片. 视频中识别出大量类别的物体, 甚至可以初步理解图 ...

最新文章

  1. 1370亿参数、接近人类水平,谷歌重磅推出对话AI模型LaMDA
  2. 会话管理隐患与防御 总结
  3. 在Java 17中偏向锁可算是废了
  4. 变量在原型链中的查找顺序
  5. win32 实现死锁的小例子
  6. mac回退jdk版本_mac中不同jdk版本切换
  7. SAP MM/FI 自动过账实现 OBYC 接口执行
  8. php中的echo、print,print_r、var_dump
  9. leetcode算法—无重复字符的最长子串 Longest Substring Without Repeating Characters
  10. 微信上让人反感的5种行为 敢不敢看看你是否也犯过
  11. JavaScript------字符串与HTML格式相互转换
  12. Green Plum测试报告
  13. C#窗体无法接受Keydown事件
  14. 常用Win32 API函数
  15. TS - 勉强入个门儿
  16. java stl 模型 切片_使用 Materialise magics 对 STL文件进行切片
  17. Win11+RTX3060显卡 配置cuda和cudnn
  18. 【python学习】如何将字典添加到字典
  19. 突破封闭 Web 系统的技巧之正面冲锋
  20. ae运动模糊怎么调整_ae运动模糊怎么用?为什么我打开开关后没什么效果,看不出来。有没有什么视屏可以看看?...

热门文章

  1. 传奇服务器包裹元宝数量显示,GOM传奇服务端测试记录元宝,传奇GM版本库测试区记录元宝数量教程...
  2. python的基础数据类型
  3. JS 怎么控制 checkbox 选中的问题
  4. 网上销售平台--需求分析(二)
  5. Java50道经典编程题:(七)输出不同种类字符个数 ——字符串的遍历及循环结构的使用
  6. Android-Rxjava 常用操作符
  7. linux怎样将文件夹设置共享,Linux操作系统下共享文件夹设置方法介绍
  8. 分享:ThinkPad E40无线网卡驱动安装 FOR CENTOS6.3
  9. win10如何删除microsoft账户并免密登录
  10. if lte IE if gte IE 浏览器兼容