PyTorch迁移学习
实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集]。通常的做法是在一个很大的数据集上进行预训练,得到卷积网络ConvNet, 然后,将这个ConvNet的参数,作为目标任务的初始化参数,或者固定这些参数。
转移学习的两个主要场景:
• 微调Convnet:使用预训练的网络(如在imagenet 1000上训练而来的网络),来初始化自己的网络,而不是随机初始化。其它的训练步骤不变。
• 将Convnet看成固定的特征提取器: 首先固定ConvNet,除了最后的全连接层外的其他所有层。最后的全连接层被替换成一个新的随机初始化的层,只有这个新的层会被训练[只有这层参数会在反向传播时更新]
下面是利用PyTorch进行迁移学习步骤,要解决的问题是,训练一个模型来对蚂蚁和蜜蜂进行分类。
1.导入相关的包

License: BSD

Author: Sasank Chilamkurthy

from future import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy

plt.ion() # interactive mode
2.加载数据
要解决的问题是,训练一个模型来分类蚂蚁ants和蜜蜂bees。ants和bees各有约120张训练图片。每个类有75张验证图片。从零开始,在如此小的数据集上进行训练,通常是很难泛化的。由于使用迁移学习,模型的泛化能力会相当好。该数据集是imagenet的一个非常小的子集。下载数据,并将其解压缩到当前目录。
#训练集数据扩充和归一化
#在验证集上仅需要归一化
data_transforms = {
‘train’: transforms.Compose([
transforms.RandomResizedCrop(224), #随机裁剪一个area然后再resize
transforms.RandomHorizontalFlip(), #随机水平翻转
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
‘val’: transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}

data_dir = ‘data/hymenoptera_data’
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in [‘train’, ‘val’]}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in [‘train’, ‘val’]}
dataset_sizes = {x: len(image_datasets[x]) for x in [‘train’, ‘val’]}
class_names = image_datasets[‘train’].classes

device = torch.device(“cuda:0” if torch.cuda.is_available() else “cpu”)
3.可视化部分图像数据
可视化部分训练图像,以便了解数据扩充。
def imshow(inp, title=None):
“”“Imshow for Tensor.”""
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated

获取一批训练数据

inputs, classes = next(iter(dataloaders[‘train’]))

批量制作网格

out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

4.训练模型
编写一个通用函数来训练模型。下面将说明: * 调整学习速率 * 保存最好的模型
下面的参数scheduler,是一个来自 torch.optim.lr_scheduler的学习速率调整类的对象(LR scheduler object)。
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()

best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# 每个epoch都有一个训练和验证阶段for phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# 迭代数据.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 零参数梯度optimizer.zero_grad()# 前向# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# 后向+仅在训练阶段进行优化if phase == 'train':loss.backward()optimizer.step()# 统计running_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# 深度复制moif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))# 加载最佳模型权重
model.load_state_dict(best_model_wts)
return model

5.可视化模型的预测结果
#一个通用的展示少量预测图片的函数
def visualize_model(model, num_images=6):
was_training = model.training
model.eval()
images_so_far = 0
fig = plt.figure()

with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

6.场景1:微调ConvNet
加载预训练模型,重置最终完全连接的图层。
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)

model_ft = model_ft.to(device)

criterion = nn.CrossEntropyLoss()

观察所有参数都正在优化

optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

每7个epochs衰减LR通过设置gamma=0.1

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
训练和评估模型
(1)训练模型 该过程在CPU上,需要大约15-25分钟,但是在GPU上,它只需不到一分钟。
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
num_epochs=25)
• 输出
Epoch 0/24

train Loss: 0.7032 Acc: 0.6025
val Loss: 0.1698 Acc: 0.9412

Epoch 1/24

train Loss: 0.6411 Acc: 0.7787
val Loss: 0.1981 Acc: 0.9281
·
·
·
Epoch 24/24

train Loss: 0.2812 Acc: 0.8730
val Loss: 0.2647 Acc: 0.9150

Training complete in 1m 7s
Best val Acc: 0.941176
(2)模型评估效果可视化
visualize_model(model_ft)
• 输出

7.场景2:ConvNet作为固定特征提取器
需要冻结除最后一层之外的所有网络。通过设置requires_grad == Falsebackward()
来冻结参数,这样在反向传播backward()的时候,梯度就不会被计算。
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
param.requires_grad = False

Parameters of newly constructed modules have requires_grad=True by default

num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)

model_conv = model_conv.to(device)

criterion = nn.CrossEntropyLoss()

Observe that only parameters of final layer are being optimized as

opposed to before.

optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

Decay LR by a factor of 0.1 every 7 epochs

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
训练和评估
(1)训练模型 在CPU上,与前一个场景相比,这将花费大约一半的时间,因为不需要为大多数网络计算梯度。但需要计算转发。
model_conv = train_model(model_conv, criterion, optimizer_conv,
exp_lr_scheduler, num_epochs=25)
• 输出
Epoch 0/24

train Loss: 0.6400 Acc: 0.6434
val Loss: 0.2539 Acc: 0.9085
·
·
·
Epoch 23/24

train Loss: 0.2988 Acc: 0.8607
val Loss: 0.2151 Acc: 0.9412

