• module.py

这里对于VGG19的网络模型只进行了一点改动,就是最后一层输出层,定义了我需要输出的类。修改num_class参数即可。这里input的图片大小是224,也可以自由修改。

import torch.nn as nnclass VGG(nn.Module):# initialize modeldef __init__(self, img_size=224, input_channel=3, num_class=8):###img_size=224代表图片尺寸大小,num_class代表图片种类的数量super().__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=input_channel, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),  # default parameter:nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2,stride=2,padding=0))self.conv3 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),)self.conv4 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv5 = nn.Sequential(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv6 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv7 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),)self.conv8 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(256),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv9 = nn.Sequential(nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv10 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv11 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv12 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.conv13 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv14 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv15 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),)self.conv16 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),nn.BatchNorm2d(512),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=2, stride=2, padding=0))self.fc17 = nn.Sequential(nn.Linear(int(512 * img_size * img_size / 32 / 32), 4096),nn.ReLU(inplace=True),nn.Dropout(p=0.5)  # 默认就是0.5)self.fc18 = nn.Sequential(nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Dropout(p=0.5))self.fc19 = nn.Sequential(nn.Linear(4096, num_class))self.conv_list = [self.conv1, self.conv2, self.conv3, self.conv4, self.conv5, self.conv6, self.conv7,self.conv8, self.conv9, self.conv10, self.conv11, self.conv12, self.conv13, self.conv14,self.conv15, self.conv16]self.fc_list = [self.fc17, self.fc18, self.fc19]print("VGG Model Initialize Successfully!")# forwarddef forward(self, x):for conv in self.conv_list:    # 16 CONVx = conv(x)output = x.view(x.size()[0], -1)for fc in self.fc_list:        # 3 FCoutput = fc(output)return outputif __name__ == '__main__':vgg19 = VGG()# 检查模型每一层的参数# print(vgg19)
  • dataset.py

写这个文件的时候出现了很多次报错,关于标签和图片格式的报错。所以在文件的最后,有保留检验时候的代码。修改的时候,方便查看一些输出和变量类型。

import os
import torch
import glob
import numpy as np
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader,Datasetclass Mydata(Dataset):def __init__(self,root,transforms=None):# 初始化函数,读取所有data_path下的图片self.root = rootself.transform = transformsself.classes_data = {}for name in sorted(os.listdir(os.path.join(root))):self.classes_data[name] = len(self.classes_data)self.images = []for name in self.classes_data.keys():self.images += glob.glob(os.path.join(root, name, '*')) # print(self.classes_data)def __getitem__(self,index):#根据索引index返回dataset[index]img_path = self.images[index]#根据索引index获取图片路径img = Image.open(img_path).convert('RGB')# 读取该图片
#         target= img_path.split('\\')[-1].split('.')[0]if self.transform is not None:img = self.transform(img)  target= self.classes_data[img_path.split('/')[-2]]target = torch.tensor(target)# 根据该图片的路径名获取该图片的label,具体根据路径名进行分割。我这里是"E:\\Python Project\\Pytorch\\dogs-vs-cats\\train\\cat.0.jpg",所以先用"\\"分割,选取最后一个为['cat.0.jpg'],然后使用"."分割,选取[cat]作为该图片的标签return img,target       def __len__(self):return len(self.images)def Myclasses(self):return self.classes_data########下面的是测试这个自定义文件的dataset文件能否正常运行,以及上面的函数在调用后的功能测试
# if __name__=='__main__':
#     train_root = "./dam-v4/train"
#     mydata = Mydata(train_root, transforms=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()]))
#     print(mydata.classes)# dataloader = DataLoader(mydata,batch_size=32,shuffle=True)#使用DataLoader加载数据# # # #     #   查看batch
#     for data in dataloader:
#         imgs,targets = data
  • train.py

训练代码没什么好讲的,就是导入数据,模型,然后开始训练

这里我只打印了train的loss,acc和test的acc

import os
import time
import torch
import torchvision
from torchvision import datasets, models, transforms
import torch.optim as optim
import numpy as np
from  module_vgg import VGG
from Mydataset import Mydata#######定义一会要调用的训练和验证的函数
def train_model(model, criterion, optimizer, num_epochs=25):since = time.time()# 记录训练过程的指标print('开始训练')for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 迭代数据train_loss = 0train_acc = 0# Iterate over data.for i, data in enumerate(dataloaders['train']):inputs , labels = datainputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad()# 清空梯度with torch.set_grad_enabled(True):outputs  = model(inputs)loss = criterion(outputs, labels)loss.backward()# 反向传播计算梯度optimizer.step()# 更新网络train_loss += loss.item() * inputs.size(0)train_acc += (outputs.argmax(1)==labels).sum()print('Epoch{}acc: '.format(epoch),train_acc.item()/dataset_sizes['train'])print('Epoch{}loss: '.format(epoch),train_loss/dataset_sizes['train'])time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))return model####### load data 加载数据
data_dir = "./dam-v4"
input_shape = 224   ##统一图片大小
batch_size = 32image_datasets = {x: Mydata(os.path.join(data_dir, x),transforms=transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()]))for x in ['train', 'validation']
}dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size,shuffle=True, num_workers=4)for x in ['train', 'validation']
}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'validation']}
# print(dataset_sizes['train'],dataset_sizes['validation'])class_names = image_datasets['train'].Myclasses()
print("数据读取成功")
print('训练集和验证集图片数量:',dataset_sizes)
print('图片分类:',class_names)
# {'train': 3277, 'validation': 351}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")####导入模型开始训练
vgg_based = VGG()
vgg_based = vgg_based.to(device) ## 定义损失函数和优化器
criterion = torch.nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(vgg_based.parameters(), lr=0.001, momentum=0.9)
####模型开始训练vgg_based = train_model(vgg_based, criterion, optimizer_ft, num_epochs=25)####对模型进行测试vgg_based.eval() #在验证状态
val_loss_sum = 0.0
val_metric_sum = 0.0
val_step =1
with torch.no_grad(): # 验证的部分,不是训练所以不要带入梯度print('验证状态')for i, data_val in enumerate(dataloaders['validation']):features,labels = data_valfeatures = features.to(device)labels = labels.to(device)pred = vgg_based(features)val_loss = criterion(pred,labels)val_step += 1val_metric_sum += (pred.argmax(1)==labels).sum()# val_metric_sum = val_metric_sum*features.size(0)print(val_metric_sum)print("模型在验证集上的准确率为{}".format(val_step,(val_metric_sum.item()/dataset_sizes['validation'])))

