[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation
作者信息:
Robin Brügger, CV Lab,ETH Zürich
代码:https://github.com/RobinBruegger/RevTorch
https://github.com/RobinBruegger/PartiallyReversibleUnet
医疗影像常用3D网络,显存占用经常制约了网络结构与深度,从而对最终精度产生影响。文章主要借鉴了reversible block 的思路来解决上述问题。
reversible block
该block设计很巧妙。输入x 按通道数先分成两组,x1, x2。利用如下公式(1),得到y1,y2,由于特殊的结构设计,x1,x2反过来又可以由公式(2) 通过y1,y2计算得到。
网络训练时显存占用很大一部分是储存前向传播的中间结果(因为反向传播时需要用到),使用 reversible block 后,中间结果无需保存,只要保存最后输出的结果,中间结果都可以反推得到。
Method
文章基于MICCAI Brats18挑战赛第二名 No-New-Net 的结构进行改进,引入reversible block后的网络结构如下:
Results
结果很好,第一二行比较可以看到使用reversible block后,显存节约2.5G,使得在12G显存下使用full volume 训练成为可能,与No-New-Net的单模型比也要强。
代码
reversible block模块部分的代码如下,反向传播的代码花了一定时间才大致了解。f.backward(dy)
是链式法则的意思:把f.backward()
得到的梯度乘上之前层反传得到的梯度dy
,可以参考这个资料
import torch
import torch.nn as nn
#import torch.autograd.function as funcclass ReversibleBlock(nn.Module):'''Elementary building block for building (partially) reversible architecturesImplementation of the Reversible block described in the RevNet paper(https://arxiv.org/abs/1707.04585). Must be used inside a :class:`revtorch.ReversibleSequence`for autograd support.Arguments:f_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shapeg_block (nn.Module): arbitrary subnetwork whos output shape is equal to its input shape'''def __init__(self, f_block, g_block):super(ReversibleBlock, self).__init__()self.f_block = f_blockself.g_block = g_blockdef forward(self, x):"""Performs the forward pass of the reversible block. Does not record any gradients.:param x: Input tensor. Must be splittable along dimension 1.:return: Output tensor of the same shape as the input tensor"""x1, x2 = torch.chunk(x, 2, dim=1)y1, y2 = None, Nonewith torch.no_grad():y1 = x1 + self.f_block(x2)y2 = x2 + self.g_block(y1)return torch.cat([y1, y2], dim=1)def backward_pass(self, y, dy):"""Performs the backward pass of the reversible block.Calculates the derivatives of the block's parameters in f_block and g_block, as well as the inputs of theforward pass and its gradients.:param y: Outputs of the reversible block:param dy: Derivatives of the outputs:return: A tuple of (block input, block input derivatives). The block inputs are the same shape as the block outptus."""# Split the arguments channel-wisey1, y2 = torch.chunk(y, 2, dim=1)del yassert (not y1.requires_grad), "y1 must already be detached"assert (not y2.requires_grad), "y2 must already be detached"dy1, dy2 = torch.chunk(dy, 2, dim=1)del dyassert (not dy1.requires_grad), "dy1 must not require grad"assert (not dy2.requires_grad), "dy2 must not require grad"# Enable autograd for y1 and y2. This ensures that PyTorch# keeps track of ops. that use y1 and y2 as inputs in a DAGy1.requires_grad = Truey2.requires_grad = True# Ensures that PyTorch tracks the operations in a DAGwith torch.enable_grad():gy1 = self.g_block(y1)# Use autograd framework to differentiate the calculation. The# derivatives of the parameters of G are set as a side effectgy1.backward(dy2)with torch.no_grad():x2 = y2 - gy1 # Restore first input of forward()del y2, gy1# The gradient of x1 is the sum of the gradient of the output# y1 as well as the gradient that flows back through G# (The gradient that flows back through G is stored in y1.grad)dx1 = dy1 + y1.graddel dy1y1.grad = Nonewith torch.enable_grad():x2.requires_grad = Truefx2 = self.f_block(x2)# Use autograd framework to differentiate the calculation. The# derivatives of the parameters of F are set as a side effecfx2.backward(dx1)with torch.no_grad():x1 = y1 - fx2 # Restore second input of forward()del y1, fx2# The gradient of x2 is the sum of the gradient of the output# y2 as well as the gradient that flows back through F# (The gradient that flows back through F is stored in x2.grad)dx2 = dy2 + x2.graddel dy2x2.grad = None# Undo the channelwise splitx = torch.cat([x1, x2.detach()], dim=1)dx = torch.cat([dx1, dx2], dim=1)return x, dx
我的笔记
我觉得这篇文章思路很棒,一是本文针对到了医疗影像处理的一个痛点,即显存占用。大部分研究者显存受限,12G为最常用的设备。二是他引入了其他领域的reversible block的思路,该问题提出了一个解决思路,并且最终的实验结果也很好。本文对我的研究思路有很好的启发。
当然,节约的显存是以更长的训练时间为代价的。
[MICCAI2019] A Partially Reversible U-Net for Memory-Efficient Volumetric Image Segmentation相关推荐
- 【论文阅读】Rethinking S-T Networks with Improved Memory Coverage for Efficient Video Object Segmentation
一篇NeurIPS 2021的关于VOS (video object segmentation) 的文章,文章的思想很有借鉴价值. 论文链接 Rethinking Space-Time Network ...
- CUDA系列学习(四)Parallel Task类型 与 Memory Allocation
本文为CUDA系列学习第四讲,首先介绍了Parallel communication patterns的几种形式(map, gather, scatter, stencil, transpose), ...
- 内存管理:Linux Memory Management:MMU、段、分页、PAE、Cache、TLB
目录 Linux Memory Management Memory Address Need for Virtual Addressing Address Translation Address Tr ...
- 【操作系统概念-作业8】Main Memory
#! https://zhuanlan.zhihu.com/p/424671940 [操作系统概念-作业8]Main Memory Operating System Concepts Exercise ...
- Understanding memory usage on Linux
2019独角兽企业重金招聘Python工程师标准>>> This entry is for those people who have ever wondered, "Wh ...
- 【今日CV 计算机视觉论文速览 第131期】Mon, 17 Jun 2019
今日CS.CV 计算机视觉论文速览 Mon, 17 Jun 2019 Totally 44 papers ?上期速览✈更多精彩请移步主页 Interesting: ?综述:基于图像的深度重建, 基于单 ...
- CVPR 2016 摘要阅读
为了说明看过CVPR2016全部文章的摘要,总结一下,摘要只保留了创新点部分. ORAL SESSION Image Captioning and Question Answering Monday, ...
- CVPR 2016 全部文章摘要阅读
为了说明看过CVPR2016全部文章的摘要,总结一下,摘要只保留了创新点部分. ORAL SESSION Image Captioning and Question Answering Monday, ...
- ICIP 2009 Papers
以下转载ICIP 2009录用文章的题目,希望从其中能够发现有用的信息.想来当初也是打算投一个的,可是怎么看总觉得创新不够,后来也就放弃了.争取以后的机会吧. 粗略看了下标题,感觉不愧是一个大会,文章 ...
- Docker安装Tomcat、MySQL和Redis
总体步骤 Docker安装Tomcat docker hub上查找tomcat镜像 docker search tomcat 从docker hub上拉取tomcat镜像到本地 docker pull ...
最新文章
- 美国新桥投资集团(Newbridge Capital) [from baike]
- jQuery遍历之next()、nextAll()方法使用实例
- Terrarium 1.2
- iframe嵌入页面白屏_Vue使用iframe嵌入第三方网页并修改标题
- P1983-车站分级【图论,记忆化dfs,构图】
- Java 8 Stream Api 中的 peek、map、foreach区别
- Java高并发编程详解系列-线程生命周期观察者
- 计算机机房的维护管理论文,学校计算机机房维护与管理论文
- java中的集合_Java 集合介绍,常用集合类
- MyEclipse Web开发教程:XML XML架构(一)
- jquery动态加载js/css文件方法
- zbb20170919 设置eclipse myeclipse 在工程窗口 项目目录 显示代码错误提示的
- (二)目前主流的 Java 虚拟机有哪些?
- 视频教程-CCNA自学视频课程专题四:CCNA认证重点难点解析3(扩展篇)-思科认证
- 快速下载720云高清全景图片
- 大学计算机基础知识点
- 在centos中运行出现错误:cannot find -lbz2
- com.sun.mirror的jar包
- android 播放视频卡顿,android播放一个mp4文件的问题 卡顿非常严重
- 网络工程师/计算机网络知识如何入门,以及学习路线?