文章目录

  • 将VGG分成两部分
    • 提取特征网络结构
    • 分类网络结构
  • model
    • 输入:非关键字参数或有序字典
      • P[ython-非关键字参数和关键字参数(*args **kw)](https://blog.csdn.net/weixin_44023658/article/details/105925199?utm_medium=distribute.wap_relevant.none-task-blog-title-1)
  • predict
    • 很多人会在RGB减去这三个值,是IMAGENET的三个通道上的均值,迁移学习可能要减
  • train

将VGG分成两部分

另外这个网络很大,跑得很慢,数据要求大

提取特征网络结构

分类网络结构

model

import torch.nn as nn
import torchclass VGG(nn.Module):def __init__(self, features, num_classes=1000, init_weights=False):#features传入super(VGG, self).__init__()self.features = featuresself.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(512*7*7, 2048),nn.ReLU(True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(True),nn.Linear(2048, num_classes))if init_weights:#还要判断下是否需要初始化参数,传入的参数为true的话就初始化self._initialize_weights()def forward(self, x):# N x 3 x 224 x 224x = self.features(x)# N x 512 x 7 x 7x = torch.flatten(x, start_dim=1)#展平#start_dim从哪个维度开始进行展平处理,第0个维度是batch维度# N x 512*7*7x = self.classifier(x)return xdef _initialize_weights(self):#初始化权重函数,遍历每一层for m in self.modules():if isinstance(m, nn.Conv2d):#如果卷积层,就用xavier方法# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')nn.init.xavier_uniform_(m.weight)if m.bias is not None:#如果采用了偏置就要把偏置全置0nn.init.constant_(m.bias, 0)elif isinstance(m, nn.Linear):#全连接层的话nn.init.xavier_uniform_(m.weight)# nn.init.normal_(m.weight, 0, 0.01)nn.init.constant_(m.bias, 0)def make_features(cfg: list):#传入配置变量,只要传入对应配置的列表就行layers = []in_channels = 3#RGBfor v in cfg:if v == "M":layers += [nn.MaxPool2d(kernel_size=2, stride=2)]#池化核的大小和步距都是2else:conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)#stride默认为1所以没写layers += [conv2d, nn.ReLU(True)]in_channels = v#输出的深度变成V了return nn.Sequential(*layers)#将列表作为(非关键字参数)输入

输入:非关键字参数或有序字典

Python-非关键字参数和关键字参数(*args **kw)


#模型配置文件
cfgs = {'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],#A配置,数字代表卷积层个数,M是池化层结构从(最大池化下采样)'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],#B配置'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],#D配置'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],#E配置
}def vgg(model_name="vgg16", **kwargs):try:cfg = cfgs[model_name]except:print("Warning: model number {} not in cfgs dict!".format(model_name))exit(-1)model = VGG(make_features(cfg), **kwargs)#第一个参数是features,后面是关键字是可变长度的的字典变量(num_classes=1000, init_weights=False)return model

predict

import torch
from model import vgg
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import jsondata_transform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# load image
img = Image.open("../tulip.jpg")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)# read class_indict
try:json_file = open('./class_indices.json', 'r')class_indict = json.load(json_file)
except Exception as e:print(e)exit(-1)# create model
model = vgg(model_name="vgg16", num_classes=5)
# load model weights
model_weight_path = "./vgg16Net.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval()
with torch.no_grad():# predict classoutput = torch.squeeze(model(img))predict = torch.softmax(output, dim=0)predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)])
plt.show()

很多人会在RGB减去这三个值,是IMAGENET的三个通道上的均值,迁移学习可能要减

train

import torch.nn as nn
from torchvision import transforms, datasets
import json
import os
import torch.optim as optim
from model import vgg
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)data_transform = {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),"val": transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # get data root path
image_path = data_root + "/data_set/flower_data/"  # flower data set pathtrain_dataset = datasets.ImageFolder(root=image_path+"train",transform=data_transform["train"])
train_num = len(train_dataset)# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# write dict into json file
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices.json', 'w') as json_file:json_file.write(json_str)batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size, shuffle=True,num_workers=0)validate_dataset = datasets.ImageFolder(root=image_path + "val",transform=data_transform["val"])
val_num = len(validate_dataset)
validate_loader = torch.utils.data.DataLoader(validate_dataset,batch_size=batch_size, shuffle=False,num_workers=0)# test_data_iter = iter(validate_loader)
# test_image, test_label = test_data_iter.next()model_name = "vgg16"#取16
net = vgg(model_name=model_name, num_classes=5, init_weights=True)
net.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.0001)best_acc = 0.0
save_path = './{}Net.pth'.format(model_name)
for epoch in range(30):# trainnet.train()running_loss = 0.0for step, data in enumerate(train_loader, start=0):images, labels = dataoptimizer.zero_grad()outputs = net(images.to(device))loss = loss_function(outputs, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()# print train processrate = (step + 1) / len(train_loader)a = "*" * int(rate * 50)b = "." * int((1 - rate) * 50)print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")print()# validatenet.eval()acc = 0.0  # accumulate accurate number / epochwith torch.no_grad():for val_data in validate_loader:val_images, val_labels = val_dataoptimizer.zero_grad()outputs = net(val_images.to(device))predict_y = torch.max(outputs, dim=1)[1]acc += (predict_y == val_labels.to(device)).sum().item()val_accurate = acc / val_numif val_accurate > best_acc:best_acc = val_accuratetorch.save(net.state_dict(), save_path)print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, running_loss / step, val_accurate))print('Finished Training')

