三、VGG

论文下载地址

VGG
是Oxford的Visual Geometry Group的组提出的(大家应该能看出VGG名字的由来了)。该网络是在ILSVRC 2014上的相关工作,主要工作是证明了增加网络的深度能够在一定程度上影响网络最终的性能。VGG有两种结构,分别是VGG16和VGG19,两者并没有本质上的区别,只是网络深度不一样。

VGG原理

VGG16相比AlexNet的一个改进是采用连续的几个3x3的卷积核代替AlexNet中的较大卷积核(11x11,7x7,5x5)。对于给定的感受野(与输出有关的输入图片的局部大小),采用堆积的小卷积核是优于采用大的卷积核,因为多层非线性层可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。

简单来说,在VGG中,使用了3个3x3卷积核来代替7x7卷积核,使用了2个3x3卷积核来代替5*5卷积核,这样做的主要目的是在保证具有相同感知野的条件下,提升了网络的深度,在一定程度上提升了神经网络的效果。

比如,3个步长为1的3x3卷积核的一层层叠加作用可看成一个大小为7的感受野(其实就表示3个3x3连续卷积相当于一个7x7卷积),其参数总量为 3x(9xC^2) ,如果直接使用7x7卷积核,其参数总量为 49xC^2 ,这里 C 指的是输入和输出的通道数。很明显,27xC2小于49xC2,即减少了参数;而且3x3卷积核有利于更好地保持图像性质。

这里解释一下为什么使用2个3x3卷积核可以来代替5*5卷积核:

5x5卷积看做一个小的全连接网络在5x5区域滑动,我们可以先用一个3x3的卷积滤波器卷积,然后再用一个全连接层连接这个3x3卷积输出,这个全连接层我们也可以看做一个3x3卷积层。这样我们就可以用两个3x3卷积级联(叠加)起来代替一个 5x5卷积。

基本概念拓展——cnn感受野
在卷积神经网络中,决定某一层输出 结果中一个元素所对应的输入层的区域大 小,被称作感受野(receptive field)。通俗 的解释是,输出feature map上的一个单元 对应输入层上的区域大小。
比如上图为:
计算公式为:

 F(i)为第i层感受野 Stride为第i层的步距Ksize为卷积核或池化核尺寸

因此

Feature map: F = 1
Pool1:  F = (1 - 1) x 2 + 2 = 2
Conv1:  F = (2 - 1) x 2 + 3 = 5

具体如下图所示:

至于为什么使用3个3x3卷积核可以来代替7*7卷积核,推导过程与上述类似,大家可以自行绘图理解。

VGG网络结构
下面是VGG网络的结构(VGG16和VGG19都在):

  • VGG16包含了16个隐藏层(13个卷积层和3个全连接层),如上图中的D列所示

  • VGG19包含了19个隐藏层(16个卷积层和3个全连接层),如上图中的E列所示

VGG网络的结构非常一致,从头到尾全部使用的是3x3的卷积和2x2的max pooling。

虽然网络层数加深,但VGG在训练的过程中比AlexNet收敛的要快一些,主要因为:
(1)使用小卷积核和更深的网络进行的正则化;
(2)在特定的层使用了预训练得到的数据进行参数的初始化。对于较浅的网络,如网络A,可以直接使用随机数进行随机初始化,而对于比较深的网络,则使用前面已经训练好的较浅的网络中的参数值对其前几层的卷积层和最后的全连接层进行初始化。

VGG优点

  • VGGNet的结构非常简洁,整个网络都使用了同样大小的卷积核尺寸(3x3)和最大池化尺寸(2x2)。
  • 几个小滤波器(3x3)卷积层的组合比一个大滤波器(5x5或7x7)卷积层好:
  • 验证了通过不断加深网络结构可以提升性能。

VGG缺点

  • VGG耗费更多计算资源,并且使用了更多的参数(这里不是3x3卷积的锅),导致更多的内存占用(140M)。其中绝大多数的参数都是来自于第一个全连接层。VGG可是有3个全连接层啊!

