pytorch 搭建 VGG 网络
目录
1. VGG 网络介绍
2. 搭建VGG 网络
3. code
1. VGG 网络介绍
VGG16 的网络结构如图:
VGG 网络是由卷积层和池化层构成基础的CNN
它的CONV卷积层的参数全部是由 stride = 1,padding = 1
它的maxpool 最大池化层的参数都是 size = 2 , stride = 2 组成的
VGG 网络的亮点是 它的卷积层全部都是由 3*3 的小型卷积核连续进行的,通过重复进行---卷积层重叠2-4次,然后再由池化层将size 减半进行处理
VGG 网络提出了一个新的概念,就是通过叠加 3*3 的卷积核来替代大的卷积核,这样可以减少网络训练的参数。2 个3*3 卷积核连续卷积代替 5*5 的卷积核,3个 3*3 的卷积核连续卷积代替 7*7 的卷积核
再介绍怎么可以通过连续卷积小的kernel 可以替代大的kernel 之前,先介绍一下感受野
感受野就是说,卷积层的输出 对应的输入区域的范围。例如下图所示,输入是 9*9 大小的图片,经过卷积层的输出size为:output = (9 - 3 + 2 * 0) / 2 + 1 =4 4*4大小的输出,然后经过池化层变为 2*2 的输出,那也就是说最后2*2的一个像素的图像是由卷积后2*2像素的大小决定的,也是由输入图像 5*5 范围内的图像决定的。那么这里的2*2就是池化后一个像素点的感受野,5*5就是卷积后2*2 图像的感受野,也可以说输入的5*5 是卷积-池化后一个像素点的感受野
所以计算卷积后区域大小的公式: ,反过来就是感受野的计算公式
因此感受野size 的计算公式:
TIP:这里不计算pad的原因,是因为这里pad的作用大都是防止图像缩小,而这里的证明就是为了让图像通过CONV层提取关键特征的
感受野介绍完,我们就可以了解为什么连续小的卷积核等于大的卷积核的运算了
假设输出一个像素点,那么对应 3*3 卷积核的感受野是3*3大小的,再往前对应的 3*3 卷积核的感受野是5*5大小的,再往前对应的 3*3 卷积核的感受野是7*7大小的。那么如果对7*7大小的图片做卷积,用kernel_size 是7*7的话,带入公式output = (7 - 7)/2 + 1 =1 对应的也是一个像素点
因此:2 个3*3 卷积核连续卷积代替 5*5 的卷积核,3个 3*3 的卷积核连续卷积代替 7*7 的卷积核
这样做的好处就是可以减少卷积核的参数:因为3*3*3 = 27 个权重参数,7*7 =49 个权重参数。这样可能感受不到差别,但是算上输出的channel和输出的channel呢?前者就是27*C*C,而后者是49*C*C,这样参数差别就很大了。
所以,CONV卷积层的参数全部是由 stride = 1,padding = 1的情况下,连续2次3*3卷积等于5*5的卷积,连续3次3*3的卷积等于7*7的卷积
2. 搭建VGG 网络
VGG网络的结构有很多种形式,这里常用的是D,16个权重层的形式
首先,先建立一个字典文件存放不同VGG网络的配置列表
然后通过传入对应的key,建立对应的VGG网络卷积和池化层
然后,通过make_features 创建的特征提取层,可以建立最终的VGG网络
最后就是定义生成VGG网络的函数
这里vgg参数传递的顺序为:
实参里面的vgg16-->形参model_name-->cfgs取出key对应的value赋值给cfg-->cfg传递给make_feature建立卷积层-池化层layers,返回给nn.Sequential-->最后传递给VGG里面的feature生成特征提取层
生成的VGG16为:
3. code
因为网络太大 , 代码跑了很久都没有结果 , 所以这里就不放训练和预测的结果了
训练和预测的代码也不做讲解了,和 pytorch 搭建 LeNet 网络对 CIFAR-10 图片分类 代码的重合度很高
model代码
import torch.nn as nn
import torchclass VGG(nn.Module): # 定义VGG网络def __init__(self, features, num_classes=1000): # num_classed 为分类的个数super(VGG, self).__init__()self.features = features # 特征提取层通过make_features 创建self.classifier = nn.Sequential(nn.Dropout(p=0.5), # dropout 随机失活nn.Linear(512*7*7, 2048), # 特征提取最后的size是(512*7*7)nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, num_classes))def forward(self, x):x = self.features(x) # 特征提取层x = torch.flatten(x, start_dim=1) # ddata维度为(batch_size,512,7,7),从第二个维度开始flattenx = self.classifier(x) # 分类层return xdef make_features(cfg: list): # 生成特征提取层,就是VGG前面的卷积池化层layers = [] # 保存每一层网络结构in_channels = 3 # 输入图片的深度channels,起始输入是RGB 3 通道的for v in cfg: # 遍历配置列表 cfgsif v == "M": # M 代表最大池化层,VGG中maxpool的size=2,stride = 2layers += [nn.MaxPool2d(kernel_size=2, stride=2)] # M 代表最大池化层else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) # 数字代表卷积核的个数==输出的channelslayers += [conv2d, nn.ReLU(True)] # 添加卷积层in_channels = v # 输出的channels == 下次输入的channelsreturn nn.Sequential(*layers) # 解引用,将大的list里面的小list拿出来# 特征提取层的 网络结构参数
cfgs = { # 建立网络的字典文件,对应的key可以生成对应网络结构参数的value值'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 数字代表卷积核的个数,M代表池化层'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}# 定义生成VGG 网络函数
def vgg(model_name="vgg16", num_classes = 10): # 创建VGG网络,常用的为 VGG16 结构,如果不指定分类个数,默认是10cfg = cfgs[model_name] # 先定义特征提取层的结构model = VGG(make_features(cfg), num_classes=num_classes) # 将cfgs里面某个参数传给make_features,并且生成VGG netreturn modelnet = vgg(model_name='vgg16',num_classes=5)
print(net)
train部分代码:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from model import vgg # 应该导入创建网络的vgg,而不是空的框架VGGdata_transform = transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 图像预处理
batch_size = 32# 载入训练集
train_dataset= torchvision.datasets.CIFAR10(root='./data',train=True,download=False,transform=data_transform) # 下载数据集
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True) # 读取数据集# 载入测试集
test_dataset = torchvision.datasets.CIFAR10(root='./data',train=False,download=False,transform=data_transform) # 下载数据集
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size, shuffle=False) # 读取数据集classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # 十个分类的labelsnet = vgg(model_name='vgg16', num_classes=10) # 实例化网络
loss_function = nn.CrossEntropyLoss() # 定义交叉熵损失函数
optimizer = optim.Adam(net.parameters(), lr=0.0001) # 定义优化器best_acc = 0.0
save_path = './VGG.pth' # 保存的路径for epoch in range(5):net.train() # 开启dropoutrunning_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad() # 梯度下降outputs = net(images) # 前向传播loss = loss_function(outputs, labels) # 计算损失loss.backward() # 反向传播optimizer.step() # 梯度更新running_loss += loss.item()# testnet.eval() # 关闭dropoutacc = 0.0total = 0with torch.no_grad():for test_data in test_loader:test_images, test_labels = test_dataoutputs = net(test_images)predicted = torch.max(outputs, dim=1)[1]acc += (predicted == test_labels).sum().item()total += test_labels.size(0) # total += batch_sizeaccurate = acc / total # 计算正确率print('[epoch %d] train_loss: %.3f accuracy: %.3f' %(epoch + 1, running_loss /step, accurate))if accurate > best_acc:best_acc = accuratetorch.save(net.state_dict(), save_path)print('Finished Training')
predict部分代码:
import torch
from PIL import Image
from torchvision import transforms
from model import vggdata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])img = Image.open('./dog.png') # 载入图片
img = data_transform(img) # 预处理
img = torch.unsqueeze(img, dim=0) # 增加维度classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')model = vgg(model_name="vgg16", num_classes=10)
model.load_state_dict(torch.load('./VGG.pth')) # 读取网络参数
model.eval() # 预测的时候不需要随机失活with torch.no_grad():output = model(img)predict = torch.max(output, dim=1)[1]print(classes[int(predict)])
pytorch 搭建 VGG 网络相关推荐
- 4.2 使用pytorch搭建VGG网络
文章目录 将VGG分成两部分 提取特征网络结构 分类网络结构 model 输入:非关键字参数或有序字典 P[ython-非关键字参数和关键字参数(*args **kw)](https://blog.c ...
- Pytorch搭建FCN网络
Pytorch搭建FCN网络 前言 原理 代码实现 前言 FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果: 原理 为了得到更好的分割结果,论文中提 ...
- 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记
使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...
- 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)
实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...
- Pytorch搭建LeNet5网络
本讲目标: 介绍Pytorch搭建LeNet5网络的流程. Pytorch八股法搭建LeNet5网络 1.LeNet5网络介绍 2.Pytorch搭建LeNet5网络 2.1搭建LeNet网络 2 ...
- 使用PyTorch搭建ResNet50网络
ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...
- pytorch搭建孪生网络比较人脸相似性
参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...
- 关于用pytorch构建vgg网络实现花卉分类的学习笔记
需要的第三方库: pytorch.matplotlib.json.os.tqdm 一.model.py的编写 (1)准备工作 1.参照vgg网络结构图(如下图1),定义一个字典,用于存放各种vgg网络 ...
- 使用Keras来搭建VGG网络
上述VGG网络结构图 VGG网络是在Very Deep Convolutional Network For Large-Scale Image Recognition这篇论文中提出,VGG是2014年 ...
最新文章
- 科研助力|计算机科学方向一对一科研项目
- Leetcode刷题 463题:岛屿的周长(基于Java语言)
- java default locale_Java JSON.defaultLocale方法代码示例
- 内存体系 用共享段于进程间联系
- mysql 存储过程 嵌套if_mysql存储过程if嵌套if的写法
- ElasticSearch 知识点整理(入门)
- python_统计数组中指定范围的数据占的比例
- Linux Cgroups详解(二)
- 【原】使用IPV6,nbsp;10M/s高速BT互传…
- 软考之系统架构师考试经验分享
- 蓝牙、Wifi与ZigBee无线传输技术中,谁比较占有优势
- scala方法抛出异常_Scala异常| Scala方法如何引发异常?
- 简单的有限状态机Unity独家写法(呸~厚颜无耻之人)
- VMware Horizon 8 2111 部署系列(八)准备虚拟机模板
- Cortex M3 数据观察点与跟踪(DWT)
- ROS msg 文件修改 报错
- springboot使用flyway
- 2022开源PHP留言反馈管理系统 v2.0
- sqlmap中的columns哪里看_ROC,AUC 还是看我的吧,别人都千篇一律
- Dusk network 生态图
热门文章
- 二次曲线指数平滑怎么用计算机运行,利用Excel进行指数平滑分析(2)
- 程序xf—adsk20无法打开 mac M1芯片看这里
- UnicodeDecodeError: 'utf-8' codec can't decode byte 0xf3 in position 4645: invalid continuation byte
- IT软件开发常用英语词汇
- 读书笔记之大数据计算模式
- win7虚拟计算机名,Win7笔记本电脑启用虚拟wifi共享上网(图文介绍)
- 计算机重启后桌面文件没有反应,电脑重启后桌面文件丢失背景变黑怎么办
- python调试器 ipdb
- 结合RBAC模型讲解权限管理系统需求及表结构创建
- 2023年中职网络安全技能竞赛网页渗透(审计版)