1. 问题来源

最近在用mmcv复现Partial FC模型,看到源码中,有单独写的前向反向传播,甚是疑惑~
源码:

 # Features all-gather
total_features = torch.zeros(features.size()[0] * cfg.world_size,cfg.embedding_size,device=local_rank)
dist.all_gather(list(total_features.chunk(cfg.world_size, dim=0)),    features.data)
total_features.requires_grad = True...
#计算 loss 以及 backward
...if total_features.grad is not None:total_features.grad.detach_()
x_grad = torch.zeros_like(features)# Feature gradient all-reduce
dist.reduce_scatter(x_grad, list(total_features.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)
# Backward backbone
features.backward(x_grad)
optimizer.step()

2. all_gather做了啥

dist.all_gather官方样例:

>>> # All tensors below are of torch.int64 dtype.
>>> # We have 2 process groups, 2 ranks.
>>> tensor_list = [torch.zero(2, dtype=torch.int64) for _ in range(2)]
>>> tensor_list
[tensor([0, 0]), tensor([0, 0])] # Rank 0 and 1
>>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
>>> tensor
tensor([1, 2]) # Rank 0
tensor([3, 4]) # Rank 1
>>> dist.all_gather(tensor_list, tensor)
>>> tensor_list
[tensor([1, 2]), tensor([3, 4])] # Rank 0
[tensor([1, 2]), tensor([3, 4])] # Rank 1

发现 dist.all_gather的输出值 tensor_list,是在all_gather计算之前 需要新初始化的一个list,虽然前向传播时 tensor–>tensor_list, 但是反向传播时,由于 tensor_list 是新初始化的叶子结点,所以并不能实现 tensor_list.grad–>tensor.grad 。

所以解决这个问题,需要将该部分的反向传播搭建起来。

3. 解决方案

(1)第一种方法就是像PFC源码里写的那样,反向传播时分段处理:

# 1. 直接loss反向传播
loss.backward()
...
# 2. 在all_gather部分,对梯度进行衔接
if tensor_list.grad is not None:tensor_list.grad.detach_()
x_grad = torch.zeros_like(tensor)# 将梯度对应分配到各个GPU上
# Feature gradient all-reduce
dist.reduce_scatter(x_grad, list(tensor_list.grad.chunk(cfg.world_size, dim=0)))
x_grad.mul_(cfg.world_size)# 3. 剩余部分反向传播 Backward backbone
tensor.backward(x_grad)
# 梯度更新
optimizer.step()

(2)自定义torch.autograd.function类
对于这些不可自动求导的操作,pytorch给出了扩展 torch.autograd.function 来实现自定义求导方式,pytorch文档里也给出了使用样例:

class Exp(Function):# 定义一些前向骚操作@staticmethoddef forward(ctx, i):result = i.exp()ctx.save_for_backward(result)return result# 前向操作太骚,只好自己写反传啦~。~@staticmethoddef backward(ctx, grad_output):result, = ctx.saved_tensorsreturn grad_output * result# 应用时就可以这么搞拉
#Use it by calling the apply method:
output = Exp.apply(input)

所以前面那个问题就可以解决啦:

class BwFunction(Function):@staticmethoddef forward(ctx, x):world_size = dist.get_world_size()total_features = torch.zeros(x.size()[0]*world_size, x.size()[1], device=x.device)       dist.all_gather(list(total_features.chunk(world_size, dim=0)), x.data)  total_features.requires_grad = Truereturn total_features@staticmethoddef backward(ctx, grad_output): world_size = dist.get_world_size()grad_x = Noneif grad_output is not None:grad_output.detach_()x_grad = torch.zeros_like(x)# Feature gradient all-reducedist.reduce_scatter(x_grad, list(grad_output.chunk(world_size, dim=0)))x_grad.div_(world_size)grad_x = x_gradreturn grad_x

拖了这么久,好不容易写完,撒花!

参考:

  1. https://github.com/deepinsight/insightface/tree/master/recognition/partial_fc
  2. https://discuss.pytorch.org/t/will-dist-all-gather-break-the-auto-gradient-graph/47350
  3. https://pytorch.org/docs/stable/autograd.html?highlight=autograd#module-torch.autograd
  4. https://blog.csdn.net/Hungryof/article/details/78346304

使用torch.autograd.function解决dist.all_gather不能反向传播问题相关推荐

  1. Pytorch的自定义拓展:torch.nn.Module和torch.autograd.Function

    参考链接:pytorch的自定义拓展之(一)--torch.nn.Module和torch.autograd.Function_LoveMIss-Y的博客-CSDN博客_pytorch自定义backw ...

  2. RuntimeError: Legacy autograd function with non-static forward method is deprecated.

    显卡RTX3080 cuda 11.1 cudnn 8.0.5 python3.6.4 在使用pytorch1.0训练densefusion模型时报错,改用pytorch1.7,然后报上面的错误 Tr ...

  3. 自定义autograd function

    在TSN代码中, segmentconsensus是一个自定义函数, 所以要写一下它对应的梯度运算 # tj : https://blog.csdn.net/tsq292978891/article/ ...

  4. Pytorch使用autograd.Function自定义拓展神经网络

    我们知道CNN这类人工神经网络都基于BP算法进行优化,因此需要误差关于权重是连续可导的,这是可以运用BP算法的前提条件:也有一些网络不满足这个条件. 1.可导 对于可连续求导的神经网络构建时采用nn. ...

  5. Pytorch版本过高产生的RuntimeError: Legacy autograd function with non-static forward method is deprecated.

    前言 在尝试用ECO Lite做视频分类的时候,使用了作者的Pytorch实现,然而Pytorch实现是基于Pytorch0.4的,我自己的Pytorch版本是1.4,所以在跑模型的时候出现了一些问题 ...

  6. 【Pytorch】反向传播为NaN报错的排查解决方法,RuntimeError: Function ‘BmmBackward0‘ returned nan values

    最近在训练模型的过程中,反复出现方向传播至为NaN的报错,报错信息如下所示: File "/home/fu/anaconda3/envs/torch/lib/python3.7/site-p ...

  7. Please use new-style autograd function with static forward method

    Please use new-style autograd function with static forward method Legacy autograd function with non- ...

  8. Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd

    Legacy autograd function with non-static forward method is deprecated. Please use new-style autograd ...

  9. pytorch 笔记: 扩展torch.autograd

    1 扩展torch.autograd 向 autograd 添加操作需要为每个操作实现一个新的 Function 子类. 回想一下,函数是 autograd 用来编码操作历史和计算梯度的东西. 2 何 ...

最新文章

  1. java控制语句练习题_[Java初探实例篇02]__流程控制语句知识相关的实例练习
  2. php 3d animation,css3D+动画的例子(附完整代码)
  3. 使用swipemenulistview实现列表的左右滑动
  4. VS开发C#窗体应用时怎样设置窗体属性
  5. linux nohup screen注解
  6. boost::ratio_less_equal相关的测试程序
  7. java斐波那切数列_Java中的递归方法
  8. 核心交换机相对于普通交换机的优势
  9. 十六、Python操作excel(.xlsx)封装类MyPyExce
  10. PDF 合并软件要收费?程序员自己做一个
  11. 命令发送广播_那些你不知道的ping命令参数
  12. fft qt 代码_最简洁的FFT代码(C++实现)
  13. qq邮箱html模板_用了这么多简历模板,发现只有QQ邮箱自带的模板最好用
  14. kdchxue讲解V9父栏目调用子栏目的办法
  15. GIMP教程 3 扭曲变换工具 (瘦脸 瘦腿)
  16. 容器监控cadvisor
  17. python itchat_Python itchat模块在微信上的
  18. 物联网安全硬件修改系列-硬改
  19. SAP云平台里的三叉戟应用
  20. 计算机和网络之间有个感叹号,网络有个感叹号!电脑无线网络连接不上的几种常见问题...

热门文章

  1. 字节跳动2020春招后端开发工程师笔试复盘
  2. ABAP bgRFC 实例
  3. 转--Python标准库之一句话概括
  4. 华为这么牛?21级程序员月薪看哭众人!网友直呼:我们不一样
  5. MVC、MVVM、MVP
  6. matlab 地形模拟程序,MATLAB模拟小球自由落体运动
  7. 苹果mac电脑重装系统,以及重装之后没有声音、热键不能使用的解决办法
  8. 数论概论笔记(二)勾股数组
  9. 【面经】华为车BU面经
  10. Ubuntu 电脑下插入移动硬盘,显示不能挂载该硬盘