最近在研究Mask R-CNN,该网络一部分是跟Faster R-CNN(https://arxiv.org/pdf/1506.01497v3.pdf)相似的,同样的,在模型训练实现时,其中一种方法叫做交替训练(Alternating training),想利用该方法就涉及到如何对网络进行指定层的训练,今天就总结一下pytorch中的实现方法,既然写了指定层训练,那就把参数的单独设置也介绍一下,pytorch提供的这些方法提高了网络训练的灵活度,是很好用的方法。


指定层训练

每个变量都有一个标记:requires_grad允许从梯度计算中细分排除子图,并可以提高效率。当我们把参数属性设置为requires_grad=False时,该参数固定不变,不参与训练,不更新,具体操作如下:

# 载入预训练模型参数后...
for name, value in model.named_parameters():if name 满足某些条件:value.requires_grad = False# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)

检查是否固定:被固定的输出False,未被固定的输出True

for name, value in model.named_parameters():print(name,value.requires_grad)  #打印所有参数requires_grad属性,True或False

将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。

当我们指定的参数比较多时,可以利用正则表达式来定义参与训练的层,具体实现如下:

def set_trainable(net, layer_regex, model=None, indent=0, verbose=1):"""Sets model layers as trainable if their names matchthe given regular expression."""for layer_name,param in net.named_parameters():trainable = bool(re.fullmatch(layer_regex, layer_name))  # re.fullmatch()返回一个和模式串完全匹配的字符串if not trainable:param.requires_grad = False#将与layer_regex不匹配的层固定layer_regex = {# all layers but the backbone"heads": r"(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)",# From a specific Resnet stage and up"3+": r"(fpn.C3.*)|(fpn.C4.*)|(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)","4+": r"(fpn.C4.*)|(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)","5+": r"(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)",# All layers"all": ".*",
}#用正则表达式标识要训练的层if layers in layer_regex.keys():layers = layer_regex[layers]self.set_trainable(layers)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)model.train_model(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE,epochs=40,layers='heads')

通过选择layers来选择我们要训练更新的层。


为每个参数单独设置选项


当我们想指定每一层的学习率时,可以这样操作:

optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)

这意味着model.base的参数将会使用1e-2的学习率,model.classifier的参数将会使用1e-3的学习率,并且0.9的momentum将会被用于所 有的参数。

训练网络指定层pytorch实现方法相关推荐

  1. Pytorch加载预训练网络,替换分类层并重新训练

    定义网络时,在网络类的构造函数网络结构定义中添加如下语句: for p in self.parameters():p.requires_grad = False 该语句的功能是固定定义在该语句之前的网 ...

  2. pytorch训练网络冻结某些层

    引言:首先我们应该很清楚地知道冻结网络中的某些层有什么作用?如何进行相关的冻结设置?代码何如呢? 话不多说说,首先我们探讨第一个问题: 1.冻结网络的某些层有什么作用? 这个问题顾名思义就是冻结网络中 ...

  3. pytorch网络冻结的三种方法区别:detach、requires_grad、with_no_grad

    pytorch网络冻结的三种方法区别:detach.requires_grad.with_no_grad 文章目录 pytorch网络冻结的三种方法区别:detach.requires_grad.wi ...

  4. Pytorch:图像语义分割-FCN, U-Net, SegNet, 预训练网络

    Pytorch: 图像语义分割-FCN, U-Net, SegNet, 预训练网络 Copyright: Jingmin Wei, Pattern Recognition and Intelligen ...

  5. 【最强ResNet改进系列】IResNet:涨点不涨计算量,可训练网络超过3000层!

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 [导读]本篇文章是[最强ResNet改进系列]的第四篇文章,前面我们已经介绍了Res2Net和 ...

  6. 预训练网络的模型微调方法

    是什么 神经网络需要数据来训练,从数据中获得信息,进而转化成相应的权重.这些权重能够被提取出来,迁移到其他的神经网络中. 迁移学习:通过使用之前在大数据集上经过训练的预训练模型,我们可以直接使用相应的 ...

  7. pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络

    作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络​mp.weixin.qq.com 导 ...

  8. ResNet改进版来了!可训练网络超过3000层!相同深度精度更高

    来自阿联酋起源人工智能研究院(IIAI)的研究人员公布了一篇论文Improved Residual Networks for Image and Video Recognition,深入研究了残差网络 ...

  9. pytorch深度学习实战——预训练网络

    来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...

最新文章

  1. 2019研究生新生大数据出炉!清华园迎来8900多名新主人
  2. .NET Core版本七牛云SDK使用
  3. tcp/ip 协议栈Linux内核源码分析七 路由子系统分析二 策略路由
  4. 启明云端方案分享| ESP32-S2 摄像头 WIFI方案应用于智能猫眼
  5. Chrome 开发者工具 live expression 的用法
  6. sqlalchemy与mysql区别_sqlite3和sqlalchemy有什么区别?
  7. Altium Designer画元器件封装三种方法
  8. Mybatis与JDBC批量插入MySQL数据库性能测试及解决方案
  9. 西方主要管理思想简介
  10. python 字符串分割_python拆分字符串到列表
  11. 图像检索哈希算法综述
  12. input 框隐藏光标问题
  13. Landsat系列卫星数据应用介绍
  14. 动态改变Input和Textarea值Vue数据没有绑定的解决办法
  15. 测试管理工具的基本功能有哪些?
  16. 编译程序原理VS解释程序原理
  17. DT(密集轨迹)算法和iDT(改善的密集轨迹)算法
  18. iPhoneX的faceID到底是一种怎样风骚的操作?
  19. Python 日期时间格式化输出,带年、月、日、时、分、秒
  20. 知识图谱构建(入门)

热门文章

  1. 绑定到对象上的copyWithin方法
  2. js优化阿里云图片加载(二)
  3. fusion 360安装程序的多个实例正在同时运行。_阿里架构师实例讲解——Java多线程编程;详细的不能再详细了...
  4. java剑姬_ListView和Adapter(文字列表)
  5. 为什么网页背景图片都切开
  6. ProE常用曲线方程:Python Matplotlib 版本代码(玫瑰曲线)
  7. 使用replace pioneer批量修改文件名
  8. DeepMind用ReinforcementLearning玩游戏
  9. drawable如何只让两个叫圆角_cad怎么使用圆角?cad的圆角怎么使用?
  10. asp.net webapi 自定义身份验证