使用torch.autograd.function解决dist.all_gather不能反向传播问题
1. 问题来源
最近在用mmcv复现Partial FC模型,看到源码中,有单独写的前向反向传播,甚是疑惑~
源码:
# Features all-gather
total_features = torch.zeros(features.size()[0] * cfg.world_size,cfg.embedding_size,device=local_rank)
dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)), features.data)
total_features.requires_grad = True...
#计算 loss 以及 backward
...if total_features.grad is not None:total_features.grad.detach_()
x_grad = torch.zeros_like(features)# Feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)
# Backward backbone
features.backward(x_grad)
optimizer.step()
2. all_gather做了啥
dist.all_gather官方样例:
>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1
发现 dist.all_gather的输出值 tensor_list,是在all_gather计算之前 需要新初始化的一个list,虽然前向传播时 tensor–>tensor_list, 但是反向传播时,由于 tensor_list 是新初始化的叶子结点,所以并不能实现 tensor_list.grad–>tensor.grad 。
所以解决这个问题,需要将该部分的反向传播搭建起来。
3. 解决方案
(1)第一种方法就是像PFC源码里写的那样,反向传播时分段处理:
# 1. 直接loss反向传播
loss.backward()
...
# 2. 在all_gather部分,对梯度进行衔接
if tensor_list.grad is not None:tensor_list.grad.detach_()
x_grad = torch.zeros_like(tensor)# 将梯度对应分配到各个GPU上
# Feature gradient all-reduce
dist.reduce_scatter(x_grad, list(tensor_list.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)# 3. 剩余部分反向传播 Backward backbone
tensor.backward(x_grad)
# 梯度更新
optimizer.step()
(2)自定义torch.autograd.function类
对于这些不可自动求导的操作,pytorch给出了扩展 torch.autograd.function 来实现自定义求导方式,pytorch文档里也给出了使用样例:
class Exp(Function):# 定义一些前向骚操作@staticmethoddef forward(ctx, i):result = i.exp()ctx.save_for_backward(result)return result# 前向操作太骚,只好自己写反传啦~。~@staticmethoddef backward(ctx, grad_output):result, = ctx.saved_tensorsreturn grad_output * result# 应用时就可以这么搞拉
#Use it by calling the apply method:
output = Exp.apply(input)
所以前面那个问题就可以解决啦:
class BwFunction(Function):@staticmethoddef forward(ctx, x):world_size = dist.get_world_size()total_features = torch.zeros(x.size()[0]*world_size, x.size()[1], device=x.device) dist.all_gather(list(total_features.chunk(world_size, dim=0)), x.data) total_features.requires_grad = Truereturn total_features@staticmethoddef backward(ctx, grad_output): world_size = dist.get_world_size()grad_x = Noneif grad_output is not None:grad_output.detach_()x_grad = torch.zeros_like(x)# Feature gradient all-reducedist.reduce_scatter(x_grad, list(grad_output.chunk(world_size, dim=0)))x_grad.div_(world_size)grad_x = x_gradreturn grad_x
拖了这么久,好不容易写完,撒花!
参考:
- https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc
- https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350
- https://pytorch.org/docs/stable/autograd.html?highlight=autograd#module-torch.autograd
- https://blog.csdn.net/Hungryof/article/details/78346304
使用torch.autograd.function解决dist.all_gather不能反向传播问题相关推荐
- Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function
参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...
- RuntimeError: Legacy autograd function with non-static forward method is deprecated.
显卡RTX3080 cuda 11.1 cudnn 8.0.5 python3.6.4 在使用pytorch1.0训练densefusion模型时报错,改用pytorch1.7,然后报上面的错误 Tr ...
- 自定义autograd function
在TSN代码中, segmentconsensus是一个自定义函数, 所以要写一下它对应的梯度运算 # tj : https://blog.csdn.net/tsq292978891/article/ ...
- Pytorch使用autograd.Function自定义拓展神经网络
我们知道CNN这类人工神经网络都基于BP算法进行优化,因此需要误差关于权重是连续可导的,这是可以运用BP算法的前提条件:也有一些网络不满足这个条件. 1.可导 对于可连续求导的神经网络构建时采用nn. ...
- Pytorch版本过高产生的RuntimeError: Legacy autograd function with non-static forward method is deprecated.
前言 在尝试用ECO Lite做视频分类的时候,使用了作者的Pytorch实现,然而Pytorch实现是基于Pytorch0.4的,我自己的Pytorch版本是1.4,所以在跑模型的时候出现了一些问题 ...
- 【Pytorch】反向传播为NaN报错的排查解决方法,RuntimeError: Function ‘BmmBackward0‘ returned nan values
最近在训练模型的过程中,反复出现方向传播至为NaN的报错,报错信息如下所示: File "/home/fu/anaconda3/envs/torch/lib/python3.7/site-p ...
- Please use new-style autograd function with static forward method
Please use new-style autograd function with static forward method Legacy autograd function with non- ...
- Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd
Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd ...
- pytorch 笔记: 扩展torch.autograd
1 扩展torch.autograd 向 autograd 添加操作需要为每个操作实现一个新的 Function 子类. 回想一下,函数是 autograd 用来编码操作历史和计算梯度的东西. 2 何 ...
最新文章
- java控制语句练习题_[Java初探实例篇02]__流程控制语句知识相关的实例练习
- php 3d animation,css3D+动画的例子(附完整代码)
- 使用swipemenulistview实现列表的左右滑动
- VS开发C#窗体应用时怎样设置窗体属性
- linux nohup screen注解
- boost::ratio_less_equal相关的测试程序
- java斐波那切数列_Java中的递归方法
- 核心交换机相对于普通交换机的优势
- 十六、Python操作excel(.xlsx)封装类MyPyExce
- PDF 合并软件要收费?程序员自己做一个
- 命令发送广播_那些你不知道的ping命令参数
- fft qt 代码_最简洁的FFT代码(C++实现)
- qq邮箱html模板_用了这么多简历模板,发现只有QQ邮箱自带的模板最好用
- kdchxue讲解V9父栏目调用子栏目的办法
- GIMP教程 3 扭曲变换工具 (瘦脸 瘦腿)
- 容器监控cadvisor
- python itchat_Python itchat模块在微信上的
- 物联网安全硬件修改系列-硬改
- SAP云平台里的三叉戟应用
- 计算机和网络之间有个感叹号,网络有个感叹号!电脑无线网络连接不上的几种常见问题...