GitHub 地址

在本教程中,您将学习如何使用迁移学习来训练您的网络。您可以在 cs231n 笔记 上关于迁移学习的信息

在实践中,很少有人从头开始训练整个卷积网络(随机初始化),因为拥有足够大小的数据集是相对罕见的。相反,通常在非常大的数据集(例如 ImageNet,其包含具有1000个类别的120万个图像)上预先训练 ConvNet,然后使用 ConvNet 作为感兴趣任务的初始化或固定特征提取器。

如下是两个主要的迁移学习场景:

1)Finetuning the convnet: 我们使用预训练网络初始化网络,而不是随机初始化,就像在imagenet 1000数据集上训练的网络一样。其余训练看起来像往常一样。

2)ConvNet as fixed feature extractor: 在这里,我们将冻结除最终完全连接层之外的所有网络的权重(用预训练网络初始化)。最后一个全连接层被替换为具有随机权重的新层,并且仅训练该层。

# License: BSD
# Author: Sasank Chilamkurthyfrom __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copyplt.ion()   # interactive mode
  • 加载数据

我们将使用 torchvision 和 torch.utils.data 包来加载数据。

我们今天要解决的问题是训练一个模型来对 蚂蚁 和 蜜蜂 进行分类。我们有大约120个训练图像,每个图像用于 蚂蚁 和 蜜蜂。每个类有75个验证图像。通常,如果从头开始训练,这是一个非常小的数据集。由于我们正在使用迁移学习,我们应该能够合理地推广。

该数据集是 imagenet 的一个非常小的子集。

从 此处 下载数据并将其解压缩到当前目录。

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  • 可视化一些图像

让我们可视化一些训练图像,以便了解数据增强。

def imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs)imshow(out, title=[class_names[x] for x in classes])

  • 训练模型

现在, 让我们编写一个通用函数来训练模型. 这里, 我们将会举例说明:

1)调度学习率

2)保存最佳的学习模型

下面函数中, scheduler 参数是 torch.optim.lr_scheduler 中的 LR scheduler 对象.

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model
  • 可视化模型预测

用于显示少量图像预测的通用功能。

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)
  • 微调卷积网络

加载预训练模型并重置最终的全连接层。

model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features #最后一个全连接层的输入特征数
model_ft.fc = nn.Linear(num_ftrs, 2) #重置最后一个全连接层model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) #学习率衰减
  • 训练和评估

CPU上需要大约15-25分钟。但是在GPU上,它只需不到一分钟。

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)

visualize_model(model_ft)

  • ConvNet作为固定特征提取器

在这里,我们需要冻结除最后一层之外的所有网络。我们需要设置 requires_grad == False 冻结参数,以便在 backward() 中不计算梯度。

您可以在 此处 的文档中相关信息。

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():param.requires_grad = False #冻结# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
  • 训练和评估

在CPU上,与前一个场景相比,这将花费大约一半的时间。这是预期的,因为不需要为大多数网络计算梯度。但是,前向传递需要计算梯度。

model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=25)

visualize_model(model_conv)plt.ioff()

PyTorch中文教程 | (4) 迁移学习教程相关推荐

  1. PyTorch 1.0 中文官方教程:迁移学习教程

    译者:片刻 作者: Sasank Chilamkurthy 在本教程中,您将学习如何使用迁移学习来训练您的网络.您可以在 cs231n 笔记 上关于迁移学习的信息 引用这些笔记: 在实践中,很少有人从 ...

  2. pytorch与keras_Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者

    pytorch与keras by Patryk Miziuła 通过PatrykMiziuła Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者 (Keras vs PyTorch ...

  3. PyTorch实战使用Resnet迁移学习

    PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower_data文件夹 cat_to_name.json是makejson文件运行生 ...

  4. pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类

    只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...

  5. PyTorch系列 | 快速入门迁移学习

    点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来源:Pexels,作者:Arthur Ogleznev 2019 ...

  6. oracle安装搜狗输入法教程,Linux入门学习教程:在Ubuntu 14.04中安装使用搜狗拼音输入法...

    然后,访问搜狗输入法Linux版的官网,http://pinyin.sogou.com/linux,下载搜狗输入法Linux版.从官网可以看到,该输入法官方只支持Ubuntu(不过网上有人通过将deb ...

  7. ★教程2:fpga学习教程入门100例目录

    1.订阅本教程用户可以免费获得本博任意2个(包括所有免费专栏和付费专栏)博文对应代码: 2.本FPGA课程的所有案例(部分理论知识点除外)均由博主编写而成,供有兴趣的朋友们自己订阅学习使用.未经本人允 ...

  8. PyTorch实现基于ResNet18迁移学习的宝可梦数据集分类

    一.实现过程 1.数据集描述 数据集分为5类,分别如下: 皮卡丘:234 超梦:239 杰尼龟:223 小火龙:238 妙蛙种子:234 自取链接:https://pan.baidu.com/s/1b ...

  9. Java 开源中文分词器Ansj 学习教程

    Java有11大开源中文分词器,分别是word分词器,Ansj分词器,Stanford分词器,FudanNLP分词器,Jieba分词器,Jcseg分词器,MMSeg4j分词器,IKAnalyzer分词 ...

  10. ★教程3:Simulink学习教程入门60例目录

    1.订阅本教程用户可以免费获得本博任意1个(包括所有免费专栏和付费专栏)博文对应代码: (私信博主给出代码博文的链接和邮箱) 2.本Simulink课程的所有案例(部分理论知识点除外)均由博主编写而成 ...

最新文章

  1. Win95架构师发布移动设备富媒体文档创建平台
  2. python培训班排行榜-深圳python培训机构排行榜
  3. VC++调试程序、快捷键以及Debug版本与Release版本
  4. linux查看java jdk安装路径和设置环境变量
  5. 拆解声网Q4财报:除了“元宇宙”,我们还应该关注什么?
  6. LeetCode 多线程 1117. H2O 生成
  7. mysqlbinlog工具_mysqlbinlog命令详解 Part 1-实验环境准备
  8. ini_set() 函数的使用 以及 post_max_size,upload_max_filesize的修改方法
  9. OPENCV 实现png绘制,alpha通道叠加。
  10. scala练习——fold函数
  11. Rust学习教程30 - Panic原理剖析
  12. 使用ZYNQ实现单LUT内容的动态修改(一)PL端OOC设计流程
  13. ios 通过代码调整焦距
  14. 爱德泰科普 | 电信级单模光纤跳线在综合布线中的连接方法
  15. 2022广东最新八大员之(安全员)模拟试题题库及答案
  16. php使用eval上传文件,PHP一句话实现单个文件批量上传?
  17. 计算机网络(ISP,因特网组成,分组交换,计算机网络性能,网络体系机构)
  18. leaflet 加载腾讯地图
  19. c语言贪吃蛇大作业报告,C语言贪吃蛇实验报告
  20. php iphone壁纸,经典!iPhone 8全套超高清壁纸出炉

热门文章

  1. WES7@IIC-China
  2. 《女士品茶》与统计检验
  3. 数据库 的日志已满,备份该数据库的事务日志以释放一些日志空间的解决办法 ...
  4. 别找了,这就是你心心念念想要的年会活动抽奖软件
  5. 图片节点html,Qunee for HTML5 - 中文 : 节点图片
  6. 学习数据库系统概念,设计及应用心得
  7. Ubuntu以及CentOS7修改ssh端口号详细步骤
  8. 有趣有用的PCA——PCA压缩图片
  9. ubuntu svn命令
  10. docker视频教程下载