1. 加载VGG16模型并打印查看
from torchvision import models
net=models.vgg16()
print(net)

1.1结果说明

1.2查看某一部分

**

  1. 加载模型进行预训练,改变classifier层,固定feature层参数

**
2.1模型搭建

import torch
import torch.nn as nn
from torchvision import models
from torchsummary import summary
net=models.vgg16()class VGGnet(nn.Module):def __init__(self,feature_extract=True,num_classes=5):super(VGGnet, self).__init__()#导入VGG16模型model = models.vgg16(pretrained=True)#加载features部分self.features = model.features#固定特征提取层参数set_parameter_requires_grad(self.features, feature_extract)#加载avgpool层self.avgpool=model.avgpool#改变classifier:分类输出层self.classifier = nn.Sequential(nn.Linear(512*7*7 , 1024),nn.ReLU(),nn.Linear(1024, 1024),nn.ReLU(),nn.Linear(1024, num_classes))def forward(self, x):x = self.features(x)x = self.avgpool(x)x = x.view(x.size(0), 512*7*7)out=self.classifier(x)return out#固定参数,不进行训练
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = False

2.2 模型对比

net_self=VGGnet()print('model_build','**'*20)print(net_self)

结果:与原始模型(1.1中的结果)改变的是classifier层

  1. 模型训练
    3.1数据及数据加载代码来源:Tensorflow2.1.0 自定义数据集:精灵宝可梦数据集
import os
from torch.utils.data import Dataset, DataLoader  #自定义的母类,必须的
from torchvision.transforms import transforms
from PIL import Image
import torch
import glob
import csv
import randomclass Pokemon(Dataset):def __init__(self, root, resize, mode):super(Pokemon, self).__init__()self.root = rootself.resize = resizeself.name2label = {}  # "sq...":0for name in sorted(os.listdir(os.path.join(root))):if not os.path.isdir(os.path.join(root, name)):continueself.name2label[name] = len(self.name2label.keys()) #将英文标签名转化数字0-4# print(self.name2label)# image, labelself.images, self.labels = self.load_csv('images.csv')  #csv文件存在 直接读取if mode == 'train':  # 60%                   self.images = self.images[:int(0.6 * len(self.images))]self.labels = self.labels[:int(0.6 * len(self.labels))]elif mode == 'val':  # 20% = 60%->80%self.images = self.images[int(0.6 * len(self.images)):int(0.8 * len(self.images))]self.labels = self.labels[int(0.6 * len(self.labels)):int(0.8 * len(self.labels))]else:  # 20% = 80%->100%self.images = self.images[int(0.8 * len(self.images)):]self.labels = self.labels[int(0.8 * len(self.labels)):]def __len__(self):return len(self.images)def __getitem__(self, idx):# idx~[0~len(images)]# self.images, self.labels# img: 'pokemon\\bulbasaur\\00000000.png'# label: 0img, label = self.images[idx], self.labels[idx]tf = transforms.Compose([   #常用的数据变换器lambda x:Image.open(x).convert('RGB'),  # string path= > image data #这里开始读取了数据的内容了transforms.Resize(   #数据预处理部分(int(self.resize * 1.25), int(self.resize * 1.25))), transforms.RandomRotation(15), transforms.CenterCrop(self.resize), #防止旋转后边界出现黑框部分transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])img = tf(img)label = torch.tensor(label)  #转化tensorreturn img, label       #返回当前的数据内容和标签def load_csv(self, filename):if not os.path.exists(os.path.join(self.root, filename)): #如果没有保存csv文件,那么我们需要写一个csv文件,如果有了直接读取csv文件images = []for name in self.name2label.keys():   # 'pokemon\\mewtwo\\00001.pngimages += glob.glob(os.path.join(self.root, name, '*.png'))images += glob.glob(os.path.join(self.root, name, '*.jpg'))images += glob.glob(os.path.join(self.root, name, '*.jpeg'))random.shuffle(images)with open(os.path.join(self.root, filename), mode='w', newline='') as f:writer = csv.writer(f)for img in images:  # 'pokemon\\bulbasaur\\00000000.png'name = img.split(os.sep)[-2]        #从名字就可以读取标签label = self.name2label[name]# 'pokemon\\bulbasaur\\00000000.png', 0writer.writerow([img, label])  #写进csv文件# read from csv fileimages, labels = [], []with open(os.path.join(self.root, filename)) as f:reader = csv.reader(f)for row in reader:# 'pokemon\\bulbasaur\\00000000.png', 0img, label = rowlabel = int(label)images.append(img)labels.append(label)assert len(images) == len(labels)return images, labelsdef denormalize(self, x_hat):mean = [0.485, 0.456, 0.406]std = [0.229, 0.224, 0.225]# x_hat = (x-mean)/std# x = x_hat*std = mean# x: [c, h, w]# mean: [3] => [3, 1, 1]mean = torch.tensor(mean).unsqueeze(1).unsqueeze(1)std = torch.tensor(std).unsqueeze(1).unsqueeze(1)# print(mean.shape, std.shape)x = x_hat * std + meanreturn x
if __name__=='__main__':db = Pokemon('pokeman', 224, 'train')loader = DataLoader(db, batch_size=32, shuffle=True)for x, y in loader: #此时x,y是批量的数据print(x.shape)

