• 是什么
    神经网络需要数据来训练,从数据中获得信息,进而转化成相应的权重。这些权重能够被提取出来,迁移到其他的神经网络中。
    迁移学习:通过使用之前在大数据集上经过训练的预训练模型,我们可以直接使用相应的结构和权重,将他们应用在我们正在面对的问题上。即将预训练的模型“迁移”到我们正在应对的特定问题中。
    在选择预训练模型时需要注意,如果我们的问题与预训练模型训练情景有很大出入,那么模型所得到的的预测结果会非常不准确。举例来说,如果把一个原本用于语音识别的模型用作用户识别,那结果肯定是不理想的。
    ImageNet数据集已经被广泛用作训练集,因为它规模足够大(包括120万张图片),有助于训练普适模型。在迁移学习中,这些预训练的网络对于ImageNet以外的图片表现出了很好的泛化性能。
    微调(fine tuning)可以省去大量的计算资源和计算时间,提高计算效率,甚至提高准确率。

  • 什么时候用
    使用的数据集和预训练模型的数据集相似;
    自己搭建或使用的CNN模型正确率太低;
    数据集相似,但数据集数量少;
    计算资源少。

  • 怎么用
    数据量少,且数据高度相似: - 在这种情况下,我们所做的只是修改最后几层或最终的softmax图层的输出类别。
    数据量少,但数据相似度低: 在这种情况下,我们可以冻结预训练模型的初始层(比如k层),并再次训练剩余的(n-k)层。由于新数据集的相似度较低,因此根据新数据集对较高层进行重新训练具有重要意义。
    数据量大,数据相似度低:此时最好根据我们自己的数据从头开始训练神经网络(Training from scatch)。
    数据量大,数据相似度高: 这是理想情况。在这种情况下,预训练模型应该是最有效的。使用模型的最好方法是保留模型的体系结构和模型的初始权重。然后,我们可以使用在预先训练的模型中的权重来重新训练该模型。

  • 注意事项

  1. 使用较小的学习率来训练网络。由于我们预计预先训练的权重相对于随机初始化的权重已经相当不错,我们不想过快地扭曲它们太多。通常的做法是使初始学习率比用于从头开始训练(Training from scratch)的初始学习率小10倍。

  2. 如果数据集数量过少,我们进来只训练最后一层,如果数据集数量中等,冻结预训练网络的前几层的权重也是一种常见做法。这是因为前几个图层捕捉了与我们的新问题相关的通用特征,如曲线和边。我们希望保持这些权重不变。相反,我们会让网络专注于学习后续深层中特定于数据集的特征。

  • 预训练模型修剪+微调:

    1. 在已经训练好的基网络上添加自定义网络;
    2. 冻结基网络,训练自定义网络;
    3. 解冻部分基网络,联合训练解冻层和自定义网络。

注意,在联合训练解冻层和自定义网络之前,通常要先训练自定义网络,否则,随机初始化的自定义网络权重会将误差信号传到解冻层,破坏解冻层以前学到的表示,使得训练成本增大。

pytorch四种冻结层的方式:
假设模型定义如下:

class Char3SeqModel(nn.Module):def __init__(self, char_sz, n_fac, n_h):super().__init__()self.em = nn.Embedding(char_sz, n_fac)self.fc1 = nn.Linear(n_fac, n_h)self.fc2 = nn.Linear(n_h, n_h)self.fc3 = nn.Linear(n_h, char_sz)def forward(self, ch1, ch2, ch3):# do somethingout = #....return outmodel = Char3SeqModel(10000, 50, 25)

假设需要冻结FC1

  1. 方法1:设置requires_grad为False
# 冻结
model.fc1.weight.requires_grad = False
optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=0.1)
#
# compute loss
# loss.backward()
# optmizer.step()# 解冻
model.fc1.weight.requires_grad = True
optimizer.add_param_group({'params': model.fc1.parameters()})
  1. 方法2:最简单的方式是在定义optimizer的时候,不要加入你想冻结的那一层的参数。
