文章目录

  • 1.导入相关的包
  • 2.加载数据
  • 3.可视化部分图像数据
  • 4.训练模型
  • 5.可视化模型的预测结果
  • 6.场景1:微调ConvNet
  • 7.场景2:ConvNet作为固定特征提取器

实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集]。通常的做法是在一个很大的数据集上进行预训练得到卷积网络ConvNet, 然后将这个ConvNet的参数作为目标任务的初始化参数或者固定这些参数。

转移学习的两个主要场景:

  • 微调Convnet:使用预训练的网络(如在imagenet 1000上训练而来的网络)来初始化自己的网络,而不是随机初始化。其他的训练步骤不变。
  • Convnet看成固定的特征提取器:首先固定ConvNet除了最后的全连接层外的其他所有层。最后的全连接层被替换成一个新的随机 初始化的层,只有这个新的层会被训练[只有这层参数会在反向传播时更新]

下面是利用PyTorch进行迁移学习步骤,要解决的问题是训练一个模型来对蚂蚁和蜜蜂进行分类。

1.导入相关的包

from __future__ import print_function, divisionimport torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copyplt.ion()   # 交互模式

注释1:交互模式详情

2.加载数据

今天要解决的问题是训练一个模型来分类蚂蚁ants和蜜蜂bees。ants和bees各有约120张训练图片。每个类有75张验证图片。从零开始在 如此小的数据集上进行训练通常是很难泛化的。由于我们使用迁移学习,模型的泛化能力会相当好。 该数据集是imagenet的一个非常小的子集。从此处下载数据,并将其解压缩到当前目录。

# 训练集数据扩充和归一化
# 在验证集上仅需要归一化
data_transforms = {'train': transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),'val': transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
}data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x])for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,shuffle=True, num_workers=4)for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classesdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

注释2:ImageFolder
torchvision中有一个更常用的数据集类ImageFolder。 它假定了数据集是以如下方式构造的:

root/ants/xxx.png
root/ants/xxy.jpeg
root/ants/xxz.png
.
.
.
root/bees/123.jpg
root/bees/nsdf3.png
root/bees/asd932_.png

也就是"根目录/类别名称/该类别对应的图片"

3.可视化部分图像数据

可视化部分训练图像,以便了解数据扩充。

def imshow(inp, title=None):"""Imshow for Tensor."""inp = inp.numpy().transpose((1, 2, 0))mean = np.array([0.485, 0.456, 0.406])std = np.array([0.229, 0.224, 0.225])inp = std * inp + meaninp = np.clip(inp, 0, 1)plt.imshow(inp)if title is not None:plt.title(title)plt.pause(0.001)  # pause a bit so that plots are updated# 获得一批训练数据
inputs, classes = next(iter(dataloaders['train']))# 批量制作网格
out = torchvision.utils.make_grid(inputs)imshow(out, title=[class_names[x] for x in classes])

4.训练模型

编写一个通用函数来训练模型。下面将说明:

  • 调整学习速率
  • 保存最好的模型

下面的参数scheduler是一个来自 torch.optim.lr_scheduler的学习速率调整类的对象(LRscheduler object)。

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):since = time.time()best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0for epoch in range(num_epochs):print('Epoch {}/{}'.format(epoch, num_epochs - 1))print('-' * 10)# Each epoch has a training and validation phasefor phase in ['train', 'val']:if phase == 'train':model.train()  # Set model to training modeelse:model.eval()   # Set model to evaluate moderunning_loss = 0.0running_corrects = 0# Iterate over data.for inputs, labels in dataloaders[phase]:inputs = inputs.to(device)labels = labels.to(device)# 将参数梯度归零optimizer.zero_grad()# forward# 只在训练上追踪历史with torch.set_grad_enabled(phase == 'train'):outputs = model(inputs)_, preds = torch.max(outputs, 1)loss = criterion(outputs, labels)# backward + optimize 只在训练进行if phase == 'train':loss.backward()optimizer.step()# statisticsrunning_loss += loss.item() * inputs.size(0)running_corrects += torch.sum(preds == labels.data)if phase == 'train':scheduler.step()epoch_loss = running_loss / dataset_sizes[phase]epoch_acc = running_corrects.double() / dataset_sizes[phase]print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))# deep copy the modelif phase == 'val' and epoch_acc > best_acc:best_acc = epoch_accbest_model_wts = copy.deepcopy(model.state_dict())print()time_elapsed = time.time() - sinceprint('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))print('Best val Acc: {:4f}'.format(best_acc))# load best model weightsmodel.load_state_dict(best_model_wts)return model

5.可视化模型的预测结果

一个通用的展示少量预测图片的函数