3.2 训练

import torch
import torch.nn as nn
from torchvision import modelsclass VGGnet(nn.Module):def __init__(self,feature_extract=True,num_classes=5):super(VGGnet, self).__init__()model = models.vgg16(pretrained=True)self.features = model.featuresset_parameter_requires_grad(self.features, feature_extract)#固定特征提取层参数self.avgpool=model.avgpoolself.classifier = nn.Sequential(nn.Linear(512*7*7 , 1024),nn.ReLU(),nn.Linear(1024, 1024),nn.ReLU(),nn.Linear(1024, num_classes))def forward(self, x):x = self.features(x)x = self.avgpool(x)x = x.view(x.size(0), 512*7*7)out=self.classifier(x)return outdef set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falseif __name__=="__main__":import torch.nn as nnfrom torch.utils.data import DataLoaderfrom data_read import Pokemon# In[]learning_rate=0.001num_epochs = 2               # train the training data n times, to save time, we just train 1 epochbatch_size = 32LR = 0.01              # learning rate# In[]train_dataset = Pokemon('pokeman', 224, 'train')val_dataset = Pokemon('pokeman', 224, 'val')test_dataset = Pokemon('pokeman', 224, 'test')train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)test_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)val_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)# In[]device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model=VGGnet().to(device)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# In[]total_step = len(train_loader)for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):images = images.to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward()optimizer.step()if (i + 1) % 2 == 0:print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))# In[]      # Test the modelmodel.eval()  #with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy  {} %'.format(100 * correct / total))

3.3 结果

  1. 数据集链接:所有分享的数据集都在这个文件夹,分类,迁移学习,图像分割等
    数据集名称:VGG16(fine-tuning_pokeman)
    百度云链接:https://pan.baidu.com/s/1eyDFV6YVaOHwr9QrRcKXFg
    提取码:rlne

