最近在使用迁移学习跑实验,遇到要将网络的部分层的参数固定住这一问题,经过多次尝试找到了常用的几种方法。

案例介绍

有两个网络串联训练(model、model1),现在想固定住model的网络参数,网络训练过程中只更新model1的权重。
其中model为仅包含两个卷积层的网络,model1为仅包含一个全连接层的简单网络。
model和model1代码如下:

import torch
import torch.nn as nn
from torch import optim# 定义两个包含两个卷积层,一个全连接层的简单网络
class Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=1, bias=False)# self.FC = nn.Linear(16*10*10, 20, bias=False)def forward(self, x):x = self.conv1(x)x = self.conv2(x)# x = self.FC(x.view(x.size(0), -1))return xclass Net1(nn.Module):def __init__(self):super().__init__()self.FC = nn.Linear(16*10*10, 20, bias=False)def forward(self, x):x = self.FC(x.view(x.size(0), -1))return xmodel = Net()
model1 = Net1()

训练过程如下

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([{'params': model.parameters()}, {'params': model1.parameters()}], lr=1.0)model.train()input = torch.ones(3, 3, 10, 10)
label = torch.ones(3, dtype=torch.long)for i in range(1):output = model(input)last_output = model1(output)loss = criterion(last_output, label)optimizer.zero_grad()loss.backward()optimizer.step()

一、require_grad=False

使用方法为令要冻结的层的网络参数的require_grad=False

for p in model.parameters():p.requires_grad = False

则训练过程修改为

for p in model.parameters():p.data.fill_(2)for p in model1.parameters():p.data.fill_(1)print(model.conv1.weight[0][0][0])
print(model1.FC.weight[0])criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([{'params': model.parameters()}, {'params': model1.parameters()}], lr=1.0)model.train()input = torch.ones(3, 3, 10, 10)
label = torch.ones(3, dtype=torch.long)for i in range(1):output = model(input)for p in model.parameters():p.requires_grad = Falselast_output = model1(output)loss = criterion(last_output, label)optimizer.zero_grad()print(model.conv1.weight[0][0][0])print(model1.FC.weight[0])loss.backward()print(model.conv1.weight[0][0][0])optimizer.step()print(model.conv1.weight[0][0][0])print(model1.FC.weight[0])

其中为了便于观察是否冻结,展示了网络的几个参数。运行结果如下

从结果可以看到model网络层被冻结了。
关于require_grad=False的作用推荐阅读PyTorch中关于backward、grad、autograd的计算原理的深度剖析

with torch.no_grad()

该方法也是利用了require_grad=False,该做法在模型进行测试时比较常见,用法也比较简单,网上说的是放在需要冻结的网络层的forward函数里面,这里拿model举例子

class Net(nn.Module):def __init__(self):super().__init__()self.conv1 = nn.Conv2d(3, 16, 3, 1, padding=1, bias=False)self.conv2 = nn.Conv2d(16, 16, 3, 1, padding=1, bias=False)# self.FC = nn.Linear(16*10*10, 20, bias=False)def forward(self, x):with torch.no_grad():x = self.conv1(x)x = self.conv2(x)# x = self.FC(x.view(x.size(0), -1))return x

三、 torch.detach()和torch.data()

该方法直接在model的输出结果output后添加.detach()或者.data()
其实.detach()或者.data()就是令require_grad=False

for p in model.parameters():p.data.fill_(2)for p in model1.parameters():p.data.fill_(1)print(model.conv1.weight[0][0][0])
print(model1.FC.weight[0])criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD([{'params': model.parameters()}, {'params': model1.parameters()}], lr=1.0)model.train()input = torch.ones(3, 3, 10, 10)
label = torch.ones(3, dtype=torch.long)for i in range(1):output = model(input)for p in model.parameters():p.requires_grad = False# last_output = model1(output)last_output = model1(output.detach())# last_output = model1(output.data)loss = criterion(last_output, label)optimizer.zero_grad()print(model.conv1.weight[0][0][0])print(model1.FC.weight[0])loss.backward()print(model.conv1.weight[0][0][0])optimizer.step()print(model.conv1.weight[0][0][0])print(model1.FC.weight[0])

关于detach和data的区别大家可以看一下浅谈 PyTorch 中的 tensor 及使用

注 在 0.4.0 版本以前,.data 是用来取 Variable 中的 tensor 的,但是之后 Variable 被取消,.data 却留了下来。现在我们调用 tensor.data,可以得到 tensor的数据 + requires_grad=False 的版本,而且二者共享储存空间,也就是如果修改其中一个,另一个也会变。因为 PyTorch 的自动求导系统不会追踪 tensor.data 的变化,所以使用它的话可能会导致求导结果出错。官方建议使用 tensor.detach() 来替代它,二者作用相似,但是 detach 会被自动求导系统追踪,使用起来很安全。

四、设置优化器更新参数

如果不想让某一个网络层进行更新,比较简单的做法就是不把该网络层的参数放到优化器里面。拿本文案例来讲,优化器原来为:

optimizer = optim.SGD([{'params': model.parameters()}, {'params': model1.parameters()}], lr=1.0)

把冻结的网络参数不放入优化器中,此刻优化器为

