作者信息:
Robin Brügger, CV Lab,ETH Zürich
代码:https://github.com/RobinBruegger/RevTorch
https://github.com/RobinBruegger/PartiallyReversibleUnet


医疗影像常用3D网络,显存占用经常制约了网络结构与深度,从而对最终精度产生影响。文章主要借鉴了reversible block 的思路来解决上述问题。

reversible block

该block设计很巧妙。输入x 按通道数先分成两组,x1, x2。利用如下公式(1),得到y1,y2,由于特殊的结构设计,x1,x2反过来又可以由公式(2) 通过y1,y2计算得到。

网络训练时显存占用很大一部分是储存前向传播的中间结果(因为反向传播时需要用到),使用 reversible block 后,中间结果无需保存,只要保存最后输出的结果,中间结果都可以反推得到。

Method

文章基于MICCAI Brats18挑战赛第二名 No-New-Net 的结构进行改进,引入reversible block后的网络结构如下:

Results

结果很好,第一二行比较可以看到使用reversible block后,显存节约2.5G,使得在12G显存下使用full volume 训练成为可能,与No-New-Net的单模型比也要强。

代码

reversible block模块部分的代码如下,反向传播的代码花了一定时间才大致了解。f.backward(dy) 是链式法则的意思:把f.backward()得到的梯度乘上之前层反传得到的梯度dy,可以参考这个资料

import torch
import torch.nn as nn
#import torch.autograd.function as funcclass ReversibleBlock(nn.Module):'''Elementary building block for building (partially) reversible architecturesImplementation of the Reversible block described in the RevNet paper(https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`for autograd support.Arguments:f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shapeg_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape'''def __init__(self, f_block, g_block):super(ReversibleBlock, self).__init__()self.f_block = f_blockself.g_block = g_blockdef forward(self, x):"""Performs the forward pass of the reversible block. Does not record any gradients.:param x: Input tensor. Must be splittable along dimension 1.:return: Output tensor of the same shape as the input tensor"""x1, x2 = torch.chunk(x, 2, dim=1)y1, y2 = None, Nonewith torch.no_grad():y1 = x1 + self.f_block(x2)y2 = x2 + self.g_block(y1)return torch.cat([y1, y2], dim=1)def backward_pass(self, y, dy):"""Performs the backward pass of the reversible block.Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of theforward pass and its gradients.:param y: Outputs of the reversible block:param dy: Derivatives of the outputs:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus."""# Split the arguments channel-wisey1, y2 = torch.chunk(y, 2, dim=1)del yassert (not y1.requires_grad), "y1 must already be detached"assert (not y2.requires_grad), "y2 must already be detached"dy1, dy2 = torch.chunk(dy, 2, dim=1)del dyassert (not dy1.requires_grad), "dy1 must not require grad"assert (not dy2.requires_grad), "dy2 must not require grad"# Enable autograd for y1 and y2. This ensures that PyTorch# keeps track of ops. that use y1 and y2 as inputs in a DAGy1.requires_grad = Truey2.requires_grad = True# Ensures that PyTorch tracks the operations in a DAGwith torch.enable_grad():gy1 = self.g_block(y1)# Use autograd framework to differentiate the calculation. The# derivatives of the parameters of G are set as a side effectgy1.backward(dy2)with torch.no_grad():x2 = y2 - gy1 # Restore first input of forward()del y2, gy1# The gradient of x1 is the sum of the gradient of the output# y1 as well as the gradient that flows back through G# (The gradient that flows back through G is stored in y1.grad)dx1 = dy1 + y1.graddel dy1y1.grad = Nonewith torch.enable_grad():x2.requires_grad = Truefx2 = self.f_block(x2)# Use autograd framework to differentiate the calculation. The# derivatives of the parameters of F are set as a side effecfx2.backward(dx1)with torch.no_grad():x1 = y1 - fx2 # Restore second input of forward()del y1, fx2# The gradient of x2 is the sum of the gradient of the output# y2 as well as the gradient that flows back through F# (The gradient that flows back through F is stored in x2.grad)dx2 = dy2 + x2.graddel dy2x2.grad = None# Undo the channelwise splitx = torch.cat([x1, x2.detach()], dim=1)dx = torch.cat([dx1, dx2], dim=1)return x, dx

我的笔记

我觉得这篇文章思路很棒,一是本文针对到了医疗影像处理的一个痛点,即显存占用。大部分研究者显存受限,12G为最常用的设备。二是他引入了其他领域的reversible block的思路,该问题提出了一个解决思路,并且最终的实验结果也很好。本文对我的研究思路有很好的启发。
当然,节约的显存是以更长的训练时间为代价的。

