PyTorch框架学习二十——模型微调(Finetune)

  • 一、Transfer Learning:迁移学习
  • 二、Model Finetune:模型的迁移学习
  • 三、看个例子:用ResNet18预训练模型训练一个图片二分类任务

因为模型微调的内容没有实际使用过,但是后面是肯定会要了解的,所以这里算是一个引子,简单从概念上介绍一下迁移学习与模型微调,后面有时间或需要用到时再去详细了解。

一、Transfer Learning:迁移学习

是机器学习(ML)的一项分支,主要研究源域的知识如何应用到目标域。将源域所学习到的知识应用到目标任务当中,用于提升在目标任务里模型的性能

所以迁移学习的主要目的就是借助其他的知识提升模型性能。

详细了解可以参考这篇综述:《A Survey on Transfer Learning》

二、Model Finetune:模型的迁移学习

训练一个Model,就是去更新它的权值,这里的权值可以称为知识,从AlexNet的卷积核可视化中,我们可以看到大多数卷积核为边缘等信息,这些信息就是AlexNet在ImageNet上学习到的知识,所以可以把权值理解为神经网络在特定任务中学习到的知识,而这些知识可以迁移,将其迁移到新任务中,这样就完成了一个Transfer Learning,这就是模型微调,这就是为什么称Model Finetune为Transfer Learning,它其实是将权值认为是知识,把这些知识应用到新任务中去。

为什么要 Model Finetune?

一般来说需要模型微调的任务都有如下特点:在新任务中数据量较小,不足以训练一个较大的Model。可以用Model Finetune的方式辅助我们在新任务中训练一个较好的模型,让训练过程更快。

模型微调的步骤

一般来说,一个神经网络模型可以分为Features ExtractorClassifer两部分,前者用于提取特征,后者用于合理分类,通常我们习惯对Features Extractor的结构和参数进行保留,而仅修改Classifer来适应新任务。这是因为新任务的数据量太小,预训练参数已经具有共性,不再需要改变,如果再用这些小数据训练,可能反而过拟合。

所以步骤如下:

  1. 获取预训练模型参数
  2. 加载参数至模型(load_state_dict)
  3. 修改输出层以适应新任务

模型微调训练方法

因为需要保留Features Extractor的结构和参数,提出了两种训练方法:

  1. 固定预训练的参数:requires_grad = False 或者 lr = 0,即不更新参数;
  2. Features Extractor部分设置很小的学习率,这里用到参数组(params_group)的概念,分组设置优化器的参数。

三、看个例子:用ResNet18预训练模型训练一个图片二分类任务

涉及到的data:https://pan.baidu.com/s/115grxHrq6kMZBg6oC2fatg
提取码:yld7

