PyTorch中文教程 | (4) 迁移学习教程
GitHub 地址
在本教程中,您将学习如何使用迁移学习来训练您的网络。您可以在 cs231n 笔记 上关于迁移学习的信息
在实践中,很少有人从头开始训练整个卷积网络(随机初始化),因为拥有足够大小的数据集是相对罕见的。相反,通常在非常大的数据集(例如 ImageNet,其包含具有1000个类别的120万个图像)上预先训练 ConvNet,然后使用 ConvNet 作为感兴趣任务的初始化或固定特征提取器。
如下是两个主要的迁移学习场景:
1)Finetuning the convnet: 我们使用预训练网络初始化网络,而不是随机初始化,就像在imagenet 1000数据集上训练的网络一样。其余训练看起来像往常一样。
2)ConvNet as fixed feature extractor: 在这里,我们将冻结除最终完全连接层之外的所有网络的权重(用预训练网络初始化)。最后一个全连接层被替换为具有随机权重的新层,并且仅训练该层。
# License: BSD
# Author: Sasank Chilamkurthyfrom __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copyplt.ion() # interactive mode
- 加载数据
我们将使用 torchvision 和 torch.utils.data 包来加载数据。
我们今天要解决的问题是训练一个模型来对 蚂蚁 和 蜜蜂 进行分类。我们有大约120个训练图像,每个图像用于 蚂蚁 和 蜜蜂。每个类有75个验证图像。通常,如果从头开始训练,这是一个非常小的数据集。由于我们正在使用迁移学习,我们应该能够合理地推广。
该数据集是 imagenet 的一个非常小的子集。
从 此处 下载数据并将其解压缩到当前目录。
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- 可视化一些图像
让我们可视化一些训练图像,以便了解数据增强。
def imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001) # pause a bit so that plots are updated# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))# Make a grid from batch
out = torchvision.utils.make_grid(inputs)imshow(out, title=[class_names[x] for x in classes])
- 训练模型
现在, 让我们编写一个通用函数来训练模型. 这里, 我们将会举例说明:
1)调度学习率
2)保存最佳的学习模型
下面函数中, scheduler
参数是 torch.optim.lr_scheduler
中的 LR scheduler 对象.
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':scheduler.step()model.train() # Set model to training modeelse:model.eval() # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward# track history if only in trainwith torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize only if in training phaseif phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model
- 可视化模型预测
用于显示少量图像预测的通用功能。
def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)
- 微调卷积网络
加载预训练模型并重置最终的全连接层。
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features #最后一个全连接层的输入特征数
model_ft.fc = nn.Linear(num_ftrs, 2) #重置最后一个全连接层model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) #学习率衰减
- 训练和评估
CPU上需要大约15-25分钟。但是在GPU上,它只需不到一分钟。
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
visualize_model(model_ft)
- ConvNet作为固定特征提取器
在这里,我们需要冻结除最后一层之外的所有网络。我们需要设置 requires_grad == False
冻结参数,以便在 backward()
中不计算梯度。
您可以在 此处 的文档中相关信息。
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():param.requires_grad = False #冻结# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
- 训练和评估
在CPU上,与前一个场景相比,这将花费大约一半的时间。这是预期的,因为不需要为大多数网络计算梯度。但是,前向传递需要计算梯度。
model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=25)
visualize_model(model_conv)plt.ioff()
PyTorch中文教程 | (4) 迁移学习教程相关推荐
- PyTorch 1.0 中文官方教程:迁移学习教程
译者:片刻 作者: Sasank Chilamkurthy 在本教程中,您将学习如何使用迁移学习来训练您的网络.您可以在 cs231n 笔记 上关于迁移学习的信息 引用这些笔记: 在实践中,很少有人从 ...
- pytorch与keras_Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者
pytorch与keras by Patryk Miziuła 通过PatrykMiziuła Keras vs PyTorch:如何通过迁移学习区分外星人与掠食者 (Keras vs PyTorch ...
- PyTorch实战使用Resnet迁移学习
PyTorch实战使用Resnet迁移学习 项目结构 项目任务 项目代码 网络模型测试 项目结构 数据集存放在flower_data文件夹 cat_to_name.json是makejson文件运行生 ...
- pytorch1.7教程实验——迁移学习训练卷积神经网络进行图像分类
只是贴上跑通的代码以供参考学习 参考网址:迁移学习训练卷积神经网络进行图像分类 需要用到的数据集下载网址: https://download.pytorch.org/tutorial/hymenopt ...
- PyTorch系列 | 快速入门迁移学习
点击上方"算法猿的成长",选择"加为星标" 第一时间关注 AI 和 Python 知识 图片来源:Pexels,作者:Arthur Ogleznev 2019 ...
- oracle安装搜狗输入法教程,Linux入门学习教程:在Ubuntu 14.04中安装使用搜狗拼音输入法...
然后,访问搜狗输入法Linux版的官网,http://pinyin.sogou.com/linux,下载搜狗输入法Linux版.从官网可以看到,该输入法官方只支持Ubuntu(不过网上有人通过将deb ...
- ★教程2:fpga学习教程入门100例目录
1.订阅本教程用户可以免费获得本博任意2个(包括所有免费专栏和付费专栏)博文对应代码: 2.本FPGA课程的所有案例(部分理论知识点除外)均由博主编写而成,供有兴趣的朋友们自己订阅学习使用.未经本人允 ...
- PyTorch实现基于ResNet18迁移学习的宝可梦数据集分类
一.实现过程 1.数据集描述 数据集分为5类,分别如下: 皮卡丘:234 超梦:239 杰尼龟:223 小火龙:238 妙蛙种子:234 自取链接:https://pan.baidu.com/s/1b ...
- Java 开源中文分词器Ansj 学习教程
Java有11大开源中文分词器,分别是word分词器,Ansj分词器,Stanford分词器,FudanNLP分词器,Jieba分词器,Jcseg分词器,MMSeg4j分词器,IKAnalyzer分词 ...
- ★教程3:Simulink学习教程入门60例目录
1.订阅本教程用户可以免费获得本博任意1个(包括所有免费专栏和付费专栏)博文对应代码: (私信博主给出代码博文的链接和邮箱) 2.本Simulink课程的所有案例(部分理论知识点除外)均由博主编写而成 ...
最新文章
- Win95架构师发布移动设备富媒体文档创建平台
- python培训班排行榜-深圳python培训机构排行榜
- VC++调试程序、快捷键以及Debug版本与Release版本
- linux查看java jdk安装路径和设置环境变量
- 拆解声网Q4财报:除了“元宇宙”,我们还应该关注什么?
- LeetCode 多线程 1117. H2O 生成
- mysqlbinlog工具_mysqlbinlog命令详解 Part 1-实验环境准备
- ini_set() 函数的使用 以及 post_max_size,upload_max_filesize的修改方法
- OPENCV 实现png绘制,alpha通道叠加。
- scala练习——fold函数
- Rust学习教程30 - Panic原理剖析
- 使用ZYNQ实现单LUT内容的动态修改(一)PL端OOC设计流程
- ios 通过代码调整焦距
- 爱德泰科普 | 电信级单模光纤跳线在综合布线中的连接方法
- 2022广东最新八大员之(安全员)模拟试题题库及答案
- php使用eval上传文件,PHP一句话实现单个文件批量上传?
- 计算机网络(ISP,因特网组成,分组交换,计算机网络性能,网络体系机构)
- leaflet 加载腾讯地图
- c语言贪吃蛇大作业报告,C语言贪吃蛇实验报告
- php iphone壁纸,经典!iPhone 8全套超高清壁纸出炉