Epoch 24/24

train Loss: 0.3519 Acc: 0.8484
val Loss: 0.2045 Acc: 0.9412

Training complete in 0m 35s
Best val Acc: 0.954248
(2)模型评估效果可视化
visualize_model(model_conv)

plt.ioff()
plt.show()
• 输出

8.文件下载
• py文件
• jupyter文件

PyTorch迁移学习相关推荐

  1. PyTorch 迁移学习 (Transfer Learning) 代码详解

    PyTorch 迁移学习 代码详解 概述 为什么使用迁移学习 更好的结果 节省时间 加载模型 ResNet152 冻层实现 模型初始化 获取需更新参数 训练模型 获取数据 完整代码 概述 迁移学习 ( ...

  2. Pytorch迁移学习加载部分预训练权重

    迁移学习在图像分类领域非常常见,利用在超大数据集上训练得到的网络权重,迁移到自己的数据上进行训练可以节约大量的训练时间,降低欠拟合/过拟合的风险. 如果用原生网络进行迁移学习非常简单,其核心是 mod ...

  3. 计算机视觉PyTorch迁移学习 - (二)

    图像迁移学习 3.PyTorch实现迁移学习 3.1数据集预处理 3.2构建模型 3.3模型训练与验证 3.PyTorch实现迁移学习 文件目录 3.1数据集预处理 这里实现一个蚂蚁与蜜蜂的图像分类, ...

  4. PyTorch迁移学习-私人数据集上的蚂蚁蜜蜂分类

    1. 迁移学习的两个主要场景 微调CNN:使用预训练的网络来初始化自己的网络,而不是随机初始化,然后训练即可 将CNN看成固定的特征提取器:固定前面的层,重写最后的全连接层,只有这个新的层会被训练 下 ...

  5. pytorch迁移学习载入部分权重

    载入权重是迁移学习的重要部分,这个权重的来源可以是官方发布的预训练权重,也可以是你自己训练的权重并载入模型进行继续学习.使用官方预训练权重,这样的权重包含的信息量大且全面,可以适配一些小数据的任务,即 ...

  6. 【学习笔记】pytorch迁移学习-猫狗分类实战

    1.迁移学习入门 什么是迁移学习:在深度神经网络算法的引用过程中,如果我们面对的是数据规模较大的问题,那么在搭建好深度神经网络模型后,我们势必要花费大量的算力和时间去训练模型和优化参数,最后耗费了这么 ...

  7. pytorch迁移学习后使用微调策略再次提高模型训练结果

    1.使迁移的模型解冻 for param in model.parameters():param.requires_grad=True 2.此时学习速率设置再小些 optimizer=torch.op ...

  8. pytorch迁移学习--模型建立的代码实现

    1.选择基础模型 model=torchvision.models.vgg16(pretrained=True) 2.看一下模型结构 代码:

  9. 基于Pytorch迁移学习+集成学习的水果霉变区分设计与实现

    1.数据集的介绍 此次采用的数据集中有六种水果,六种水果都有自己的对应的好坏水果集, 数据量:一共12050张图片,包含训练集,测试集和验证集,训练集:共7240张图片,测试集:共1796张图片,验证 ...

最新文章

  1. OpenCV学习笔记(12)——OpenCV中的轮廓
  2. 导师:寒假复现几篇顶会论文?答:3天1篇!
  3. arm-none-linux-gnueabi,安装交叉编译器arm-none-linux-gnueabi-gcc 过程
  4. python 文字语音朗读-python 利用pyttsx3文字转语音
  5. 不同项目之间的控件共享
  6. [AX]AX2012开发新特性-全文索引
  7. 洛谷P1912:诗人小G(二分栈、决策单调性)
  8. plsql例外_大例外背后的真相
  9. 计算机视觉论文-2021-07-01
  10. 查看 linux 网络状态命令,Linux操作系统常用的网络状态查询命令
  11. 计算机教师的幸福,如何成为一名幸福信息技术教师
  12. Tkinter 的 Text 组件
  13. 设计模式(7)——适配器模式
  14. (14)数据结构-二叉排序树
  15. programmer-common-word-pronunciation 程序员常用单词发音
  16. 全基因组关联分析(GWAS)常见问题(工具,概念,脚本)
  17. flashfxp用什么协议连接服务器,flashfxp怎么连接,flashfxp怎么连接,具体的连接方法...
  18. php下载源文件绕开下载地址,Fengcms 最新版v1.24任意文件下载(绕过过滤)
  19. 制导武器的分布式半实物仿真系统研究
  20. PS实战操作之滤镜、通道

热门文章

  1. Go 学习笔记(66)— Go 并发同步原语(sync.Mutex、sync.RWMutex、sync.Once)
  2. 提高班第三周周记(中秋第二天)
  3. BERT可视化工具bertviz体验
  4. Map再整理,从底层源码探究HashMap
  5. 聊一聊Spring中的线程安全性
  6. 摄像头ISP系统原理(上)
  7. Java基础语法运算和控制符
  8. connot not ensure the target project location exist and is accessible
  9. Error:Could not download guava.jar (com.google.guava:guava:19.0): No cached version available for of
  10. Plugin with id 'com.novoda.bintray-release' not found的解决方法