训练网络指定层pytorch实现方法
最近在研究Mask R-CNN,该网络一部分是跟Faster R-CNN(https://arxiv.org/pdf/1506.01497v3.pdf)相似的,同样的,在模型训练实现时,其中一种方法叫做交替训练(Alternating training),想利用该方法就涉及到如何对网络进行指定层的训练,今天就总结一下pytorch中的实现方法,既然写了指定层训练,那就把参数的单独设置也介绍一下,pytorch提供的这些方法提高了网络训练的灵活度,是很好用的方法。
指定层训练
每个变量都有一个标记:requires_grad
允许从梯度计算中细分排除子图,并可以提高效率。当我们把参数属性设置为requires_grad=False时,该参数固定不变,不参与训练,不更新,具体操作如下:
# 载入预训练模型参数后...
for name, value in model.named_parameters():if name 满足某些条件:value.requires_grad = False# setup optimizer
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-4)
检查是否固定:被固定的输出False,未被固定的输出True
for name, value in model.named_parameters():print(name,value.requires_grad) #打印所有参数requires_grad属性,True或False
将满足条件的参数的 requires_grad 属性设置为False, 同时 filter 函数将模型中属性 requires_grad = True 的参数筛选出来,传到优化器(以Adam为例)中,只有这些参数会被求导数和更新。
当我们指定的参数比较多时,可以利用正则表达式来定义参与训练的层,具体实现如下:
def set_trainable(net, layer_regex, model=None, indent=0, verbose=1):"""Sets model layers as trainable if their names matchthe given regular expression."""for layer_name,param in net.named_parameters():trainable = bool(re.fullmatch(layer_regex, layer_name)) # re.fullmatch()返回一个和模式串完全匹配的字符串if not trainable:param.requires_grad = False#将与layer_regex不匹配的层固定layer_regex = {# all layers but the backbone"heads": r"(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)",# From a specific Resnet stage and up"3+": r"(fpn.C3.*)|(fpn.C4.*)|(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)","4+": r"(fpn.C4.*)|(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)","5+": r"(fpn.C5.*)|(fpn.P5\_.*)|(fpn.P4\_.*)|(fpn.P3\_.*)|(fpn.P2\_.*)|(rpn.*)|(classifier.*)|(mask.*)",# All layers"all": ".*",
}#用正则表达式标识要训练的层if layers in layer_regex.keys():layers = layer_regex[layers]self.set_trainable(layers)
optimizer = optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=5e-4)model.train_model(dataset_train, dataset_val,learning_rate=config.LEARNING_RATE,epochs=40,layers='heads')
通过选择layers来选择我们要训练更新的层。
为每个参数单独设置选项
当我们想指定每一层的学习率时,可以这样操作:
optim.SGD([{'params': model.base.parameters()},{'params': model.classifier.parameters(), 'lr': 1e-3}], lr=1e-2, momentum=0.9)
这意味着model.base
的参数将会使用1e-2
的学习率,model.classifier
的参数将会使用1e-3
的学习率,并且0.9
的momentum将会被用于所 有的参数。
训练网络指定层pytorch实现方法相关推荐
- Pytorch加载预训练网络,替换分类层并重新训练
定义网络时,在网络类的构造函数网络结构定义中添加如下语句: for p in self.parameters():p.requires_grad = False 该语句的功能是固定定义在该语句之前的网 ...
- pytorch训练网络冻结某些层
引言:首先我们应该很清楚地知道冻结网络中的某些层有什么作用?如何进行相关的冻结设置?代码何如呢? 话不多说说,首先我们探讨第一个问题: 1.冻结网络的某些层有什么作用? 这个问题顾名思义就是冻结网络中 ...
- pytorch网络冻结的三种方法区别:detach、requires_grad、with_no_grad
pytorch网络冻结的三种方法区别:detach.requires_grad.with_no_grad 文章目录 pytorch网络冻结的三种方法区别:detach.requires_grad.wi ...
- Pytorch:图像语义分割-FCN, U-Net, SegNet, 预训练网络
Pytorch: 图像语义分割-FCN, U-Net, SegNet, 预训练网络 Copyright: Jingmin Wei, Pattern Recognition and Intelligen ...
- 【最强ResNet改进系列】IResNet:涨点不涨计算量,可训练网络超过3000层!
点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要15分钟 Follow小博主,每天更新前沿干货 [导读]本篇文章是[最强ResNet改进系列]的第四篇文章,前面我们已经介绍了Res2Net和 ...
- 预训练网络的模型微调方法
是什么 神经网络需要数据来训练,从数据中获得信息,进而转化成相应的权重.这些权重能够被提取出来,迁移到其他的神经网络中. 迁移学习:通过使用之前在大数据集上经过训练的预训练模型,我们可以直接使用相应的 ...
- pytorch自带网络_使用PyTorch Lightning自动训练你的深度神经网络
作者:Erfandi Maula Yusnu, Lalu 编译:ronghuaiyang 原文链接 使用PyTorch Lightning自动训练你的深度神经网络mp.weixin.qq.com 导 ...
- ResNet改进版来了!可训练网络超过3000层!相同深度精度更高
来自阿联酋起源人工智能研究院(IIAI)的研究人员公布了一篇论文Improved Residual Networks for Image and Video Recognition,深入研究了残差网络 ...
- pytorch深度学习实战——预训练网络
来源:<Pytorch深度学习实战>,2.1,一个识别图像主体的预训练网络 from torchvision import models from torchvision import t ...
最新文章
- 2019研究生新生大数据出炉!清华园迎来8900多名新主人
- .NET Core版本七牛云SDK使用
- tcp/ip 协议栈Linux内核源码分析七 路由子系统分析二 策略路由
- 启明云端方案分享| ESP32-S2 摄像头 WIFI方案应用于智能猫眼
- Chrome 开发者工具 live expression 的用法
- sqlalchemy与mysql区别_sqlite3和sqlalchemy有什么区别?
- Altium Designer画元器件封装三种方法
- Mybatis与JDBC批量插入MySQL数据库性能测试及解决方案
- 西方主要管理思想简介
- python 字符串分割_python拆分字符串到列表
- 图像检索哈希算法综述
- input 框隐藏光标问题
- Landsat系列卫星数据应用介绍
- 动态改变Input和Textarea值Vue数据没有绑定的解决办法
- 测试管理工具的基本功能有哪些?
- 编译程序原理VS解释程序原理
- DT(密集轨迹)算法和iDT(改善的密集轨迹)算法
- iPhoneX的faceID到底是一种怎样风骚的操作?
- Python 日期时间格式化输出,带年、月、日、时、分、秒
- 知识图谱构建(入门)
热门文章
- 绑定到对象上的copyWithin方法
- js优化阿里云图片加载(二)
- fusion 360安装程序的多个实例正在同时运行。_阿里架构师实例讲解——Java多线程编程;详细的不能再详细了...
- java剑姬_ListView和Adapter(文字列表)
- 为什么网页背景图片都切开
- ProE常用曲线方程:Python Matplotlib 版本代码(玫瑰曲线)
- 使用replace pioneer批量修改文件名
- DeepMind用ReinforcementLearning玩游戏
- drawable如何只让两个叫圆角_cad怎么使用圆角?cad的圆角怎么使用?
- asp.net webapi 自定义身份验证