def visualize_model(model, num_images=6):was_training = model.trainingmodel.eval()images_so_far = 0fig = plt.figure()with torch.no_grad():for i, (inputs, labels) in enumerate(dataloaders['val']):inputs = inputs.to(device)labels = labels.to(device)outputs = model(inputs)_, preds = torch.max(outputs, 1)for j in range(inputs.size()[0]):images_so_far += 1ax = plt.subplot(num_images//2, 2, images_so_far)ax.axis('off')ax.set_title('predicted: {}'.format(class_names[preds[j]]))imshow(inputs.cpu().data[j])if images_so_far == num_images:model.train(mode=was_training)returnmodel.train(mode=was_training)

6.场景1:微调ConvNet

加载预训练模型并重置最终完全连接的图层。

model_ft = models.resnet18(pretrained=True)
# in_features 是fc线性层的输入数量
num_ftrs = model_ft.fc.in_features
# 这里,每个输出样本的大小设置为2。
# 或者,它可以推广到nn.Linear(num_ftrs,len(类名称))。
model_ft.fc = nn.Linear(num_ftrs, 2)model_ft = model_ft.to(device)criterion = nn.CrossEntropyLoss()# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

训练和评估模型
(1)训练模型 该过程在CPU上需要大约15-25分钟,但是在GPU上,它只需不到一分钟。(我在CPU上跑的)

model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,num_epochs=25)
  • 输出
Epoch 0/24
----------
train Loss: 0.6633 Acc: 0.6721
val Loss: 0.2236 Acc: 0.9216Epoch 1/24
----------
train Loss: 0.4831 Acc: 0.7705
val Loss: 0.2813 Acc: 0.9020Epoch 2/24
----------
train Loss: 0.4444 Acc: 0.7828
val Loss: 0.1721 Acc: 0.9477Epoch 3/24
----------
train Loss: 0.4857 Acc: 0.7828
val Loss: 0.1651 Acc: 0.9477Epoch 4/24
----------
train Loss: 0.5882 Acc: 0.7705
val Loss: 0.4952 Acc: 0.8301Epoch 5/24
----------
train Loss: 0.5504 Acc: 0.7869
val Loss: 0.1896 Acc: 0.9412Epoch 6/24
----------
train Loss: 0.5927 Acc: 0.7992
val Loss: 0.3142 Acc: 0.8954Epoch 7/24
----------
train Loss: 0.4391 Acc: 0.8361
val Loss: 0.1705 Acc: 0.9477Epoch 8/24
----------
train Loss: 0.3733 Acc: 0.8566
val Loss: 0.1884 Acc: 0.9346Epoch 9/24
----------
train Loss: 0.3567 Acc: 0.8484
val Loss: 0.2050 Acc: 0.9281Epoch 10/24
----------
train Loss: 0.3769 Acc: 0.8279
val Loss: 0.2070 Acc: 0.9346Epoch 11/24
----------
train Loss: 0.3473 Acc: 0.8648
val Loss: 0.2191 Acc: 0.9281Epoch 12/24
----------
train Loss: 0.3654 Acc: 0.8566
val Loss: 0.1732 Acc: 0.9412Epoch 13/24
----------
train Loss: 0.2885 Acc: 0.8689
val Loss: 0.1959 Acc: 0.9346Epoch 14/24
----------
train Loss: 0.3242 Acc: 0.8525
val Loss: 0.2066 Acc: 0.9281Epoch 15/24
----------
train Loss: 0.3471 Acc: 0.8279
val Loss: 0.1821 Acc: 0.9477Epoch 16/24
----------
train Loss: 0.4058 Acc: 0.8443
val Loss: 0.1773 Acc: 0.9346Epoch 17/24
----------
train Loss: 0.4398 Acc: 0.8279
val Loss: 0.1726 Acc: 0.9477Epoch 18/24
----------
train Loss: 0.3293 Acc: 0.8689
val Loss: 0.1841 Acc: 0.9346Epoch 19/24
----------
train Loss: 0.3484 Acc: 0.8361
val Loss: 0.1846 Acc: 0.9412Epoch 20/24
----------
train Loss: 0.3164 Acc: 0.8402
val Loss: 0.1702 Acc: 0.9542Epoch 21/24
----------
train Loss: 0.3769 Acc: 0.8197
val Loss: 0.1828 Acc: 0.9346Epoch 22/24
----------
train Loss: 0.3204 Acc: 0.8852
val Loss: 0.2065 Acc: 0.9412Epoch 23/24
----------
train Loss: 0.3201 Acc: 0.8852
val Loss: 0.1970 Acc: 0.9346Epoch 24/24
----------
train Loss: 0.2603 Acc: 0.8730
val Loss: 0.2063 Acc: 0.9412Training complete in 33m 18s
Best val Acc: 0.954248Process finished with exit code 0

