VGG16网络结构及代码

下图为VGG网络结构图,最常用的就是表中的D结构,16层结构(13层卷积+3层全连接层),卷积的stride为1,padding为1,maxpool的大小为2,stride为2(池化只改变图像的大小,不改变图像的深度)

vgg网络结构可以看作两部分:特征提取网络(连接层之前)+分类网络(3层全连接层)

VGG模型搭建

VGG模型一共分为两部分,特征提取部分和分类网络部分,我们分别进行搭建

特征提取网络

1、定义字典文件,定义了四个网络结构

cfgs = {  'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # 列表的数字代表卷积层卷积核的个数,字符M代表池化层的结构'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

2、定义一个函数,生成vgg网络第一部分:特征提取网络

def make_features(cfg: list): # 传入一个配置变量layers = []  # 定义一个空列表in_channels = 3for v in cfg:if v == "M": # 判断是否是池化层layers += [nn.MaxPool2d(kernel_size=2, stride=2)]else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) # v表示输出通道layers += [conv2d, nn.ReLU(True)]in_channels = v # 卷积之后,输出通道变为vreturn nn.Sequential(*layers) # *layers代表通过非关键字参数的形式传入进去

分类网络

1、定义VGG类

# vgg类
class VGG(nn.Module):  # features代表提取特征网络def __init__(self, features, num_classes=1000, init_weights=False):super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5), # 减少过拟合,50%比例随机失活神经元nn.Linear(512*7*7, 4096),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(4096, 4096),nn.ReLU(True),nn.Linear(4096, num_classes))if init_weights:self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7 展平操作x = torch.flatten(x, start_dim=1) # 从第一个维度开始展平,第0个维度是batch# N x 512*7*7x = self.classifier(x)return x# 初始化权重函数,会便利网络的每一个子模块,也就是遍历每一层def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d): # 如果当前层为卷积层# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight) # 初始化卷积核参数if m.bias is not None: # 如果卷积核有偏置,设置偏置为0nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear): # 如果当前层为全连接层nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)

2、实例化vgg

# 实例化vgg
def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("Warning: model number {} not in cfgs dict!".format(model_name))exit(-1)model = VGG(make_features(cfg), **kwargs) # **kwargs可变长度的字典变量return modelvgg_model = vgg(model_name='vgg13')

VGG训练

import os
import jsonimport torch
import torch.nn as nn
from torchvision import transforms, datasets
import torch.optim as optim
from tqdm import tqdmfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))
# 数据预处理data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224), # 随即裁剪transforms.RandomHorizontalFlip(), # 随机翻转transforms.ToTensor(), # 转为tensortransforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 标准化处理"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root pathimage_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),transform=data_transform["train"])train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}flower_list = train_dataset.class_to_idxcla_dict = dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str = json.dumps(cla_dict, indent=4)with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workersprint('Using {} dataloader workers every process'.format(nw))train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=nw)validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),transform=data_transform["val"])val_num = len(validate_dataset)validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=nw)print("using {} images for training, {} images for validation.".format(train_num,val_num))# test_data_iter = iter(validate_loader)# test_image, test_label = test_data_iter.next()model_name = "vgg16"net = vgg(model_name=model_name, num_classes=5, init_weights=True)net.to(device)loss_function = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.0001)epochs = 30best_acc = 0.0save_path = './{}Net.pth'.format(model_name)train_steps = len(train_loader)for epoch in range(epochs):# trainnet.train()running_loss = 0.0train_bar = tqdm(train_loader)for step, data in enumerate(train_bar):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,epochs,loss)# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():val_bar = tqdm(validate_loader)for val_data in val_bar:val_images, val_labels = val_dataoutputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += torch.eq(predict_y, val_labels.to(device)).sum().item()val_accurate = acc / val_numprint('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %(epoch + 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('Finished Training')if __name__ == '__main__':main()

VGG预测

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import vggdef main():device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load imageimg_path = "../tulip.jpg"assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)img = Image.open(img_path)plt.imshow(img)# [N, C, H, W]img = data_transform(img)# expand batch dimensionimg = torch.unsqueeze(img, dim=0)# read class_indictjson_path = './class_indices.json'assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)json_file = open(json_path, "r")class_indict = json.load(json_file)# create modelmodel = vgg(model_name="vgg16", num_classes=5).to(device)# load model weightsweights_path = "./vgg16Net.pth"assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_location=device))model.eval()with torch.no_grad():# predict classoutput = torch.squeeze(model(img.to(device))).cpu()predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()print_res = "class: {}   prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)print(print_res)plt.show()if __name__ == '__main__':main()

