迁移学习(Transfer Learning)

迁移学习是机器学习的分支,提出的初衷是节省人工标注样本的时间,让模型可以通过一个已有的标记数据向未标记数据领域进行迁移从而训练出适用于该领域的模型,直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。
举一个例子就能很好的说明问题,我们学习编程的时候会学习什么?语法、特定语言的API、流程处理、面向对象,设计模式和面向对象是不用去学的,因为原理都是一样的,甚至在学习C#的时候语法都可以少学很多,这就是迁移学习的概念呢,把统一的概念抽象出来,只学习不同的内容。
其实Transfer Learing和Fine tune并没有严格的区分,只不过后者跟常用于形容迁移学习的后期微调。

模型微调(Fine tuning)

什么是模型微调?

针对于某个任务,自己的训练数据不多时,那怎么办?我们可以找到一个同类的别人训练好的模型,把别人现成的训练好了的模型(预训练模型)拿过来,换成自己的数据。调整一下参数,再训练一遍,这就是微调。PyTorch里面提供的网络模型都是官方通过Imagenet的数据集与训练好的数据,如果我们的训练数据不够,这些数据是可以作为基础模型来使用的。在此我个人理解简言之就是:预训练模型拿过来用,改一下参数或者结构就叫微调。

为什么要微调?

1、对于数据集本身很小的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的魔型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠微调已经训练好的模型。
2、可以降低训练成本;如哦使用到处特征向量的方法进行迁移学习,后期的训练成本非常低,用CPU都完全无压力,没有深度学习机器也可以做。
3、前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的要强悍,没有必要重复造轮子。

怎么微调?

对于不同的领域微调的方法也不一样,比如语音识别领域一般微调前几层,图片识别问题微调后面几层。对于图片来说,我们CNN的前几层学习到的都是低级的特征,比如点、线、面,这些低级的特征对于任何图片来说都是可以抽象出来的。所以我们将他作为通用数据,只微调这些低级特征组合起来的高级特征即可,例如,这些点、线、面,组成的是圆还是椭圆,还是正方形,这些代表的含义都是我们需要后面训练出来的。
对于语音来说,每个单词表达的意思都是一样的,只不过发音或者是单词的拼写不一样,比如苹果,apple,apfel,都表示的是同一个东西,只不过发音和单词不一样,但是他具体代表的含义是一样的,就是高级特征是相同的,所以我们只要微调低级的特征就可以了。

下面只介绍下计算机视觉反向的微调,摘自https://cs231n.github.io/transfer-learning/
ConvNet as fixed feature extractor:其实这里有两种做法:
1、使用最后一个fc layer之前的fc layer获得的特征,学习个线性分类器(比如SVM);
2、重新训练最后一个fc layer。
Fine tuing the ConvNet:固定前几层的参数,只对最后几层进行fine tuing,

对于以上两种方案有一些微调的小技巧,比如先计算出预训练模型的卷积层对所有训练和测试数据的特征向量,然后抛开预训练模型,只训练自己定制的简配版全连接网络。这个方式的一个好处就是节省计算资源,每次迭代都不会再去跑全部的数据,而只是跑一下简配版的全连接

Pretrained models:这个其实和第二种是一个意思,不过比较极端,使用整个pre trained的model作为初始化,然后fine tuning整个网络而不是某些层,但是这个的计算量是非常大的,就只相当于做了一个初始化。

注意事项

1、新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器
2、新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine tuing
3、新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的话,那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据。
4、新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错
5、如果数据集大小不同的话,可以在最后的fc层之前添加conv或者pool层,使得最后的输出于fc层一致,但这样会导致准确度大幅下降,所以不建议这样做
6、对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。

微调实例

这里我们使用官方训练好的resnet50来参加kaggle上的 dog breed狗的种类识别来做一个简单的微调实例。