# -*- coding: utf-8 -*-
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as pltimport sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)from tools.my_dataset import AntsDataset
from tools.common_tools import set_seed
import torchvision.models as models
import torchvision
BASEDIR = os.path.dirname(os.path.abspath(__file__))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("use device :{}".format(device))set_seed(1)  # 设置随机种子
label_name = {"ants": 0, "bees": 1}# 参数设置
MAX_EPOCH = 25
BATCH_SIZE = 16
LR = 0.001
log_interval = 10
val_interval = 1
classes = 2
start_epoch = -1
lr_decay_step = 7# ============================ step 1/5 数据 ============================
data_dir = os.path.abspath(os.path.join(BASEDIR, "..", "..", "data", "hymenoptera_data"))
if not os.path.exists(data_dir):raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip  放到\n{} 下,并解压即可".format(data_dir, os.path.dirname(data_dir)))train_dir = os.path.join(data_dir, "train")
valid_dir = os.path.join(data_dir, "val")norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])valid_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(norm_mean, norm_std),
])# 构建MyDataset实例
train_data = AntsDataset(data_dir=train_dir, transform=train_transform)
valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform)# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)# ============================ step 2/5 模型 ============================# 1/3 构建模型
resnet18_ft = models.resnet18()# 2/3 加载参数
# flag = 0
flag = 1
if flag:path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data", "finetune_resnet18-5c106cde.pth")if not os.path.exists(path_pretrained_model):raise Exception("\n{} 不存在,请下载 07-02-数据-模型finetune.zip\n放到 {}下,并解压即可".format(path_pretrained_model, os.path.dirname(path_pretrained_model)))state_dict_load = torch.load(path_pretrained_model)resnet18_ft.load_state_dict(state_dict_load)# 法1 : 冻结卷积层
flag_m1 = 0
# flag_m1 = 1
if flag_m1:for param in resnet18_ft.parameters():param.requires_grad = Falseprint("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...]))# 3/3 替换fc层
num_ftrs = resnet18_ft.fc.in_features
resnet18_ft.fc = nn.Linear(num_ftrs, classes)resnet18_ft.to(device)
# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数# ============================ step 4/5 优化器 ============================
# 法2 : conv 小学习率
# flag = 0
flag = 1
if flag:fc_params_id = list(map(id, resnet18_ft.fc.parameters()))     # 返回的是parameters的 内存地址base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters())optimizer = optim.SGD([{'params': base_params, 'lr': LR*0},   # 0{'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9)else:optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9)               # 选择优化器scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1)     # 设置学习率下降策略# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()for epoch in range(start_epoch + 1, MAX_EPOCH):loss_mean = 0.correct = 0.total = 0.resnet18_ft.train()for i, data in enumerate(train_loader):# forwardinputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = resnet18_ft(inputs)# backwardoptimizer.zero_grad()loss = criterion(outputs, labels)loss.backward()# update weightsoptimizer.step()# 统计分类情况_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).squeeze().cpu().sum().numpy()# 打印训练信息loss_mean += loss.item()train_curve.append(loss.item())if (i+1) % log_interval == 0:loss_mean = loss_mean / log_intervalprint("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))loss_mean = 0.# if flag_m1:print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...]))scheduler.step()  # 更新学习率# validate the modelif (epoch+1) % val_interval == 0:correct_val = 0.total_val = 0.loss_val = 0.resnet18_ft.eval()with torch.no_grad():for j, data in enumerate(valid_loader):inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)outputs = resnet18_ft(inputs)loss = criterion(outputs, labels)_, predicted = torch.max(outputs.data, 1)total_val += labels.size(0)correct_val += (predicted == labels).squeeze().cpu().sum().numpy()loss_val += loss.item()loss_val_mean = loss_val/len(valid_loader)valid_curve.append(loss_val_mean)print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_mean, correct_val / total_val))resnet18_ft.train()train_x = range(len(train_curve))
train_y = train_curvetrain_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curveplt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

输出结果为:

use device :cpu
Training:Epoch[000/025] Iteration[010/016] Loss: 0.6572 Acc:60.62%
epoch:0 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[000/025] Iteration[010/010] Loss: 0.4565 Acc:84.97%
Training:Epoch[001/025] Iteration[010/016] Loss: 0.4074 Acc:85.00%
epoch:1 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[001/025] Iteration[010/010] Loss: 0.2846 Acc:93.46%
Training:Epoch[002/025] Iteration[010/016] Loss: 0.3542 Acc:83.12%
epoch:2 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[002/025] Iteration[010/010] Loss: 0.2904 Acc:89.54%
Training:Epoch[003/025] Iteration[010/016] Loss: 0.2266 Acc:93.12%
epoch:3 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[003/025] Iteration[010/010] Loss: 0.2252 Acc:94.12%
Training:Epoch[004/025] Iteration[010/016] Loss: 0.2805 Acc:87.50%
epoch:4 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[004/025] Iteration[010/010] Loss: 0.1953 Acc:95.42%
Training:Epoch[005/025] Iteration[010/016] Loss: 0.2423 Acc:91.88%
epoch:5 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)
Valid:   Epoch[005/025] Iteration[010/010] Loss: 0.2399 Acc:92.16%
Training:Epoch[006/025] Iteration[010/016] Loss: 0.2455 Acc:90.00%
epoch:6 conv1.weights[0, 0, ...] :tensor([[-0.0104, -0.0061, -0.0018,  0.0748,  0.0566,  0.0171, -0.0127],[ 0.0111,  0.0095, -0.1099, -0.2805, -0.2712, -0.1291,  0.0037],[-0.0069,  0.0591,  0.2955,  0.5872,  0.5197,  0.2563,  0.0636],[ 0.0305, -0.0670, -0.2984, -0.4387, -0.2709, -0.0006,  0.0576],[-0.0275,  0.0160,  0.0726, -0.0541, -0.3328, -0.4206, -0.2578],[ 0.0306,  0.0410,  0.0628,  0.2390,  0.4138,  0.3936,  0.1661],[-0.0137, -0.0037, -0.0241, -0.0659, -0.1507, -0.0822, -0.0058]],grad_fn=<SelectBackward>)

可以看出,模型的训练从一开始就有了较高的准确率,比较快速地进入了较好训练状态,相比于不借助其他知识的普通训练,速度上要快很多。

而且这里是用分组参数的方法将特征提取部分的学习率设置为0,这样就不改变特征提取部分的参数了,而将全连接层的学习率正常设置,从上面的结果也能看出特征提取部分的权值一直没有改变(改变的是全连接层的权值,所以准确率才会提升)。