注:很多pretrained的方法就是使用VGG的model(主要是16和19),VGG相对其他的方法,参数空间很大,最终的model有500多m,AlexNet只有200m,GoogLeNet更少,所以train一个vgg模型通常要花费更长的时间,所幸有公开的pretrained model让我们很方便的使用。

代码篇:VGG训练与测试

这里推荐两个开源库,训练请参考tensorflow-vgg
快速测试请参考VGG-in TensorFlow。

代码实现

注:本次训练集下载在AlexNet博客有详细解说:https://blog.csdn.net/weixin_44023658/article/details/105798326

#model.pyimport torch.nn as nn
import torch#分类网络
class VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = features#先会进行一个展平处理self.classifier = nn.Sequential( #定义分类网络结构nn.Dropout(p=0.5), #减少过拟合nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)#展平处理# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self): #初始化权重函数for m in self.modules():if isinstance(m, nn.Conv2d):# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:nn.init.constant_(m.bias, 0)#初始化偏置为0elif isinstance(m, nn.Linear):nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):#提取特征函数layers = [] #存放创建每一层结构in_channels = 3 #RGBfor v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)layers += [conv2d, nn.ReLU(True)]in_channels = vreturn nn.Sequential(*layers)   #通过非关键字参数传入cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], #数字为卷积层个数,M为池化层结构'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}#实例化VGG网络
def vgg(model_name="vgg16", **kwargs): #**字典try:cfg = cfgs[model_name]except:print("Warning: model number {} not in cfgs dict!".format(model_name))exit(-1)model = VGG(make_features(cfg), **kwargs) #return model#vgg_model = vgg(model_name='vgg13')
#train.pyimport torch.nn as nn
from torchvision import transforms, datasets
import json
import os
import torch.optim as optim
from model import vgg
import torch
import timedevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)#数据预处理,从头
data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}#data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
data_root = os.getcwd()
image_path = data_root + "/flower_data/"  # flower data set path
train_dataset = datasets.ImageFolder(root=image_path + "/train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()model_name = "vgg16"
net = vgg(model_name=model_name, num_classes=5, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0
save_path = './{}Net.pth'.format(model_name)
for epoch in range(20):# trainnet.train()running_loss = 0.0t1 = time.perf_counter()for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()#with torch.no_grad(): #用来消除验证阶段的loss,由于梯度在验证阶段不能传回,造成梯度的累计outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()print(time.perf_counter() - t1)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_data#optimizer.zero_grad()outputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

由于我硬件问题我的batch_size和轮数都比较小,所以最后test_accuracy都会比较低点

#predict.pyimport torch
from model import vgg
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("./sunflower.jpg")  #验证太阳花
#img = Image.open("./roses.jpg")     #验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0) #扩充维度,加bench维度# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
net = vgg(model_name=img, num_classes=5, init_weights=True)
# load model weights
model_weight_path = "./vgg16Net.pth"
model.load_state_dict(torch.load(model_weight_path)) #载入网络模型
model.eval() #关闭dropout方法
#跟踪准确梯度
with torch.no_grad():# predict classoutput = torch.squeeze(model(img)) #压缩掉beach维度predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy() #获取最大索引
#print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

参考自:
Amusi

VGG——CNN经典网络模型(pytorch实现)相关推荐

  1. CNN经典网络模型综述及发散思考(LeNet/ AlexNet/VGGNet/GoogLeNet/ResNet)

    目录 一. 背景 成功原因 设计目标 二. 经典网络模型 LeNet(1990 年) 网络特点 AlexNet(2012年) 网络特点 VGGNet(2014年) 网络特点 发散思考 GoogLeNe ...

  2. CNN经典网络模型(四):GoogLeNet简介及代码实现(PyTorch超详细注释版)

    目录 一.开发背景 二.网络结构 三.模型特点 四.代码实现 1. model.py 2. train.py 3. predict.py 4. spilit_data.py 五.参考内容 一.开发背景 ...

  3. CNN经典网络模型(二):AlexNet简介及代码实现(PyTorch超详细注释版)

    目录 一.开发背景 二.网络结构 三.模型特点 四.代码实现 1. model.py 2. train.py 3. predict.py 4. spilit_data.py 五.参考内容 一.开发背景 ...

  4. MobileNet(v1、v2)——CNN经典网络模型详解(pytorch实现)

    在之前的文章中讲的AlexNet.VGG.GoogLeNet以及ResNet网络,它们都是传统卷积神经网络(都是使用的传统卷积层),缺点在于内存需求大.运算量大导致无法在移动设备以及嵌入式设备上运行. ...

  5. DenseNet——CNN经典网络模型详解(pytorch实现)

    一.概述 论文:Densely Connected Convolutional Networks 论文链接:https://arxiv.org/pdf/1608.06993.pdf 代码的github ...

  6. LeNet-5——CNN经典网络模型详解(pytorch实现)

    一.LeNet-5 这个是n多年前就有的一个CNN的经典结构,主要是用于手写字体的识别,也是刚入门需要学习熟悉的一个网络. 原论文地址 输入:32*32的手写字体图片,这些手写字体包含0~9数字,也就 ...

  7. 【深度学习】ResNet——CNN经典网络模型详解(pytorch实现)

    建议大家可以实践下,代码都很详细,有不清楚的地方评论区见~ 1.前言 ResNet(Residual Neural Network)由微软研究院的Kaiming He等四名华人提出,通过使用ResNe ...

  8. GoogLeNet——CNN经典网络模型详解(pytorch实现)

    一.前言 论文地址:http://arxiv.org/abs/1602.07261 2014年,GoogLeNet和VGG是当年ImageNet挑战赛(ILSVRC14)的双雄,GoogLeNet获得 ...

  9. CNN经典网络模型:LeNet,Alexnet,VGGNet,GoogleNet,ReSNet

    关于卷积神经网络CNN,网络和文献中有非常多的资料,我在工作/研究中也用了好一段时间各种常见的model了,就想着简单整理一下,以备查阅之需.如果读者是初接触CNN,建议可以先看一看"Dee ...

最新文章

  1. 保护了无数医护人员的N95口罩,原来是华裔科学家和一位学生共同发明的!
  2. Deep learning前的图像预处理
  3. 使用SQLQuery
  4. “财务自由的15个阶段!说说你到哪个阶段了?”
  5. [Java基础]反射获取成员变量并使用
  6. python 插补数据_python 2020中缺少数据插补技术的快速指南
  7. [你必须知道的.NET]第二十八回:说说Name这回事儿
  8. sql服务器默认密码_搭建一个DNS服务器,轻松实现域名解析内容分发,访问速度提高N倍...
  9. 码云、coding拉取项目代码
  10. 追求代码质量: 不要被覆盖报告所迷惑
  11. 如何克隆服务器系统数据,Linode面板clone克隆功能实现服务器数据完整迁移
  12. AlexNet详解2
  13. Ibatis.Net 数据库操作(四)
  14. 任务管理三部曲 - 模板使用说明(超实用模板下载)
  15. 巧用批处理cmd快速切换IP地址
  16. FFT算法实现与分析MATLAB
  17. HandlerThread的使用场景和用法
  18. excel保存快捷键_Excel新手必备的5大技巧,看看你会几个?(附26个超实用快捷键)
  19. 【转自Testerhome】iOS 真机如何安装 WebDriverAgent
  20. 使用关键字搜索公众号文章,

热门文章

  1. 机器学习中的特征选择——决策树模型预测泰坦尼克号乘客获救实例
  2. PHP简易商城(一)概述
  3. 《论语》原文及其全文翻译 学而篇3
  4. 艾伟:memcached全面剖析–2.理解memcached的内存存储
  5. 北京大学,新增设置数据科学与工程博士点!
  6. 未来财富:呼啸而来的数字人民币
  7. 室内模拟高尔夫成潮流新贵,挥杆篙尔夫体现娱乐体验价值
  8. java nio 传统标准io socket 和nio socket比较与学习
  9. Go C 编程 第4课 变色魔法(魔法学院的奇幻之旅 Go C编程绘图)
  10. C# Hello World 实例