import torch, os, torchvision
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models, transforms
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit# 定义一个数据集格式化
class DogDataset(Dataset):def __init__(self, labels_df, img_path, transform=None):self.labels_df = labels_dfself.img_path = img_pathself.transform = transformdef __len__(self):return self.labels_df.shape[0]def __getitem__(self, idx):image_name = os.path.join(self.img_path, self.labels_df.id[idx]) + '.jpg'img = Image.open(image_name)label = self.labels_df.label_idx[idx]if self.transform:img = self.transform(img)return img, labeldef train(model, device, train_loader, epoch):model.train()for batch_idx, data in enumerate(train_loader):x, y = datax = x.to(device)y = y.to(device)optimizer.zero_grad()y_hat = model(x)loss = criterion(y_hat, y)loss.backward()optimizer.step()print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss.item()))def test(model, device, test_loader):model.eval()test_loss = 0correct = 0with torch.no_grad():for i, data in enumerate(test_loader):x, y = datax = x.to(device)y = y.to(device)optimizer.zero_grad()y_hat = model(x)test_loss += criterion(y_hat, y).item()predict = y_hat.max(1, keepdim=True)[1]  # get the index of the max log-probabilitycorrect += predict.eq(y.view_as(predict)).sum().item()test_loss /= len(test_loader.dataset)print('\n Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(100. * correct / len(val_dataset)))# 固定层的向量导出
def hook(module, input, output):for i in range(input[0].size(0)):in_list.append(input[0][i].cpu().numpy())if __name__ == '__main__':DATA_ROOT = '自己的文件目录\Dog_Dataset'all_labels_df = pd.read_csv(os.path.join(DATA_ROOT, 'labels.csv'))# all_labels_df.head() 显示头部数据 默认前5个breeds = all_labels_df.breed.unique()  # breed读取全部标签部分 unique剔除重复统计的标签breed2idx = dict((breed, idx) for idx, breed in enumerate(breeds))  # dict的字典all_labels_df['label_idx'] = [breed2idx[b] for b in all_labels_df.breed]all_labels_df.head()#定义一些超参数IMG_SIZE = 224BATCH_SIZE = 8IMG_MEAN = [0.485, 0.456, 0.406]IMG_STD = [0.229, 0.224, 0.225]CUDA = torch.cuda.is_available()DEVICE = torch.device("cuda" if CUDA else "cpu")#定义数据初始化train_transforms = transforms.Compose([transforms.Resize(IMG_SIZE),  # 缩放尺寸transforms.RandomResizedCrop(IMG_SIZE),  # 随机长宽比裁剪transforms.RandomHorizontalFlip(),  # 依概率p水平翻转ttransforms.RandomRotation(30),  # 随机旋转 选择30°transforms.ToTensor(),transforms.Normalize(IMG_MEAN, IMG_STD)])val_transforms = transforms.Compose([transforms.Resize(IMG_SIZE),transforms.CenterCrop(IMG_SIZE),transforms.ToTensor(),transforms.Normalize(IMG_MEAN, IMG_STD)])# 分割10%的数据dataset_names = ['train', 'valid']stratified_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)  # 提供分层抽样功能,确保每个标签对应的样本的比例train_split_idx, val_split_idx = next(iter(stratified_split.split(all_labels_df.id, all_labels_df.breed)))train_df = all_labels_df.iloc[train_split_idx].reset_index()val_df = all_labels_df.iloc[val_split_idx].reset_index()image_transforms = {'train': train_transforms, 'valid': val_transforms}train_dataset = DogDataset(train_df, os.path.join(DATA_ROOT, 'train'), transform=image_transforms['train'])val_dataset = DogDataset(val_df, os.path.join(DATA_ROOT, 'train'), transform=image_transforms['valid'])image_dataset = {'train': train_dataset, 'valid': val_dataset}# 使用官方的dataloader载入数据image_dataloader = {x: DataLoader(image_dataset[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=0) for x in dataset_names}dataset_sizes = {x: len(image_dataset[x]) for x in dataset_names}#开始配置网络,由于ImageNet是识别1000个物体,狗的分类一共只有120,所以需要对模型的最后一层全连接层#进行微调,将输出从1000改为120 model_ft = models.resnet50(pretrained=True) #自动下载官方的预训练模型for param in model_ft.parameters():param.requires_grad = Falsenum_fc_ftr = model_ft.fc.in_featuresmodel_ft.fc = nn.Linear(num_fc_ftr, len(breeds))model_ft = model_ft.to(DEVICE)criterion = nn.CrossEntropyLoss()optimizer = torch.optim.Adam([{'params': model_ft.fc.parameters()}], lr=0.001)#训练查看效果  这里要拼显卡,实在不行整个服务器。for epoch in range(1, 10):train(model=model_ft, device=DEVICE, train_loader=image_dataloader["train"], epoch=epoch)test(model=model_ft, device=DEVICE, test_loader=image_dataloader["valid"])#但是这里每次训练都需要将一张图片在全部网络中进行计算,而且每次的计算结果都是一样的。为了提高效率,#我们将这些不更新网络权值参数层的计算结果保存下来,这样就可以直接将这些结果输入到fc层或者以这些结#果建立新的网络层。#如何将固定层的向量输出in_list = []model_ft.avgpool.register_forward_hook(hook)with torch.no_gard():for batch_idx, data in enumerate(image_dataloader["train"]):x, y = data;x = x.to(DEVICE)y = y.to(DEVICE)y_hat = model_ft(x)features = np.array(in_list)np.save("features",features)#这样再训练时我们就只需将这个数组读出来,然后就可以直接使用这个数组在输入到线性层或者sigmoid层就可以了。