pytorch加载VGG16及进行fine-tuning训练相关推荐

  1. PyTorch加载模型model.load_state_dict()问题,Unexpected key(s) in state_dict: “module.features..,Expected .

    希望将训练好的模型加载到新的网络上.如上面题目所描述的,PyTorch在加载之前保存的模型参数的时候,遇到了问题. Unexpected key(s) in state_dict: "mod ...

  2. pytorch 驱动不兼容_解决Pytorch 加载训练好的模型 遇到的error问题

    这是一个非常愚蠢的错误 debug的时候要好好看error信息 提醒自己切记好好对待error!切记!切记! -----------------------分割线---------------- py ...

  3. 用pytorch加载训练模型

    用pytorch加载.pth格式的训练模型 在pytorch/vision/models网页上有很多现成的经典网络模型可以调用,其中包括alexnet.vgg.googlenet.resnet.inc ...

  4. Pytorch加载模型只导入部分层权重,即跳过指定网络层的方法

    需求 Pytorch加载模型时,只导入部分层权重,跳过部分指定网络层.(权重文件存储为dict形式) 方法一 常见方法:加载权重时用if对网络层进行筛选 ''' # model为定义的网络结构: cl ...

  5. pytorch加载自己的图片数据集的两种方法

    目录 ImageFolder 加载数据集 使用pytorch提供的Dataset类创建自己的数据集. Dataset加载数据集 接下来我们就可以构建我们的网络架构: 训练我们的网络: 保存网络模型(这 ...

  6. Pytorch加载模型并进行图像分类预测

    目录 1. 整体流程 1)实例化模型 2)加载模型 3)输入图像 4)输出分类结果 5)完整代码 2. 处理图像 1) How can i convert an RGB image into gray ...

  7. pytorch加载训练数据集dataloader操作耗费时间太久,该如何解决?

    笔者在使用pytorch加载训练数据进行模型训练的时候,发现数据加载需要耗费太多时间,该如何缩短数据加载的时间消耗呢?经过查询相关文档,总结实际操作过程如下: 1.尽量将jpg等格式的文件保存为bmp ...

  8. Pytorch加载torchvision从本地下载好的预训练模型的简单解决方案

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.喜 ...

  9. pytorch加载自己的数据集,数据集载入-视频合集

    pytorch加载数据主要学习了两种:只有图片的数据集和有scv保存标签的数据集 而第一种只有图片的数据集的加 载又分为两种:标签在文件夹上的和标签在图片名上的 1.第一种标签在文件夹上的数据加载方法 ...

最新文章

  1. 2021年大数据Hadoop(二十九):​​​​​​​关于YARN常用参数设置
  2. ddr test DCD CFG file CBT
  3. vim代码格式化工具autopep8
  4. php探针源码,服务器探针 (刘海探针)—开源PHP探针
  5. java使用varargs,Java 实例 – Varargs 可变参数使用 - Java 基础教程
  6. MySQL extract()函数
  7. POJ3274Gold Balanced Lineup(哈希)
  8. 怎么查看计算机的系统内存大小,Windows10系统怎么查看电脑内存大小
  9. python3里面的图片处理库 pillow
  10. Jumpserver0.4.0基于Centos7安装
  11. 神经网络进行自然语言处理最佳实践
  12. css层叠优先级,CSS样式的优先级(层叠)
  13. Servlet、ServletConfig、ServletContext
  14. Backdrop CMS介绍
  15. 使用OpenSSL生成证书
  16. 串行口数据缓冲寄存器 SBUF 之 初步了解
  17. 人工智能的发展历程,AI ,路在何方(文章分享)
  18. CSS_后端工程师必备知识-从入门到劝退详解-呕心沥血撰写(滑稽)
  19. PS色彩算法理解记录 1 Darken Lighten
  20. 无线传感网络的基本结构

热门文章

  1. go 条件变量简介 sync.Cond
  2. Ashen的成长,从CSDN博客开始!
  3. Excel如何条件求和
  4. 如何看待IT行业发展前景,就业前景和人才需求趋势
  5. 纸箱制作机器人邮箱_纸箱机器人衣服制作方法
  6. SONET/SDH帧格式
  7. 计算机网络术语sonet,计算机网络(第七版)谢希仁-第2章 物理层(示例代码)
  8. debian 10的安装DVD
  9. com.alibaba.nacos.shaded.io.grpc.StatusRuntimeException UNAVAILABLE io exception
  10. 知物由学 | 一文读懂Android资源文件保护