# 冻结
optimizer = optim.Adam([{'params':[ param for name, param in model.named_parameters() if 'fc1' not in name]}], lr=0.1)
# compute loss
# loss.backward()
# optimizer.step()# 解冻
optimizer.add_param_group({'params': model.fc1.parameters()})
  1. 方法3:将原来layer的weight缓存下来,每次反向传播之后,再将原来的weight赋值给相应的layer
fc1_old_weights = Variable(model.fc1.weight.data.clone())
# compute loss
# loss.backward()
# optimizer.step()
model.fc1.weight.data = fc1_old_weights.data
  1. 方法4:使用 torch.no_grad()
    这种方式只需要在网络定义中的forward方法中,将需要冻结的层放在使用 torch.no_grad()下。
class xxnet(nn.Module):def __init__():....self.layer1 = xxself.layer2 = xxself.fc = xxdef forward(self.x):with torch.no_grad():x = self.layer1(x)x = self.layer2(x)x = self.fc(x)return x

这种方式则是将layer1和layer2定义的层冻结,只训练fc层的参数。
5. 终极方法实现

作者:肥波喇齐
链接:https://www.zhihu.com/question/311095447/answer/589307812
来源:知乎
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。from collections.abc import Iterabledef set_freeze_by_names(model, layer_names, freeze=True):if not isinstance(layer_names, Iterable):layer_names = [layer_names]for name, child in model.named_children():if name not in layer_names:continuefor param in child.parameters():param.requires_grad = not freezedef freeze_by_names(model, layer_names):set_freeze_by_names(model, layer_names, True)def unfreeze_by_names(model, layer_names):set_freeze_by_names(model, layer_names, False)def set_freeze_by_idxs(model, idxs, freeze=True):if not isinstance(idxs, Iterable):idxs = [idxs]num_child = len(list(model.children()))idxs = tuple(map(lambda idx: num_child + idx if idx < 0 else idx, idxs))for idx, child in enumerate(model.children()):if idx not in idxs:continuefor param in child.parameters():param.requires_grad = not freezedef freeze_by_idxs(model, idxs):set_freeze_by_idxs(model, idxs, True)def unfreeze_by_idxs(model, idxs):set_freeze_by_idxs(model, idxs, False)
# 冻结第一层
freeze_by_idxs(model, 0)
# 冻结第一、二层
freeze_by_idxs(model, [0, 1])
#冻结倒数第一层
freeze_by_idxs(model, -1)
# 解冻第一层
unfreeze_by_idxs(model, 0)
# 解冻倒数第一层
unfreeze_by_idxs(model, -1)# 冻结 em层
freeze_by_names(model, 'em')
# 冻结 fc1, fc3层
freeze_by_names(model, ('fc1', 'fc3'))
# 解冻em, fc1, fc3层
unfreeze_by_names(model, ('em', 'fc1', 'fc3'))

代码参考地址

