计算机视觉PyTorch迁移学习 - (二)
图像迁移学习
- 3.PyTorch实现迁移学习
- 3.1数据集预处理
- 3.2构建模型
- 3.3模型训练与验证
3.PyTorch实现迁移学习
文件目录
3.1数据集预处理
这里实现一个蚂蚁与蜜蜂的图像分类,用到的数据集data下载
dataset.py
from torchvision import datasets, transforms
import torchtrain=transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪一个area然后再resizetransforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val=transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train)
valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val)trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True, num_workers=4)
valloader=torch.utils.data.DataLoader(valset,batch_size=4,shuffle=True, num_workers=4)
3.2构建模型
model.py
from torchvision import models
import torch.nn as nn#初始化模型#保证模型不改变的层的参数,不发生梯度变化
def set_parameter_requires_grad(model, feature_extracting):if feature_extracting:for param in model.parameters():param.requires_grad = Falsedef initialize_model(model_name, num_classes, feature_extract):model_ft=Noneinput_size=0if model_name =='resnet':#resnet18model_ft = models.resnet18(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "alexnet":model_ft = models.alexnet(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "vgg":#vgg11model_ft = models.vgg11_bn(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier[6].in_featuresmodel_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "squeezenet":model_ft = models.squeezenet1_0(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))model_ft.num_classes = num_classesinput_size = 224elif model_name == "densenet":model_ft = models.densenet121(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)num_ftrs = model_ft.classifier.in_featuresmodel_ft.classifier = nn.Linear(num_ftrs, num_classes)input_size = 224elif model_name == "inception":model_ft = models.inception_v3(pretrained=True)set_parameter_requires_grad(model_ft, feature_extract)# Handle the auxilary netnum_ftrs = model_ft.AuxLogits.fc.in_featuresmodel_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)# Handle the primary netnum_ftrs = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_ftrs, num_classes)input_size = 299else:print("没有合适的模型...")return model_ft, input_size
3.3模型训练与验证
run.py
from __future__ import print_function
from __future__ import division
import torch.nn as nn
import torch.optim as optim
from model import initialize_model
from torch.optim import lr_scheduler
import time
import copy
from dataset import *
import argparseparser=argparse.ArgumentParser()
#模型选择
parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet')
#分类类别数
parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2)
#定义一个批次的样本数
parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8)
#定义迭代批次
parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25)
args=parser.parse_args()#用于特征提取的标志。如果为False,则对整个模型进行微调,
#如果为True,则仅更新重塑的图层参数
feature_extract = True#定义数据字典
datasets={train:trainset,val:valset}
#定义数据集字典
dataloaders={train:trainloader,val:valloader}model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract)
criterion = nn.CrossEntropyLoss()# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)def train_model(model,criterion,optimizer,scheduler,num_epochs):since=time.time()val_acc_history = []#获取模型初始参数best_model_wts=copy.deepcopy(model.state_dict())best_acc=0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch,num_epochs-1))print('-'*10)for data in ['train','val']:if data=='train':scheduler.step()model.train()else:model.eval()running_loss = 0.0running_corrects = 0for inputs,labels in dataloaders[data]:optimizer.zero_grad()with torch.set_grad_enabled(data=='train'):outputs=model(inputs)_,preds=torch.max(outputs,1)loss=criterion(outputs,labels)if data=='train':loss.backward()optimizer.step()running_loss+=loss.item()*inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / len(datasets[data])epoch_acc = running_corrects.double() / len(datasets[data])print('{} Loss: {:.4f} Acc: {:.4f}'.format(data, epoch_loss, epoch_acc))# 深度复制moif data=='val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))model.load_state_dict(best_model_wts)return modeltrain_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)
计算机视觉PyTorch迁移学习 - (二)相关推荐
- PyTorch框架学习二十——模型微调(Finetune)
PyTorch框架学习二十--模型微调(Finetune) 一.Transfer Learning:迁移学习 二.Model Finetune:模型的迁移学习 三.看个例子:用ResNet18预训练模 ...
- PyTorch迁移学习
PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...
- PyTorch框架学习二——基本数据结构(张量)
PyTorch框架学习二--基本数据结构(张量) 一.什么是张量? 二.Tensor与Variable(PyTorch中) 1.Variable 2.Tensor 三.Tensor的创建 1.直接创建 ...
- PyTorch 迁移学习 (Transfer Learning) 代码详解
PyTorch 迁移学习 代码详解 概述 为什么使用迁移学习 更好的结果 节省时间 加载模型 ResNet152 冻层实现 模型初始化 获取需更新参数 训练模型 获取数据 完整代码 概述 迁移学习 ( ...
- PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类
1. 迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下 ...
- Pytorch迁移学习加载部分预训练权重
迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...
- pytorch构造可迭代的Dataset——IterableDataset(pytorch Data学习二)
如果是可以一次性加载进内存的数据,上一篇博客:pytorch 构造读取数据的工具类 Dataset 与 DataLoader (pytorch Data学习一),已经足以应付了,但是很多时候数据集较大 ...
- pytorch迁移学习载入部分权重
载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习.使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即 ...
- 【学习笔记】pytorch迁移学习-猫狗分类实战
1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...
最新文章
- React+Redux+中间件
- 技术人员关注的几个优质公众号
- DevOps:软件架构师行动指南1.7 障碍
- esp8266 wifi模组入网案例
- java高效遍历匹配,使用cypher或遍历api仅匹配路径极端的单个节点
- 滤波算法(二)—— 中位值滤波算法
- 易语言模拟按键 c打不出,易语言怎么编写模拟按键
- linux vim中文使用教程
- 反转链表 c++实现
- 软件工程——软件开发步骤
- 按键精灵通过句柄获取窗口坐标_按键精灵9 得到鼠标指向的窗口句柄
- 微信浮窗是不是服务器保存,微信浮窗,能解决小程序留存难题吗?
- 多项式展开的逆过程的MATLAB实现
- 因机构系统维护服务暂不可用_因合作方系统维护,暂时无法使用是什么意思?...
- [转发]行列视(RCV)——生产数据应用系统
- 转 纸牌屋1-4集分析
- 三维场景注记的配置相关(学习记录)
- c语言c 哪个好学,C语言好学吗?
- HTML基础-笔记1标签
- 上个礼拜公司组织去浙江旅游的照片