4.2 使用pytorch搭建VGG网络相关推荐

  1. pytorch 搭建 VGG 网络

    目录 1. VGG 网络介绍 2. 搭建VGG 网络 3. code 1. VGG 网络介绍 VGG16 的网络结构如图: VGG 网络是由卷积层和池化层构成基础的CNN 它的CONV卷积层的参数全部 ...

  2. Pytorch搭建FCN网络

    Pytorch搭建FCN网络 前言 原理 代码实现 前言 FCN 全卷积网络,用卷积层替代CNN的全连接层,最后通过转置卷积层得到一个和输入尺寸一致的预测结果: 原理 为了得到更好的分割结果,论文中提 ...

  3. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  4. 实战:使用Pytorch搭建分类网络(肺结节假阳性剔除)

    实战:使用Pytorch搭建分类网络(肺结节假阳性剔除) 阅前可看: 实战:使用yolov3完成肺结节检测(Luna16数据集)及肺实质分割 其中的脚本资源getMat.py文件是对肺结节进行切割. ...

  5. Pytorch搭建LeNet5网络

    本讲目标:   介绍Pytorch搭建LeNet5网络的流程. Pytorch八股法搭建LeNet5网络 1.LeNet5网络介绍 2.Pytorch搭建LeNet5网络 2.1搭建LeNet网络 2 ...

  6. 使用PyTorch搭建ResNet50网络

    ResNet18的搭建请移步:使用PyTorch搭建ResNet18网络并使用CIFAR10数据集训练测试 ResNet34的搭建请移步:使用PyTorch搭建ResNet34网络 ResNet101 ...

  7. pytorch搭建孪生网络比较人脸相似性

    参考文献: 神经网络学习小记录52--Pytorch搭建孪生神经网络(Siamese network)比较图片相似性_Bubbliiiing的博客-CSDN博客_神经网络图片相似性 Python - ...

  8. 关于用pytorch构建vgg网络实现花卉分类的学习笔记

    需要的第三方库: pytorch.matplotlib.json.os.tqdm 一.model.py的编写 (1)准备工作 1.参照vgg网络结构图(如下图1),定义一个字典,用于存放各种vgg网络 ...

  9. 使用Keras来搭建VGG网络

    上述VGG网络结构图 VGG网络是在Very Deep Convolutional Network For Large-Scale Image Recognition这篇论文中提出,VGG是2014年 ...

最新文章

  1. 深度学习深陷可解释性泥淖,而这个研究领域正逐步焕发生机
  2. Linux下 数据文件 效验问题
  3. ABAP程序:查找TC相关的出口。
  4. VS2017创建ASP.NET Core Web程序
  5. canvas笔记-globalAlpha和globaleCompositeOperation的使用
  6. 用Python给你的女神带上口罩~
  7. 贝叶斯分类器用于文本分类: Multinomial Naïve Bayes
  8. Java7并发编程指南——第四章:线程执行器
  9. Linux学习笔记3
  10. kali 安装KVM教程---》给自己的笔记
  11. flowable设计器节点属性扩展_gooflow设计器API说明书
  12. 【CTS】Ubuntu下安装CTS测试环境
  13. Java实现 蓝桥杯 算法训练 递归求二项式系数
  14. 浏览器如何工作:在现代web浏览器场景的之下
  15. 自然语言处理 第二期
  16. Android系统分区备份与还原
  17. 通过大数据分析如何提升客户体验
  18. UNIX2DOS/DOS2UNIX for Windows
  19. 求救帮忙看看飞思卡尔的代码错误
  20. QMdiSubWindow

热门文章

  1. VS2010/MFC编程入门之五十(图形图像:GDI对象之画笔CPen)
  2. 1 Two Sum (Array)
  3. MongoDB 学习(一)安装配置和简单应用
  4. jquery实现checkbox的单选和全选
  5. Powershell常用命令
  6. 【C】——常用C时间库函数
  7. 一种可行的简单的Scene结构in cocos2d
  8. 模拟器上安装不能被卸载的apk
  9. 没有RunInstallerAttribute.Yes的公共安装程序。
  10. 选文可以学计算机专业,是不是文理科都可以报计算机专业?