关于用VGG19网络来做8分类任务的总结相关推荐

  1. 【PyTorch】构造VGG19网络进行本地图片分类(超详细过程)——项目介绍

    本篇博客主要解决以下3个问题: 如何自定义网络(以VGG19为例). 如何自建数据集并加载至模型中. 如何使用自定义数据训练自定义模型. 第一篇:[PyTorch]构造VGG19网络进行本地图片分类( ...

  2. 论文解读:胶囊网络在小样本做文本分类中的应用(下)

    Dynamic Memory Induction Networks for Few-Shot Text Classification 论文提出Dynamic Memory Induction Netw ...

  3. 使用CNN做文本分类——将图像2维卷积换成1维

    使用CNN做文本分类from __future__ importdivision, print_function, absolute_importimporttensorflow as tfimpor ...

  4. php 相册分类,这款不需要网络就可以智能识别分类照片,让你的相册不再混乱...

    记得在<我的前半生>中吴越饰演的凌玲凭借着一手整理的才能把陈俊生迷的死死的,而「EASY 相册」则凭借一手智能整理的本领让我抛弃了其他相册管理 APP,专注与它-- 作为一名热爱生活,热爱 ...

  5. 分享 | 基于图像分类网络ResNet50_vd实现桃子分类

    随着时代的快速发展,人工智能已经融入我们生活的方方面面.中国的农业也因人工智能而受益进入高速发展阶段.现今,看庄稼长势有卫星遥感技术,水果分拣有智能分拣系统,灌溉施肥有自动化机械-- 具体以水果分拣场 ...

  6. 使用bert模型做句子分类

    使用bert模型微调做下游任务,在goole发布的bert代码和huggingface的transformer项目中都有相应的任务,有的时候只需要把代码做简单的修改即可使用.发现代码很多,我尝试着自己 ...

  7. 【深度学习】使用tensorflow实现VGG19网络

    转载注明出处:http://blog.csdn.net/accepthjp/article/details/70170217 接上一篇AlexNet,本文讲述使用tensorflow实现VGG19网络 ...

  8. 贝叶斯分类器做文本分类案例

    贝叶斯分类器做文本分类 文本分类是现代机器学习应用中的一大模块,更是自然语言处理的基础之一.我们可以通过将文字数据处理成数字数据,然后使用贝叶斯来帮助我们判断一段话,或者一篇文章中的主题分类,感情倾向 ...

  9. 基于图像分类网络ResNet50_vd实现桃子分类

    基于图像分类网络ResNet50_vd实现桃子分类 随着时代的快速发展,人工智能已经融入我们生活的方方面面.中国的农业也因人工智能而受益进入高速发展阶段.现今,看庄稼长势有卫星遥感技术,水果分拣有智能 ...

  10. transformer做文本分类的keras实现完整版

    背景 目前csdn上搜索到的keras的版本实现,排在前面的是: https://blog.csdn.net/xiaosongshine/article/details/86595847 但是,这个文 ...

最新文章

  1. dateutils 工具类_五金工具泡壳封边机
  2. log4j每天产生一日志文件
  3. 我关注的那些程序员大佬
  4. extjs tree下拉列表_Extjs中ComboBoxTree实现的下拉框树效果(自写)_extjs
  5. oracle数据库查看建表语句,oracle 查看建表语句
  6. 计算机网络的商业价值和应用,计算机网络建模数学工具的分析与比较
  7. JavaScript打开APP
  8. 零的突破!6所高校、2所“双非”顶刊发文
  9. 计网 | 网络层 SDN控制器 / 远程控制器
  10. canvas教程5-绘制路径
  11. AspNetCore.Mvc写Get方法运行显示该网页无法正常运作 http err 500问题
  12. VJ框架 与 人脸检测/物体检测 详解
  13. NoSQL之 Redis配置与优化
  14. C51单片机实现串口通信
  15. 解决CitSpace分析新版本web of science文献报错“the timing slicing setting is outside the range of your data”
  16. 追逐日月,不苟于山川。
  17. CSDN中编辑文章时,如何去除图片水印?
  18. 50个好用的前端工具,建议收藏!
  19. bootstrap栅栏系统
  20. IE7.0,IE8.0卸载方法,回到IE6.0

热门文章

  1. 运动控制卡选型和特点讨论
  2. xcode ios 怎么导入p12证书
  3. ThinkPHP商城系统与外部系统用户互通,集成UCenter
  4. Java多线程电影院_java 多线程-快乐订座电影院
  5. mysql设计一个网上购物系统_网上购物系统的设计与实现(MyEclipse,MySQL)
  6. 使用 Python 编写一个聊天小程序
  7. 莫烦python之python基础学习备忘
  8. python 东方财富接口_东方财富开放交易api,我只想要东方财富软件交易功能
  9. 牛客练习赛63 牛牛的树行棋
  10. WPF_界面_图片/界面/文字模糊解决之道整理