[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation相关推荐

  1. 【论文阅读】Rethinking S-T Networks with Improved Memory Coverage for Efficient Video Object Segmentation

    一篇NeurIPS 2021的关于VOS (video object segmentation) 的文章,文章的思想很有借鉴价值. 论文链接 Rethinking Space-Time Network ...

  2. CUDA系列学习(四)Parallel Task类型 与 Memory Allocation

    本文为CUDA系列学习第四讲,首先介绍了Parallel communication patterns的几种形式(map, gather, scatter, stencil, transpose), ...

  3. 内存管理:Linux Memory Management:MMU、段、分页、PAE、Cache、TLB

    目录 Linux Memory Management Memory Address Need for Virtual Addressing Address Translation Address Tr ...

  4. 【操作系统概念-作业8】Main Memory

    #! https://zhuanlan.zhihu.com/p/424671940 [操作系统概念-作业8]Main Memory Operating System Concepts Exercise ...

  5. Understanding memory usage on Linux

    2019独角兽企业重金招聘Python工程师标准>>> This entry is for those people who have ever wondered, "Wh ...

  6. 【今日CV 计算机视觉论文速览 第131期】Mon, 17 Jun 2019

    今日CS.CV 计算机视觉论文速览 Mon, 17 Jun 2019 Totally 44 papers ?上期速览✈更多精彩请移步主页 Interesting: ?综述:基于图像的深度重建, 基于单 ...

  7. CVPR 2016 摘要阅读

    为了说明看过CVPR2016全部文章的摘要,总结一下,摘要只保留了创新点部分. ORAL SESSION Image Captioning and Question Answering Monday, ...

  8. CVPR 2016 全部文章摘要阅读

    为了说明看过CVPR2016全部文章的摘要,总结一下,摘要只保留了创新点部分. ORAL SESSION Image Captioning and Question Answering Monday, ...

  9. ICIP 2009 Papers

    以下转载ICIP 2009录用文章的题目,希望从其中能够发现有用的信息.想来当初也是打算投一个的,可是怎么看总觉得创新不够,后来也就放弃了.争取以后的机会吧. 粗略看了下标题,感觉不愧是一个大会,文章 ...

  10. Docker安装Tomcat、MySQL和Redis

    总体步骤 Docker安装Tomcat docker hub上查找tomcat镜像 docker search tomcat 从docker hub上拉取tomcat镜像到本地 docker pull ...

最新文章

  1. 美国新桥投资集团(Newbridge Capital) [from baike]
  2. jQuery遍历之next()、nextAll()方法使用实例
  3. Terrarium 1.2
  4. iframe嵌入页面白屏_Vue使用iframe嵌入第三方网页并修改标题
  5. P1983-车站分级【图论,记忆化dfs,构图】
  6. Java 8 Stream Api 中的 peek、map、foreach区别
  7. Java高并发编程详解系列-线程生命周期观察者
  8. 计算机机房的维护管理论文,学校计算机机房维护与管理论文
  9. java中的集合_Java 集合介绍,常用集合类
  10. MyEclipse Web开发教程:XML XML架构(一)
  11. jquery动态加载js/css文件方法
  12. zbb20170919 设置eclipse myeclipse 在工程窗口 项目目录 显示代码错误提示的
  13. (二)目前主流的 Java 虚拟机有哪些?
  14. 视频教程-CCNA自学视频课程专题四:CCNA认证重点难点解析3(扩展篇)-思科认证
  15. 快速下载720云高清全景图片
  16. 大学计算机基础知识点
  17. 在centos中运行出现错误:cannot find -lbz2
  18. com.sun.mirror的jar包
  19. android 播放视频卡顿,android播放一个mp4文件的问题 卡顿非常严重
  20. 网络工程师/计算机网络知识如何入门,以及学习路线?

热门文章

  1. PHP对接抖音开发平台接口
  2. pr如何跳到关键帧_全套pr视频剪辑教程[叫兽七叔讲解]
  3. Pr:音频和视频的同步
  4. select函数使用细节
  5. 思维导图 XMind 闯关之路(第02关)插入各类符号
  6. 大数据司法时代的立言、立功与立德
  7. 形容计算机网络教室的成语,形容教育的成语
  8. gitee、github使用教程
  9. 微信公众号H5网页跳转小程序方法
  10. 思维导图怎么做计划的简单高效绘制方法