(2)模型评估结果可视化

visualize_model(model_ft)
  • 输出

7.场景2:ConvNet作为固定特征提取器

在这里需要冻结除最后一层之外的所有网络。通过设置requires_grad == Falsebackward()来冻结参数,这样在反向传播backward()的时候他们的梯度就不会被计算。

model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():param.requires_grad = False# 默认情况下,新建模块的参数需要requires_grad=True
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 2)model_conv = model_conv.to(device)criterion = nn.CrossEntropyLoss()# 请注意,只有最后一层的参数被优化为
# 与以前相反。
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)# 每7个epochs将LR衰减0.1倍
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

训练和评估
(1)训练模型 在CPU上,与前一个场景相比,这将花费大约一半的时间,因为不需要为大多数网络计算梯度。但需要计算转发。

model_conv = train_model(model_conv, criterion, optimizer_conv,exp_lr_scheduler, num_epochs=25)
  • 输出
Epoch 0/24
----------
train Loss: 0.5987 Acc: 0.6393
val Loss: 0.3066 Acc: 0.8693Epoch 1/24
----------
train Loss: 0.5367 Acc: 0.7172
val Loss: 0.1962 Acc: 0.9477Epoch 2/24
----------
train Loss: 0.4591 Acc: 0.7992
val Loss: 0.1974 Acc: 0.9477Epoch 3/24
----------
train Loss: 0.5040 Acc: 0.7787
val Loss: 0.1922 Acc: 0.9346Epoch 4/24
----------
train Loss: 0.5343 Acc: 0.7705
val Loss: 0.2062 Acc: 0.9542Epoch 5/24
----------
train Loss: 0.5048 Acc: 0.7828
val Loss: 0.2255 Acc: 0.9281Epoch 6/24
----------
train Loss: 0.5591 Acc: 0.7951
val Loss: 0.2150 Acc: 0.9346Epoch 7/24
----------
train Loss: 0.3710 Acc: 0.8402
val Loss: 0.2361 Acc: 0.9346Epoch 8/24
----------
train Loss: 0.2645 Acc: 0.8934
val Loss: 0.2024 Acc: 0.9346Epoch 9/24
----------
train Loss: 0.3999 Acc: 0.8402
val Loss: 0.1959 Acc: 0.9477Epoch 10/24
----------
train Loss: 0.3806 Acc: 0.8238
val Loss: 0.2191 Acc: 0.9346Epoch 11/24
----------
train Loss: 0.4044 Acc: 0.8402
val Loss: 0.1941 Acc: 0.9542Epoch 12/24
----------
train Loss: 0.3234 Acc: 0.8648
val Loss: 0.1977 Acc: 0.9477Epoch 13/24
----------
train Loss: 0.3640 Acc: 0.8361
val Loss: 0.2026 Acc: 0.9346Epoch 14/24
----------
train Loss: 0.4070 Acc: 0.8115
val Loss: 0.1912 Acc: 0.9542Epoch 15/24
----------
train Loss: 0.3331 Acc: 0.8484
val Loss: 0.2011 Acc: 0.9346Epoch 16/24
----------
train Loss: 0.3006 Acc: 0.8770
val Loss: 0.1766 Acc: 0.9542Epoch 17/24
----------
train Loss: 0.3397 Acc: 0.8443
val Loss: 0.2180 Acc: 0.9412Epoch 18/24
----------
train Loss: 0.3332 Acc: 0.8443
val Loss: 0.1928 Acc: 0.9477Epoch 19/24
----------
train Loss: 0.3563 Acc: 0.8238
val Loss: 0.1982 Acc: 0.9477Epoch 20/24
----------
train Loss: 0.3222 Acc: 0.8566
val Loss: 0.2268 Acc: 0.9281Epoch 21/24
----------
train Loss: 0.4554 Acc: 0.8115
val Loss: 0.2420 Acc: 0.9281Epoch 22/24
----------
train Loss: 0.3066 Acc: 0.8648
val Loss: 0.1828 Acc: 0.9542Epoch 23/24
----------
train Loss: 0.4099 Acc: 0.8279
val Loss: 0.2061 Acc: 0.9477Epoch 24/24
----------
train Loss: 0.3176 Acc: 0.8648
val Loss: 0.2098 Acc: 0.9346Training complete in 0m 34s
Best val Acc: 0.954248

(2)模型评测结果可视化

visualize_model(model_conv)plt.ioff()
plt.show()
  • 输出