optimizer = optim.SGD(model1.parameters(), lr=1.0)

注: 此刻被冻结的参数在进行反向传播时依旧进行求导,只是参数没有更新。
可以看到,如果采用该方法可以减少内存使用,同时如果提前使用require_grad=False会使得模型跳过不需要计算的参数,提高运算速度,所以可以将这两种方法结合在一起使用。

参考文献

Pytorch autograd,backward详解
Pytorch在训练时冻结某些层使其不参与训练(更新梯度)
浅谈 PyTorch 中的 tensor 及使用
【PyTorch】冻结网络参数
PyTorch中关于backward、grad、autograd的计算原理的深度剖析

pytroch冻结某些层的常用方法相关推荐

  1. pytorch 之 冻结某层参数,即训练时不更新

    首先,我们知道,深度学习网络中的参数是通过计算梯度,在反向传播进行更新的,从而能得到一个优秀的参数,但是有的时候,我们想固定其中的某些层的参数不参与反向传播.比如说,进行微调时,我们想固定已经加载预训 ...

  2. pytorch——冻结某层参数

    参考链接: https://blog.csdn.net/qq_41368074/article/details/107860126 https://blog.csdn.net/Code_Mart/ar ...

  3. Pytorch 加载部分预训练模型并冻结某些层

    目录 1  pytorch的版本: 2  数据下载地址: 3  原始版本代码下载: 4  直接上代码: 1  pytorch的版本: 2  数据下载地址: <https://download.p ...

  4. pytorch训练网络冻结某些层

    引言:首先我们应该很清楚地知道冻结网络中的某些层有什么作用?如何进行相关的冻结设置?代码何如呢? 话不多说说,首先我们探讨第一个问题: 1.冻结网络的某些层有什么作用? 这个问题顾名思义就是冻结网络中 ...

  5. 网络中BN层的作用以及为什么冻结BN层

    BN层的作用主要有三个: 加快网络的训练和收敛的速度 控制梯度爆炸防止梯度消失 防止过拟合

  6. 预训练+微调+Rethinking ImageNet Pre-training论文阅读笔记

    文章目录 一.前言 二.预训练+微调 1.预训练 2.微调 3.Pytroch实现 三.Rethinking ImageNet Pre-training论文笔记 参考文献 一.前言 近期在阅读何凯明大 ...

  7. pytorch 冻结层操作 + 学习率超参数设置

    pytorch finetune冻结层操作 知乎文章:pytorch 两种冻结层的方式 - 知乎 文章说了两种冻结层的方法: 一.设置requires_grad为False 第一步: for para ...

  8. 【yolov5 v6.0】中断以后重新训练,增加epochs,冻结层

    中断以后重新训练 有个resume的参数,将default从False改成True,然后他就会自己去找最新的权重继续训练了. 然后有个需要注意的点就是,不要为了想要备份最新的权重,然后把它复制一份出来 ...

  9. keras冻结_Keras 实现加载预训练模型并冻结网络的层

    在解决一个任务时,我会选择加载预训练模型并逐步fine-tune.比如,分类任务中,优异的深度学习网络有很多. ResNet, VGG, Xception等等... 并且这些模型参数已经在imagen ...

最新文章

  1. 用Python轻松搞定Excel中的20个常用操作
  2. [翻译]AKKA笔记 - CHILD ACTORS与ACTORPATH -6
  3. 山西DOT NET俱乐部
  4. mongodb聚合操作之group
  5. Spring框架对JDBC的简单封装。提供了一个JDBCTemplate对象简化JDBC的开发
  6. SSAS的MDX的基础函数(二)
  7. 测试人员转型是大势所趋:我的十年自动化测试经验分享
  8. c语言智能指针是什么,C++ 智能指针深入解析
  9. c 调用 android jar包,Unity调用AndroidStudio导出的Jar包
  10. 【渝粤教育】国家开放大学2018年秋季 2632T城市轨道交通客运组织 参考试题
  11. win10 64位 Compaq Visual Fortran(CVF)安装教程
  12. 计算机学业水平考试反思总结8百,考试反思与总结
  13. SylixOS命令行下内存操作/测试工具
  14. 创建多媒体APP 之 音频播放:管理音频焦点
  15. maven 打包打出带依赖的和不带依赖的jiar包
  16. ZBrush自定义笔刷
  17. asp毕业设计—— 基于asp+access的网上教学系统设计与实现(毕业论文+程序源码)——网上教学系统
  18. 初中数学教师资格证考试成功通过前辈复习经验分享
  19. 大型游戏行业网络技术解决方案
  20. Git命令及分支操作

热门文章

  1. python 使用cv2、io.BytesIO处理图片二进制数据
  2. Android第三方SDK集成 —— 极光推送
  3. 解决开发板不兼容earpods问题
  4. 将 ChatGPT 引入你的飞书
  5. 阿卜杜拉·法兹里和两个哥哥的故事(二)
  6. Probability and Hypothesis Testing
  7. Turbo还是那个Turbo吗?
  8. 7_Arya_superbeyone_新浪博客
  9. Qt+mpv制作windows/linux 下的动态壁纸软件(含源码)
  10. go语言的魔幻旅程28-go命令