预训练网络的模型微调方法相关推荐

  1. 预训练网络的特征提取方法(VGG16)

    预训练网络的特征提取方法 1.知识点 #想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训练网络 #预训练网络:一个保存好的网络,之前已经在大型数据集(通常是大规模图像分类任务)上 ...

  2. python 动物分类_《python深度学习》笔记---5.3-1、猫狗分类(使用预训练网络)

    <python深度学习>笔记---5.3-1.猫狗分类(使用预训练网络) 一.总结 一句话总结: [小型图像数据集]:想要将深度学习应用于小型图像数据集,一种常用且非常高效的方法是使用预训 ...

  3. Keras 的预训练权值模型用来进行预测、特征提取和微调(fine-tuning)

    转至:Keras中文文档 https://keras.io/zh/applications/ 应用 Applications Keras 的应用模块(keras.applications)提供了带有预 ...

  4. 基于Keras预训练词向量模型的文本分类方法

    本文语料仍然是上篇所用的搜狗新闻语料,采用中文预训练词向量模型对词进行向量表示.上篇文章将文本分词之后,采用了TF-IDF的特征提取方式对文本进行向量化表示,所产生的文本表示矩阵是一个稀疏矩阵,本篇采 ...

  5. Hugging Face实战(NLP实战/Transformer实战/预训练模型/分词器/模型微调/模型自动选择/PyTorch版本/代码逐行解析)下篇之模型训练

    模型训练的流程代码是不是特别特别多啊?有的童鞋看过Bert那个源码写的特别特别详细,参数贼多,运行一个模型百八十个参数的. Transformer对NLP的理解是一个大道至简的感觉,Hugging F ...

  6. MedicalGPT:基于LLaMA-13B的中英医疗问答模型(LoRA)、实现包括二次预训练、有监督微调、奖励建模、强化学习训练[LLM:含Ziya-LLaMA]。

    项目设计集合(人工智能方向):助力新人快速实战掌握技能.自主完成项目设计升级,提升自身的硬实力(不仅限NLP.知识图谱.计算机视觉等领域):汇总有意义的项目设计集合,助力新人快速实战掌握技能,助力用户 ...

  7. KDD 2022 | 图“预训练、提示、微调”范式下的图神经网络泛化框架

    ©作者 | 社媒派SMP 来源 | 社媒派SMP 本文是SIGKDD 2022入选论文"GPPT: Graph Pre-training and Prompt Tuning to Gener ...

  8. MICCAI 2020 | 基于3D监督预训练的全身病灶检测SOTA(预训练代码和模型已公开)...

    关注公众号,发现CV技术之美 ▊ 研究背景介绍 由于深度学习任务往往依赖于大量的标注数据,医疗图像的样本标注又会涉及到较多的专业知识,标注人员需要对病灶的大小.形状.边缘等信息进行准确的判断,甚至需要 ...

  9. 【深度学习】预训练的卷积模型比Transformer更好?

    引言 这篇文章就是当下很火的用预训练CNN刷爆Transformer的文章,LeCun对这篇文章做出了很有深意的评论:"Hmmm".本文在预训练微调范式下对基于卷积的Seq2Seq ...

最新文章

  1. AI已经融入生活,不懂AI的人已经out了,五分钟了解AI人工智能!
  2. Digital Image Processing 学习笔记3
  3. 服务器-番外篇-搭建samba共享
  4. mysql中利用sql语句修改字段名称,字段长度等操作(亲测)
  5. python android自动化基于java_Appium+Python自动化 1 环境搭建(适用windows系统-Android移动端自动化)...
  6. python 判断字符串是否为空用什么方法?
  7. JAVA进阶教学之(Enum枚举类)
  8. linux进程创建时间,linux进程创建时间计算
  9. kubernets kube-proxy原理分析
  10. HDOJ 1863畅通工程(最小生成树kruskal算法并查集实现)
  11. linux日期时间转换函数,Linux时间戳、日期转换函数
  12. python创建单行文本框_HTML单行文本框
  13. IAR环境下的STM32H750片外QSPI Flash下载仿真
  14. HTML在列表中加图片,HTML + JS 列表显示图片
  15. C++:Trivial、Standard-Layout 和 POD
  16. 一年回顾_2016年:一年回顾
  17. 卸载微软拼音2003
  18. MySQL数据库day01
  19. 星巴克:邮件里的夏趣盎然
  20. 飞秋 飞秋2010 飞秋2010下载 飞秋下载2010正式版

热门文章

  1. 安装ubuntu18.04分区设置
  2. MySQL数据库系列之数据库设计原则
  3. TableLayout 中 stretchColumns的用法
  4. 猎豹移动:公有云快速构建海外移动化应用基础环境
  5. what is IMHO?
  6. 数据增强增广方法及实现
  7. 激光SLAM入门学习笔记
  8. dotnet 用 gcdump 调试应用程序内存占用
  9. 安装和配置Canal
  10. [笔记] vxworks添加静态路由备注routec