关于用VGG19网络来做8分类任务的总结
- module.py
这里对于VGG19的网络模型只进行了一点改动,就是最后一层输出层,定义了我需要输出的类。修改num_class参数即可。这里input的图片大小是224,也可以自由修改。
import torch.nn as nnclass VGG(nn.Module):# initialize modeldef __init__(self, img_size=224, input_channel=3, num_class=8):###img_size=224代表图片尺寸大小,num_class代表图片种类的数量super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64), # default parameter:nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2,padding=0))self.conv3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),)self.conv4 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv5 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv6 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv7 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv8 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv9 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv10 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv11 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv12 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv13 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv14 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv15 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv16 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.fc17 = nn.Sequential(nn.Linear(int(512 * img_size * img_size / 32 / 32), 4096),nn.ReLU(inplace=True),nn.Dropout(p=0.5) # 默认就是0.5)self.fc18 = nn.Sequential(nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(p=0.5))self.fc19 = nn.Sequential(nn.Linear(4096, num_class))self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7,self.conv8, self.conv9, self.conv10, self.conv11, self.conv12, self.conv13, self.conv14,self.conv15, self.conv16]self.fc_list = [self.fc17, self.fc18, self.fc19]print("VGG Model Initialize Successfully!")# forwarddef forward(self, x):for conv in self.conv_list: # 16 CONVx = conv(x)output = x.view(x.size()[0], -1)for fc in self.fc_list: # 3 FCoutput = fc(output)return outputif __name__ == '__main__':vgg19 = VGG()# 检查模型每一层的参数# print(vgg19)
- dataset.py
写这个文件的时候出现了很多次报错,关于标签和图片格式的报错。所以在文件的最后,有保留检验时候的代码。修改的时候,方便查看一些输出和变量类型。
import os
import torch
import glob
import numpy as np
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader,Datasetclass Mydata(Dataset):def __init__(self,root,transforms=None):# 初始化函数,读取所有data_path下的图片self.root = rootself.transform = transformsself.classes_data = {}for name in sorted(os.listdir(os.path.join(root))):self.classes_data[name] = len(self.classes_data)self.images = []for name in self.classes_data.keys():self.images += glob.glob(os.path.join(root, name, '*')) # print(self.classes_data)def __getitem__(self,index):#根据索引index返回dataset[index]img_path = self.images[index]#根据索引index获取图片路径img = Image.open(img_path).convert('RGB')# 读取该图片
# target= img_path.split('\\')[-1].split('.')[0]if self.transform is not None:img = self.transform(img) target= self.classes_data[img_path.split('/')[-2]]target = torch.tensor(target)# 根据该图片的路径名获取该图片的label,具体根据路径名进行分割。我这里是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,选取最后一个为['cat.0.jpg'],然后使用"."分割,选取[cat]作为该图片的标签return img,target def __len__(self):return len(self.images)def Myclasses(self):return self.classes_data########下面的是测试这个自定义文件的dataset文件能否正常运行,以及上面的函数在调用后的功能测试
# if __name__=='__main__':
# train_root = "./dam-v4/train"
# mydata = Mydata(train_root, transforms=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()]))
# print(mydata.classes)# dataloader = DataLoader(mydata,batch_size=32,shuffle=True)#使用DataLoader加载数据# # # # # 查看batch
# for data in dataloader:
# imgs,targets = data
- train.py
训练代码没什么好讲的,就是导入数据,模型,然后开始训练
这里我只打印了train的loss,acc和test的acc
import os
import time
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import numpy as np
from module_vgg import VGG
from Mydataset import Mydata#######定义一会要调用的训练和验证的函数
def train_model(model, criterion, optimizer, num_epochs=25):since = time.time()# 记录训练过程的指标print('开始训练')for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 迭代数据train_loss = 0train_acc = 0# Iterate over data.for i, data in enumerate(dataloaders['train']):inputs , labels = datainputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()# 清空梯度with torch.set_grad_enabled(True):outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()# 反向传播计算梯度optimizer.step()# 更新网络train_loss += loss.item() * inputs.size(0)train_acc += (outputs.argmax(1)==labels).sum()print('Epoch{}acc: '.format(epoch),train_acc.item()/dataset_sizes['train'])print('Epoch{}loss: '.format(epoch),train_loss/dataset_sizes['train'])time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))return model####### load data 加载数据
data_dir = "./dam-v4"
input_shape = 224 ##统一图片大小
batch_size = 32image_datasets = {x: Mydata(os.path.join(data_dir, x),transforms=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()]))for x in ['train', 'validation']
}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size,shuffle=True, num_workers=4)for x in ['train', 'validation']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'validation']}
# print(dataset_sizes['train'],dataset_sizes['validation'])class_names = image_datasets['train'].Myclasses()
print("数据读取成功")
print('训练集和验证集图片数量:',dataset_sizes)
print('图片分类:',class_names)
# {'train': 3277, 'validation': 351}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")####导入模型开始训练
vgg_based = VGG()
vgg_based = vgg_based.to(device) ## 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)
####模型开始训练vgg_based = train_model(vgg_based, criterion, optimizer_ft, num_epochs=25)####对模型进行测试vgg_based.eval() #在验证状态
val_loss_sum = 0.0
val_metric_sum = 0.0
val_step =1
with torch.no_grad(): # 验证的部分,不是训练所以不要带入梯度print('验证状态')for i, data_val in enumerate(dataloaders['validation']):features,labels = data_valfeatures = features.to(device)labels = labels.to(device)pred = vgg_based(features)val_loss = criterion(pred,labels)val_step += 1val_metric_sum += (pred.argmax(1)==labels).sum()# val_metric_sum = val_metric_sum*features.size(0)print(val_metric_sum)print("模型在验证集上的准确率为{}".format(val_step,(val_metric_sum.item()/dataset_sizes['validation'])))
关于用VGG19网络来做8分类任务的总结相关推荐
- 【PyTorch】构造VGG19网络进行本地图片分类(超详细过程)——项目介绍
本篇博客主要解决以下3个问题: 如何自定义网络(以VGG19为例). 如何自建数据集并加载至模型中. 如何使用自定义数据训练自定义模型. 第一篇:[PyTorch]构造VGG19网络进行本地图片分类( ...
- 论文解读:胶囊网络在小样本做文本分类中的应用(下)
Dynamic Memory Induction Networks for Few-Shot Text Classification 论文提出Dynamic Memory Induction Netw ...
- 使用CNN做文本分类——将图像2维卷积换成1维
使用CNN做文本分类from __future__ importdivision, print_function, absolute_importimporttensorflow as tfimpor ...
- php 相册分类,这款不需要网络就可以智能识别分类照片,让你的相册不再混乱...
记得在<我的前半生>中吴越饰演的凌玲凭借着一手整理的才能把陈俊生迷的死死的,而「EASY 相册」则凭借一手智能整理的本领让我抛弃了其他相册管理 APP,专注与它-- 作为一名热爱生活,热爱 ...
- 分享 | 基于图像分类网络ResNet50_vd实现桃子分类
随着时代的快速发展,人工智能已经融入我们生活的方方面面.中国的农业也因人工智能而受益进入高速发展阶段.现今,看庄稼长势有卫星遥感技术,水果分拣有智能分拣系统,灌溉施肥有自动化机械-- 具体以水果分拣场 ...
- 使用bert模型做句子分类
使用bert模型微调做下游任务,在goole发布的bert代码和huggingface的transformer项目中都有相应的任务,有的时候只需要把代码做简单的修改即可使用.发现代码很多,我尝试着自己 ...
- 【深度学习】使用tensorflow实现VGG19网络
转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...
- 贝叶斯分类器做文本分类案例
贝叶斯分类器做文本分类 文本分类是现代机器学习应用中的一大模块,更是自然语言处理的基础之一.我们可以通过将文字数据处理成数字数据,然后使用贝叶斯来帮助我们判断一段话,或者一篇文章中的主题分类,感情倾向 ...
- 基于图像分类网络ResNet50_vd实现桃子分类
基于图像分类网络ResNet50_vd实现桃子分类 随着时代的快速发展,人工智能已经融入我们生活的方方面面.中国的农业也因人工智能而受益进入高速发展阶段.现今,看庄稼长势有卫星遥感技术,水果分拣有智能 ...
- transformer做文本分类的keras实现完整版
背景 目前csdn上搜索到的keras的版本实现,排在前面的是: https://blog.csdn.net/xiaosongshine/article/details/86595847 但是,这个文 ...
最新文章
- dateutils 工具类_五金工具泡壳封边机
- log4j每天产生一日志文件
- 我关注的那些程序员大佬
- extjs tree下拉列表_Extjs中ComboBoxTree实现的下拉框树效果(自写)_extjs
- oracle数据库查看建表语句,oracle 查看建表语句
- 计算机网络的商业价值和应用,计算机网络建模数学工具的分析与比较
- JavaScript打开APP
- 零的突破!6所高校、2所“双非”顶刊发文
- 计网 | 网络层 SDN控制器 / 远程控制器
- canvas教程5-绘制路径
- AspNetCore.Mvc写Get方法运行显示该网页无法正常运作 http err 500问题
- VJ框架 与 人脸检测/物体检测 详解
- NoSQL之 Redis配置与优化
- C51单片机实现串口通信
- 解决CitSpace分析新版本web of science文献报错“the timing slicing setting is outside the range of your data”
- 追逐日月,不苟于山川。
- CSDN中编辑文章时,如何去除图片水印?
- 50个好用的前端工具,建议收藏!
- bootstrap栅栏系统
- IE7.0,IE8.0卸载方法,回到IE6.0
热门文章
- 运动控制卡选型和特点讨论
- xcode ios 怎么导入p12证书
- ThinkPHP商城系统与外部系统用户互通,集成UCenter
- Java多线程电影院_java 多线程-快乐订座电影院
- mysql设计一个网上购物系统_网上购物系统的设计与实现(MyEclipse,MySQL)
- 使用 Python 编写一个聊天小程序
- 莫烦python之python基础学习备忘
- python 东方财富接口_东方财富开放交易api,我只想要东方财富软件交易功能
- 牛客练习赛63 牛牛的树行棋
- WPF_界面_图片/界面/文字模糊解决之道整理