模型微调------学习笔记相关推荐

  1. MATLAB simulink 模型验证学习笔记

    MATLAB simulink 模型验证学习笔记 一.静态验证 1.Model Advisor 模型验证意思是用matlab自带的规范检查工具来检查自己画的模型是否符合规范. 进行模型验证需要用到的模 ...

  2. MPC模型预测控制学习笔记-2021.10.27

    MPC模型预测控制学习笔记-点击目录就可以跳转 1. 笔者介绍 2. 参考资料 3. MPC分类 4. 数据的标准化与归一化 5. MATLAB-MPC学习笔记 5.1 获取测试信号:gensig( ...

  3. 图像分类模型的学习笔记

    1 查找ImageNet数据集上最好的模型--paperswithcode 我们在paperswithcode上查找现在ImageNet数据集上最好的模型 2 ImageNet上最好的模型--FixE ...

  4. 数据挖掘算法之时间序列算法(平稳时间序列模型,AR(p),MA(q),(平稳时间序列模型,AR(p),MA(q),ARMA(p,q)模型和非平稳时间序列模型,ARIMA(p,d,q)模型)学习笔记梳理

    时间序列算法 一.时间序列的预处理 二.平稳时间序列模型 (一).自回归模型AR( p ) (二).移动平均模型MA(q) (三).自回归移动平均模型ARMA(p,q) 三.非平稳时间序列模型 四.确 ...

  5. Java虚拟机(JVM)与Java内存模型(JMM)学习笔记

    Java虚拟机[JVM]与Java内存模型[JMM]学习笔记 Java虚拟机(JVM) 三种JVM JVM 位置 JVM的主要组成部分及其作用 类加载器 双亲委派机制 沙箱安全机制 Java本地接口( ...

  6. css中怎么加入立体模型,CSS学习笔记二:css 画立体图形

    继上一次学了如何去运用css画平面图形,这一次学如何去画正方体,从2D向着3D学习,虽然有点满,但总是一个过程,一点一点积累,然后记录起来. Transfrom3D 在这一次中运用到了一下几种属性: ...

  7. 双重差分模型DID学习笔记

    双重差分模型DID学习 1.DID介绍 1.1 特点 1.2 传统DID 1.3 经典DID 1.4 异时DID 1.5 广义DID 1.6 异质性DID 2. DID 平行趋势检验 3 实践举例 3 ...

  8. 06.Logistic回归与最大熵模型(学习笔记)

    06.Logistic回归与最大熵模型 参考: 袁春老师<大数据机器学习公开课>:https://www.xuetangx.com/course/THU08091001026/103331 ...

  9. 吴恩达深度学习之五《序列模型》学习笔记

    一.循环序列模型 1.1 为什么选择序列模型 如图所示是一些序列数据的例子 1.2 数学符号 如图所示,我们用  表示一个序列的第 t 个元素,t 从 1 开始 NLP中一个单词就是一个元素(又称时间 ...

  10. 【模型检测学习笔记】6:线性时序性质(Linear-time Properties)

    为方便,线性时序性质(linear-time properties)后续均简称LT性质. 在系统分析中,描述线性时序行为(linear-time behavior)可以是基于动作的(action-ba ...

最新文章

  1. Object.defineProperty()
  2. 数据蒋堂 | JOIN延伸 - 维度查询语法
  3. 考驾照选择 AI 教练,心态稳定不骂人
  4. Linux目录和文件中的常用命令(二)
  5. 基于ASP.NET的新闻管理系统(三)代码展示
  6. html自动加https,http自动跳转https的配置方法
  7. 2021考研c语言编程题,2021c语言编程例题及答案.docx
  8. 雷军微博抽奖送的那台蔚来ES6 时隔10个月终于提到车了
  9. 人工智能与深度学习概念(5)——目标检测-RCNN
  10. 微信小程序 eventChannel在页面间传参
  11. 蜂巢式技术阵营简化IoT蓝图
  12. RAC3——RAC原理开始
  13. telegram接入微信
  14. Proteus8.12无法仿真STC15系列单片机解决办法
  15. 佳能最新版DPP免CD安装
  16. 百晓生兵器谱之公有云排名
  17. 「首席架构师推荐」数值分析软件列表
  18. 国内最好用的短网址推荐(2022年最新整理)
  19. 计算机实战项目、毕业设计、课程设计之[含论文+辩论PPT+源码等]微信小程序社区疫情防控+后台管理|前后分离VUE[包运行成功
  20. 人脸识别《一》opencv人脸识别之人脸检测

热门文章

  1. 小颗粒积木步骤图纸_loz小颗粒钻石积木拼图图纸谁有
  2. 电子计算机特征具有什么功能,电子计算机的基本特征有哪些?
  3. 我国东北虎种群增长迅速 但近交风险不容忽视
  4. IIS提示您未被授权查看该页 401.1解决办法
  5. pagefile文件大小设置
  6. 终于有人把搜索引擎讲明白了
  7. Dapr for dotnet | 并发计算模型 - Virtual Actors
  8. java gui 做闹钟,用JAVA怎样编写一个可以在eclipse中运行的闹钟程序?
  9. js 十六进制,八进制,二进制
  10. android桌面小部件开发