图像迁移学习

  • 3.PyTorch实现迁移学习
    • 3.1数据集预处理
    • 3.2构建模型
    • 3.3模型训练与验证

3.PyTorch实现迁移学习

文件目录

3.1数据集预处理

这里实现一个蚂蚁与蜜蜂的图像分类,用到的数据集data下载
dataset.py

from torchvision import datasets, transforms
import torchtrain=transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪一个area然后再resizetransforms.RandomHorizontalFlip(),  # 随机水平翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train)
valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val)trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True, num_workers=4)
valloader=torch.utils.data.DataLoader(valset,batch_size=4,shuffle=True, num_workers=4)

3.2构建模型

model.py

from torchvision import models
import torch.nn as nn#初始化模型#保证模型不改变的层的参数,不发生梯度变化
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsedef initialize_model(model_name, num_classes, feature_extract):model_ft=Noneinput_size=0if model_name =='resnet':#resnet18model_ft = models.resnet18(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "alexnet":model_ft = models.alexnet(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "vgg":#vgg11model_ft = models.vgg11_bn(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "squeezenet":model_ft = models.squeezenet1_0(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))model_ft.num_classes = num_classesinput_size = 224elif model_name == "densenet":model_ft = models.densenet121(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier.in_featuresmodel_ft.classifier = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "inception":model_ft = models.inception_v3(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)# Handle the auxilary netnum_ftrs = model_ft.AuxLogits.fc.in_featuresmodel_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)# Handle the primary netnum_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 299else:print("没有合适的模型...")return model_ft, input_size

3.3模型训练与验证

run.py

from __future__ import print_function
from __future__ import division
import torch.nn as nn
import torch.optim as optim
from model import initialize_model
from torch.optim import lr_scheduler
import time
import copy
from dataset import *
import argparseparser=argparse.ArgumentParser()
#模型选择
parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet')
#分类类别数
parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2)
#定义一个批次的样本数
parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8)
#定义迭代批次
parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25)
args=parser.parse_args()#用于特征提取的标志。如果为False,则对整个模型进行微调,
#如果为True,则仅更新重塑的图层参数
feature_extract = True#定义数据字典
datasets={train:trainset,val:valset}
#定义数据集字典
dataloaders={train:trainloader,val:valloader}model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract)
criterion = nn.CrossEntropyLoss()# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)def train_model(model,criterion,optimizer,scheduler,num_epochs):since=time.time()val_acc_history = []#获取模型初始参数best_model_wts=copy.deepcopy(model.state_dict())best_acc=0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)for data in ['train','val']:if data=='train':scheduler.step()model.train()else:model.eval()running_loss = 0.0running_corrects = 0for inputs,labels in dataloaders[data]:optimizer.zero_grad()with torch.set_grad_enabled(data=='train'):outputs=model(inputs)_,preds=torch.max(outputs,1)loss=criterion(outputs,labels)if data=='train':loss.backward()optimizer.step()running_loss+=loss.item()*inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(datasets[data])epoch_acc = running_corrects.double() / len(datasets[data])print('{} Loss: {:.4f} Acc: {:.4f}'.format(data, epoch_loss, epoch_acc))# 深度复制moif data=='val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))model.load_state_dict(best_model_wts)return modeltrain_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)

计算机视觉PyTorch迁移学习 - (二)相关推荐

  1. PyTorch框架学习二十——模型微调(Finetune)

    PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...

  2. PyTorch迁移学习

    PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...

  3. PyTorch框架学习二——基本数据结构(张量)

    PyTorch框架学习二--基本数据结构(张量) 一.什么是张量? 二.Tensor与Variable(PyTorch中) 1.Variable 2.Tensor 三.Tensor的创建 1.直接创建 ...

  4. PyTorch 迁移学习 (Transfer Learning) 代码详解

    PyTorch 迁移学习 代码详解 概述 为什么使用迁移学习 更好的结果 节省时间 加载模型 ResNet152 冻层实现 模型初始化 获取需更新参数 训练模型 获取数据 完整代码 概述 迁移学习 ( ...

  5. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    1. 迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下 ...

  6. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  7. pytorch构造可迭代的Dataset——IterableDataset(pytorch Data学习二)

    如果是可以一次性加载进内存的数据,上一篇博客:pytorch 构造读取数据的工具类 Dataset 与 DataLoader (pytorch Data学习一),已经足以应付了,但是很多时候数据集较大 ...

  8. pytorch迁移学习载入部分权重

    载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习.使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即 ...

  9. 【学习笔记】pytorch迁移学习-猫狗分类实战

    1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...

最新文章

  1. React+Redux+中间件
  2. 技术人员关注的几个优质公众号
  3. DevOps:软件架构师行动指南1.7 障碍
  4. esp8266 wifi模组入网案例
  5. java高效遍历匹配,使用cypher或遍历api仅匹配路径极端的单个节点
  6. 滤波算法(二)—— 中位值滤波算法
  7. 易语言模拟按键 c打不出,易语言怎么编写模拟按键
  8. linux vim中文使用教程
  9. 反转链表 c++实现
  10. 软件工程——软件开发步骤
  11. 按键精灵通过句柄获取窗口坐标_按键精灵9 得到鼠标指向的窗口句柄
  12. 微信浮窗是不是服务器保存,微信浮窗,能解决小程序留存难题吗?
  13. 多项式展开的逆过程的MATLAB实现
  14. 因机构系统维护服务暂不可用_因合作方系统维护,暂时无法使用是什么意思?...
  15. [转发]行列视(RCV)——生产数据应用系统
  16. 转 纸牌屋1-4集分析
  17. 三维场景注记的配置相关(学习记录)
  18. c语言c 哪个好学,C语言好学吗?
  19. HTML基础-笔记1标签
  20. 上个礼拜公司组织去浙江旅游的照片

热门文章

  1. 新概念二册 Lesson 45 A clear conscience问心无愧(复习被动语态+过去完成时被动语态)
  2. AV3680A天馈线测试仪使用方式
  3. python27读书笔记0.1
  4. matlab近似计算求积分,matlab 实验二 定积分的近似计算
  5. 圆周率近似计算matlab,matlab 圆周率的近似计算 实验报告.doc
  6. Android开发范例实战宝典
  7. html中ol和li,HTML ol和li标签
  8. GSM PDU模式发中文短信
  9. gorilla websocket简易介绍
  10. 一只蝴蝶挥了挥翅膀,于是有了AI读心术