参考视频:https://www.bilibili.com/video/BV1i7411T7ZN?spm_id_from=333.999.0.0

经典卷积神经网络---VGG16网络相关推荐

  1. 卷积神经网络——vgg16网络及其python实现

    1.介绍      VGG-16网络包括13个卷积层和3个全连接层,网络结构较LeNet-5等网络变得十分复杂,但同时也有不错的效果.VGG16有强大的拟合能力在当时取得了非常的效果,但同时VGG也有 ...

  2. 【深度学习基础】经典卷积神经网络

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 导语 卷积神经网络(Convolutional Neural Ne ...

  3. AI基础:经典卷积神经网络

    导语 卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深 ...

  4. 一文总结经典卷积神经网络CNN模型

    一般的DNN直接将全部信息拉成一维进行全连接,会丢失图像的位置等信息. CNN(卷积神经网络)更适合计算机视觉领域.下面总结从1998年至今的优秀CNN模型,包括LeNet.AlexNet.ZFNet ...

  5. CNN(经典卷积神经网络)来了!

    导语 卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Networks),是深 ...

  6. 卷积神经网络VGG16权重数量的计算和理解(转载)

    VGG16网络结构是: _________________________________________________________________ Layer (type)           ...

  7. 卷积神经网络resent网络实践

    文章目录 前言 一.技术介绍 二.实现途径 三.总结 前言 上篇文章,讲了经典卷积神经网络-resnet,这篇文章通过resnet网络,做一些具体的事情. 一.技术介绍 总的来说,第一步首先要加载数据 ...

  8. tensorflow预定义经典卷积神经网络和数据集tf.keras.applications

    自己开发了一个股票软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  tensorflow预定义经 ...

  9. vgg16卷积层的计算量_卷积神经网络VGG16参数数量的计算和理解

    先说一下我对神经网络的理解:神经网络就是用巨量的简单的非线性函数组合起来拟合复杂的未知函数.比如,人类识别不同的物体.识别不同动物.不同植物是个复杂的未知函数.虽然未知,但没事,我们的神经网络可以用巨 ...

  10. Tensorflow系列 | TensorFlowNews五大经典卷积神经网络介绍

    编译 | fendouai 编辑 | 安可 [导读]:这个系列文章将会从经典的卷积神经网络历史开始,然后逐个讲解卷积神经网络结构,代码实现和优化方向.下一篇文章将会是 LeNet 卷积神经网络结构,代 ...

最新文章

  1. OpenCV中minAreaRect()最小外接矩形 cvBoxPoints()计算矩形顶点 RotatedRect和CvBox2D详解
  2. Linux用户查看系统有多少用户在登录
  3. sklearn.preprocessing下的数据标准化(scale、MinMaxScaler)
  4. python中匿名函数的作用_什么是Python中的匿名函数
  5. html间数据传送,Express框架与html之间如何进行数据传递(示例代码)
  6. jumpserver 使用教程_Jumpserver之快速入门
  7. 5月第二周全球五大顶级域名总量新增10.5万个
  8. 读者投稿 | 写Go满一年啦,来聊聊进程、线程与协程
  9. Flutter: MobX和flutter_mobx状态管理器
  10. 520用Java制作一个表白app
  11. Python使用PIL工具、ImageDraw函数在图像上根据坐标点依次连线画矩形框,可画选择倾斜的框和折线
  12. 教你如何使用Python破解WIFI密码
  13. ​网易游戏实时 HTAP 计费风控平台建设
  14. 免费的视频转换软件。包括qlv全可以转
  15. WPS Excel表格日期转文本 为数字问题
  16. 选择云服务器主要看那方面的参数和性能
  17. so文件反汇编反编译到C源码
  18. Android 颜色值转换
  19. 这五类人可以了解下蛋白粉哪个牌子好!
  20. 【C语言】(错题整理) 寻找完数、字符串中各类字符数的统计、最大公约数和最小公倍数、回文数计算 (循环、函数相关内容)

热门文章

  1. 常用地图经纬度转换,以及遇到的问题和解决方式
  2. 【图像标注】使用vue3实现图像标注功能
  3. 手机屏幕测试html,华为手机屏幕检测代码是什么
  4. 一键生成 Android 录屏 gif 的脚本
  5. 虽迟但到,手眼标定代码实现篇
  6. WinRAR 6.0 永久去除广告
  7. 旁站,子域名,C段的含义
  8. PMP®考试通过率多少
  9. 大疆文档(2)-指南
  10. java编写一个汽车类,有属性:品牌、型号、排量、速度,有方法:启动、加速、转弯、刹车、息火...