图像迁移学习

  • 3.PyTorch实现迁移学习
    • 3.1数据集预处理
    • 3.2构建模型
    • 3.3模型训练与验证

3.PyTorch实现迁移学习

文件目录

3.1数据集预处理

这里实现一个蚂蚁与蜜蜂的图像分类,用到的数据集data下载
dataset.py

from torchvision import datasets, transforms
import torchtrain=transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪一个area然后再resizetransforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train)
valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val)trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True, num_workers=4)
valloader=torch.utils.data.DataLoader(valset,batch_size=4,shuffle=True, num_workers=4)

3.2构建模型

model.py

from torchvision import models
import torch.nn as nn#初始化模型#保证模型不改变的层的参数,不发生梯度变化
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsedef initialize_model(model_name, num_classes, feature_extract):model_ft=Noneinput_size=0if model_name =='resnet':#resnet18model_ft = models.resnet18(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "alexnet":model_ft = models.alexnet(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "vgg":#vgg11model_ft = models.vgg11_bn(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "squeezenet":model_ft = models.squeezenet1_0(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))model_ft.num_classes = num_classesinput_size = 224elif model_name == "densenet":model_ft = models.densenet121(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier.in_featuresmodel_ft.classifier = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "inception":model_ft = models.inception_v3(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)# Handle the auxilary netnum_ftrs = model_ft.AuxLogits.fc.in_featuresmodel_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)# Handle the primary netnum_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 299else:print("没有合适的模型...")return model_ft, input_size

3.3模型训练与验证

run.py

from __future__ import print_function
from __future__ import division
import torch.nn as nn
import torch.optim as optim
from model import initialize_model
from torch.optim import lr_scheduler
import time
import copy
from dataset import *
import argparseparser=argparse.ArgumentParser()
#模型选择
parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet')
#分类类别数
parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2)
#定义一个批次的样本数
parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8)
#定义迭代批次
parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25)
args=parser.parse_args()#用于特征提取的标志。如果为False,则对整个模型进行微调,
#如果为True,则仅更新重塑的图层参数
feature_extract = True#定义数据字典
datasets={train:trainset,val:valset}
#定义数据集字典
dataloaders={train:trainloader,val:valloader}model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract)
criterion = nn.CrossEntropyLoss()# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)def train_model(model,criterion,optimizer,scheduler,num_epochs):since=time.time()val_acc_history = []#获取模型初始参数best_model_wts=copy.deepcopy(model.state_dict())best_acc=0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)for data in ['train','val']:if data=='train':scheduler.step()model.train()else:model.eval()running_loss = 0.0running_corrects = 0for inputs,labels in dataloaders[data]:optimizer.zero_grad()with torch.set_grad_enabled(data=='train'):outputs=model(inputs)_,preds=torch.max(outputs,1)loss=criterion(outputs,labels)if data=='train':loss.backward()optimizer.step()running_loss+=loss.item()*inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(datasets[data])epoch_acc = running_corrects.double() / len(datasets[data])print('{} Loss: {:.4f} Acc: {:.4f}'.format(data, epoch_loss, epoch_acc))# 深度复制moif data=='val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))model.load_state_dict(best_model_wts)return modeltrain_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)

计算机视觉PyTorch迁移学习 - (二)相关推荐

  1. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  2. PyTorch迁移学习

    PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...

  3. PyTorch框架学习二——基本数据结构(张量)

    PyTorch框架学习二--基本数据结构(张量) 一.什么是张量? 二.Tensor与Variable(PyTorch中) 1.Variable 2.Tensor 三.Tensor的创建 1.直接创建 ...

  4. PyTorch 迁移学习 (Transfer Learning) 代码详解

    PyTorch 迁移学习 代码详解 概述 为什么使用迁移学习 更好的结果 节省时间 加载模型 ResNet152 冻层实现 模型初始化 获取需更新参数 训练模型 获取数据 完整代码 概述 迁移学习 ( ...

  5. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    1. 迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下 ...

  6. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  7. pytorch构造可迭代的Dataset——IterableDataset(pytorch Data学习二)

    如果是可以一次性加载进内存的数据,上一篇博客:pytorch 构造读取数据的工具类 Dataset 与 DataLoader (pytorch Data学习一),已经足以应付了,但是很多时候数据集较大 ...

  8. pytorch迁移学习载入部分权重

    载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习.使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即 ...

  9. 【学习笔记】pytorch迁移学习-猫狗分类实战

    1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...

最新文章

  1. eScan Internet Security Suite 2006
  2. Xilinx® 7 series FPGAs CLBs专题介绍(一)
  3. Keras之MLP:利用MLP【Input(8)→(12)(relu)→O(sigmoid+二元交叉)】模型实现预测新数据(利用糖尿病数据集的八个特征实现二分类预测
  4. hdu4714 Tree2cycle 把树剪成链
  5. 树状数组--前n项和;
  6. 5G协议学习(38.300-物理层)
  7. wps里的茶色字体怎么设置_VRay茶色玻璃材质参数是什么,要怎么设置?
  8. 【pytorch】Conv2d()里面的参数bias什么时候加,什么时候不加?
  9. SQL_Xbar代码
  10. 视觉和imu(惯性传感器)( 一)
  11. 蚁群算法求解TSP问题
  12. UPS不间断电源系统安全使用要领
  13. python实战-02-基础语法及pip安装
  14. Unity灯光渲染之自发光材质
  15. 里恩EDC论临床试验中与第三方中心实验室实时电子化传输和接收的技术实现以及风险应对
  16. powerbi如何创建参数_Power BI中参数的用法
  17. C语言 校园歌手比赛系统源码
  18. 扒勒索病毒史,聊真CDP与准CDP
  19. Android毕业设计及论文答辩经验分享
  20. 有种爱,不会提起不会忘记

热门文章

  1. 高级语言.汇编语言和机器语言
  2. Keycloak实现手机验证码登录
  3. 奥特曼系列ol进不去服务器,奥特曼系列OL闪退怎么办?解决方案
  4. 独享云虚拟主机、共享云虚拟主机、云服务器 ECS 的区别
  5. 解决IDEA 前端返回值乱码问题
  6. c 语言编辑器 win7旗舰版,如何使用大地win7旗舰版内置字符编辑程序
  7. IT人生之猎人和猎狗
  8. AlexNet_tensorflow2.1_实现狼狗分类
  9. Unity URP DOTS Animator
  10. JDK自带的Timer定时器实现每天24点修改数据