pytorch加载VGG16及进行fine-tuning训练
- 加载VGG16模型并打印查看
from torchvision import models
net=models.vgg16()
print(net)
1.1结果说明
1.2查看某一部分
**
- 加载模型进行预训练,改变classifier层,固定feature层参数
**
2.1模型搭建
import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
net=models.vgg16()class VGGnet(nn.Module):def __init__(self,feature_extract=True,num_classes=5):super(VGGnet, self).__init__()#导入VGG16模型model = models.vgg16(pretrained=True)#加载features部分self.features = model.features#固定特征提取层参数set_parameter_requires_grad(self.features, feature_extract)#加载avgpool层self.avgpool=model.avgpool#改变classifier:分类输出层self.classifier = nn.Sequential(nn.Linear(512*7*7 , 1024),nn.ReLU(),nn.Linear(1024, 1024),nn.ReLU(),nn.Linear(1024, num_classes))def forward(self, x):x = self.features(x)x = self.avgpool(x)x = x.view(x.size(0), 512*7*7)out=self.classifier(x)return out#固定参数,不进行训练
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False
2.2 模型对比
net_self=VGGnet()print('model_build','**'*20)print(net_self)
结果:与原始模型(1.1中的结果)改变的是classifier层
- 模型训练
3.1数据及数据加载代码来源:Tensorflow2.1.0 自定义数据集:精灵宝可梦数据集
import os
from torch.utils.data import Dataset, DataLoader #自定义的母类,必须的
from torchvision.transforms import transforms
from PIL import Image
import torch
import glob
import csv
import randomclass Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {} # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv') #csv文件存在 直接读取if mode == 'train': # 60% self.images = self.images[:int(0.6 * len(self.images))]self.labels = self.labels[:int(0.6 * len(self.labels))]elif mode == 'val': # 20% = 60%->80%self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else: # 20% = 80%->100%self.images = self.images[int(0.8 * len(self.images)):]self.labels = self.labels[int(0.8 * len(self.labels)):]def __len__(self):return len(self.images)def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([ #常用的数据变换器lambda x:Image.open(x).convert('RGB'), # string path= > image data #这里开始读取了数据的内容了transforms.Resize( #数据预处理部分(int(self.resize * 1.25), int(self.resize * 1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label) #转化tensorreturn img, label #返回当前的数据内容和标签def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)): #如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件images = []for name in self.name2label.keys(): # 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images: # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2] #从名字就可以读取标签label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label]) #写进csv文件# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labelsdef denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn x
if __name__=='__main__':db = Pokemon('pokeman', 224, 'train')loader = DataLoader(db, batch_size=32, shuffle=True)for x, y in loader: #此时x,y是批量的数据print(x.shape)
3.2 训练
import torch
import torch.nn as nn
from torchvision import modelsclass VGGnet(nn.Module):def __init__(self,feature_extract=True,num_classes=5):super(VGGnet, self).__init__()model = models.vgg16(pretrained=True)self.features = model.featuresset_parameter_requires_grad(self.features, feature_extract)#固定特征提取层参数self.avgpool=model.avgpoolself.classifier = nn.Sequential(nn.Linear(512*7*7 , 1024),nn.ReLU(),nn.Linear(1024, 1024),nn.ReLU(),nn.Linear(1024, num_classes))def forward(self, x):x = self.features(x)x = self.avgpool(x)x = x.view(x.size(0), 512*7*7)out=self.classifier(x)return outdef set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falseif __name__=="__main__":import torch.nn as nnfrom torch.utils.data import DataLoaderfrom data_read import Pokemon# In[]learning_rate=0.001num_epochs = 2 # train the training data n times, to save time, we just train 1 epochbatch_size = 32LR = 0.01 # learning rate# In[]train_dataset = Pokemon('pokeman', 224, 'train')val_dataset = Pokemon('pokeman', 224, 'val')test_dataset = Pokemon('pokeman', 224, 'test')train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)val_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# In[]device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model=VGGnet().to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# In[]total_step = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 2 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# In[] # Test the modelmodel.eval() #with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy {} %'.format(100 * correct / total))
3.3 结果
- 数据集链接:所有分享的数据集都在这个文件夹,分类,迁移学习,图像分割等
数据集名称:VGG16(fine-tuning_pokeman)
百度云链接:https://pan.baidu.com/s/1eyDFV6YVaOHwr9QrRcKXFg
提取码:rlne
pytorch加载VGG16及进行fine-tuning训练相关推荐
- PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .
希望将训练好的模型加载到新的网络上.如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题. Unexpected key(s) in state_dict: "mod ...
- pytorch 驱动不兼容_解决Pytorch 加载训练好的模型 遇到的error问题
这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- py ...
- 用pytorch加载训练模型
用pytorch加载.pth格式的训练模型 在pytorch/vision/models网页上有很多现成的经典网络模型可以调用,其中包括alexnet.vgg.googlenet.resnet.inc ...
- Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法
需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...
- pytorch加载自己的图片数据集的两种方法
目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 接下来我们就可以构建我们的网络架构: 训练我们的网络: 保存网络模型(这 ...
- Pytorch加载模型并进行图像分类预测
目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...
- pytorch加载训练数据集dataloader操作耗费时间太久,该如何解决?
笔者在使用pytorch加载训练数据进行模型训练的时候,发现数据加载需要耗费太多时间,该如何缩短数据加载的时间消耗呢?经过查询相关文档,总结实际操作过程如下: 1.尽量将jpg等格式的文件保存为bmp ...
- Pytorch加载torchvision从本地下载好的预训练模型的简单解决方案
大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.喜 ...
- pytorch加载自己的数据集,数据集载入-视频合集
pytorch加载数据主要学习了两种:只有图片的数据集和有scv保存标签的数据集 而第一种只有图片的数据集的加 载又分为两种:标签在文件夹上的和标签在图片名上的 1.第一种标签在文件夹上的数据加载方法 ...
最新文章
- 2021年大数据Hadoop(二十九):​​​​​​​关于YARN常用参数设置
- ddr test DCD CFG file CBT
- vim代码格式化工具autopep8
- php探针源码,服务器探针 (刘海探针)—开源PHP探针
- java使用varargs,Java 实例 – Varargs 可变参数使用 - Java 基础教程
- MySQL extract()函数
- POJ3274Gold Balanced Lineup(哈希)
- 怎么查看计算机的系统内存大小,Windows10系统怎么查看电脑内存大小
- python3里面的图片处理库 pillow
- Jumpserver0.4.0基于Centos7安装
- 神经网络进行自然语言处理最佳实践
- css层叠优先级,CSS样式的优先级(层叠)
- Servlet、ServletConfig、ServletContext
- Backdrop CMS介绍
- 使用OpenSSL生成证书
- 串行口数据缓冲寄存器 SBUF 之 初步了解
- 人工智能的发展历程,AI ,路在何方(文章分享)
- CSS_后端工程师必备知识-从入门到劝退详解-呕心沥血撰写(滑稽)
- PS色彩算法理解记录 1 Darken Lighten
- 无线传感网络的基本结构
热门文章
- go 条件变量简介 sync.Cond
- Ashen的成长,从CSDN博客开始!
- Excel如何条件求和
- 如何看待IT行业发展前景,就业前景和人才需求趋势
- 纸箱制作机器人邮箱_纸箱机器人衣服制作方法
- SONET/SDH帧格式
- 计算机网络术语sonet,计算机网络(第七版)谢希仁-第2章 物理层(示例代码)
- debian 10的安装DVD
- com.alibaba.nacos.shaded.io.grpc.StatusRuntimeException UNAVAILABLE io exception
- 知物由学 | 一文读懂Android资源文件保护