ps:这次笔记涉及的迁移学习的知识还只是基础,以后若有需要还要更加深入。

PyTorch框架学习二十——模型微调(Finetune)相关推荐

  1. PyTorch框架学习二——基本数据结构(张量)

    PyTorch框架学习二--基本数据结构(张量) 一.什么是张量? 二.Tensor与Variable(PyTorch中) 1.Variable 2.Tensor 三.Tensor的创建 1.直接创建 ...

  2. PyTorch框架学习十九——模型加载与保存

    PyTorch框架学习十九--模型加载与保存 一.序列化与反序列化 二.PyTorch中的序列化与反序列化 1.torch.save 2.torch.load 三.模型的保存 1.方法一:保存整个Mo ...

  3. PyTorch框架学习十二——损失函数

    PyTorch框架学习十二--损失函数 一.损失函数的作用 二.18种常见损失函数简述 1.L1Loss(MAE) 2.MSELoss 3.SmoothL1Loss 4.交叉熵CrossEntropy ...

  4. PyTorch框架学习十八——Layer Normalization、Instance Normalization、Group Normalization

    PyTorch框架学习十八--Layer Normalization.Instance Normalization.Group Normalization 一.为什么要标准化? 二.BN.LN.IN. ...

  5. PyTorch框架学习十六——正则化与Dropout

    PyTorch框架学习十六--正则化与Dropout 一.泛化误差 二.L2正则化与权值衰减 三.正则化之Dropout 补充: 这次笔记主要关注防止模型过拟合的两种方法:正则化与Dropout. 一 ...

  6. PyTorch框架学习十五——可视化工具TensorBoard

    PyTorch框架学习十五--可视化工具TensorBoard 一.TensorBoard简介 二.TensorBoard安装及测试 三.TensorBoard的使用 1.add_scalar() 2 ...

  7. PyTorch框架学习十四——学习率调整策略

    PyTorch框架学习十四--学习率调整策略 一._LRScheduler类 二.六种常见的学习率调整策略 1.StepLR 2.MultiStepLR 3.ExponentialLR 4.Cosin ...

  8. PyTorch框架学习十——基础网络层(卷积、转置卷积、池化、反池化、线性、激活函数)

    PyTorch框架学习十--基础网络层(卷积.转置卷积.池化.反池化.线性.激活函数) 一.卷积层 二.转置卷积层 三.池化层 1.最大池化nn.MaxPool2d 2.平均池化nn.AvgPool2 ...

  9. PyTorch框架学习六——图像预处理transforms(二)

    PyTorch框架学习六--图像预处理transforms(二) (续)二.transforms的具体方法 4.图像变换 (1)尺寸变换:transforms.Resize() (2)标准化:tran ...

最新文章

  1. 注意力机制又一大作!DCANet:学习卷积神经网络的连接注意力
  2. 如何退出Activity?如何安全退出已调用多个Activity的Application?
  3. elastic job review
  4. 类variant解剖
  5. 系出名门Android(2) - 布局(Layout)和菜单(Menu)
  6. JAVAWEB入门之Sevlet的执行原理
  7. c#+wpf项目性能优化之OutOfMemoryException解密
  8. 整理的C++面试题,大厂面试总遇到!
  9. Cannot load JDBC driver class 'com.mysql.jdbc.Driver '
  10. Python3 数据类型-Number
  11. OpenWrt系列教程汇总 OpenWrt简体中文Wiki
  12. 用 Vue 做一个简单的购物app
  13. 大牛手把手教你用树莓派玩红警OPENRA
  14. 移动边缘计算(MEC)
  15. 不再因BT吃官司 Magnet能否将BT漂白?
  16. Linux版phpstudy搭建
  17. Java进阶之——飞翔的小鸟游戏项目
  18. 赢在中国 - 史玉柱经典语录(转载)
  19. python 数字转换为汉字大写
  20. 须知年少凌云志 曾许人间第一流

热门文章

  1. 神经网络 - 用单层感知器实现多个神经元的分类 - (Matlab建模)
  2. mysql slow log 分析工具_mysql slow log分析工具的比较
  3. 2019 年,智能问答(Question Answering)的主要研究方向有哪些?
  4. Spring Boot中使用@Async实现异步调用
  5. 史上最强Dubbo面试28题答案详解:核心功能+服务治理+架构设计等
  6. 阿里P8架构师谈:MySQL行锁、表锁、悲观锁、乐观锁的特点与应用
  7. 实时事理学习与搜索平台DemoV1.0正式对外发布
  8. 06.动态SQL和foreach
  9. 一至七-----小东西
  10. CentOS 6快捷安装RabbitMQ教程