pytorch之迁移学习相关推荐

  1. 使用PyTorch进行迁移学习

    概述 迁移学习可以改变你建立机器学习和深度学习模型的方式 了解如何使用PyTorch进行迁移学习,以及如何将其与使用预训练的模型联系起来 我们将使用真实世界的数据集,并比较使用卷积神经网络(CNNs) ...

  2. Pytorch实现迁移学习

    迁移学习 迁移学习是一种机器学习的方法,指的是一个预训练的模型被重新用在另一个任务中,它专注于存储已有问题的解决模型,并将其利用在其他不同但相关问题上.例如我在A的场景下训练了一个模型,而B.C.D等 ...

  3. Pytorch 的迁移学习的理解

    个人理解,迁移学习可以分为三类: 第一类:以训练好的模型参数为基础,对所有参数进行继续优化. 即,先在别的训练数据集上训练模型,达到一定训练标准之后,用当前的数据集继续进行训练. 第二类:将已经训练好 ...

  4. Resnet152对102种花朵图像分类(PyTorch,迁移学习)

    目录 1.介绍 1.1.项目数据及源码 1.2.数据集介绍 1.3.任务介绍 1.4.ResNet网络介绍 2.数据预处理 3.展示数据 4.进行迁移学习 4.1.训练全连接层 4.2.训练所有层 5 ...

  5. 【PyTorch】迁移学习:基于 VGG 实现的光明哨兵与破败军团分类器

    文章目录 简述. 环境配置. PyTorch代码. 导入第三方库. 使用 GPU. 加载数据. 定义可视化函数. 加载预训练模型. 冻结特征层. 修改输出层. 定义优化器. 定义训练函数. 训练过程. ...

  6. pytorch添加迁移学习

    # 参数设置(指定用第几轮的预训练权重) parser = argparse.ArgumentParser(description="PyTorch Net") parser.ad ...

  7. PyTorch迁移学习

    PyTorch迁移学习 实际中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集[网络很深, 需要足够大数据集].通常的做法是在一个很大的数据集上进 ...

  8. PyTorch基础(六)迁移学习

    在实际工程中,基本没有人会从零开始(随机初始化)训练一个完整的卷积网络,因为相对于网络,很难得到一个足够大的数据集(网络很深,需要足够大数据集).通常的做法是在一个很大的数据集上进行预训练得到卷积网络 ...

  9. 【Pytorch实战6】一个完整的分类案例:迁移学习分类蚂蚁和蜜蜂(Res18,VGG16)

    参考资料: <深度学习之pytorch实战计算机视觉> Pytorch官方教程 Pytorch官方文档 本文是采用pytorch进行迁移学习的实战演练,实战目的是为了进一步学习和熟悉pyt ...

最新文章

  1. python实现复制文件功能
  2. SQL Server 2008用'sa'登录失败,启用'sa'登录的办法
  3. Converting slapd.conf to a Directory Based Configu
  4. pandas 对某一行标准化_Python中的神器Pandas,但是有人说Pandas慢...
  5. Python数据分析模块 | pandas做数据分析(三):统计相关函数
  6. Android 系统(90)---JIT 编译器
  7. AndroidStudio安卓原生开发_Activity的IntentFlag的SINGLE_TOP_CLEAR_TOP_REORDER_TO_FRONT的用法---Android原生开发工作笔记90
  8. uniapp:H5页面长按识别二维码
  9. 换了3根高清线后,第四根mini dp转HDMI线终于可以显示4K了
  10. Excel中如何快速汇总带单位的数据
  11. seo人员的每日工作内容应该都有什么?
  12. ATFX:道琼斯指数的反弹,11月能否突破35000关口?
  13. python中美元人汇率_Python获取美元人民币实时汇率
  14. SAP PS 第0节 PS PA有哪些知识点及IDES练习
  15. 模式识别第二课 建立MFC窗口+插入图片+处理+显示图片
  16. 服务器部署sas_如何在阿里云SAS上部署WordPress网站
  17. OpenCV Java入门六 使用神经网算法辩识人脸
  18. Seabron作图:
  19. 关于Python的虚拟环境
  20. html恋爱纪念页面,HTML5适合的情人节礼物有纪念日期功能

热门文章

  1. 【Neo4j】第 10 章:图嵌入 - 从图到矩阵
  2. 计算机的语言是美式英语,为什么电脑的语言栏一直有两国语言“CH中文(中国)”和“EH英语(美国)”...
  3. 榕树说技术支持(Rong Zhiyun technical support)
  4. 椭圆曲线加密(ECC)
  5. 工作展望简短_2018励志句子简短大全 展望2018励志正能量句子最新励志说说
  6. World Streamer学习4
  7. 迅为i.MX6ULL 开发板开机进度条修改文档
  8. 什么是网站备案?如何查询网站是否备案?
  9. RockyLinux9.0系统在VMware虚拟机上【保姆级】安装步骤,并修改网络配置,使用固定IP进行SSH连接【47张过程图】
  10. office2010 word发布博客 博客园