Pytorch TTA 源码阅读

1.ttach/wrappers.py

TTA主要调用的接口

继承了pytorch的nn.Module

import torch
import torch.nn as nn
# 做类型注解的库
# 参考 https://www.bilibili.com/read/cv3249320/
from typing import Optional, Mapping, Union, Tuple from .base import Merger, Composeclass SegmentationTTAWrapper(nn.Module):"""Wrap PyTorch nn.Module (segmentation model) with test time augmentation transformsArgs:model (torch.nn.Module): segmentation model with single input and single output(.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])transforms (ttach.Compose): composition of test time transformsmerge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpenoutput_mask_key (str): if model output is `dict`, specify which key belong to `mask`"""def __init__(self,model: nn.Module,  # 这里是需要做TTA的训练好的modeltransforms: Compose,  # 数据增强的组合(这里实际上是一个迭代器) Compose class参见1.2merge_mode: str = "mean",  # 最后输出预测结果的方案output_mask_key: Optional[str] = None,  # Optional提示该参数是可选类型,告诉ide除了给定的默认值之外还有可能是None):super().__init__()self.model = modelself.transforms = transformsself.merge_mode = merge_modeself.output_key = output_mask_keydef forward(self, image: torch.Tensor, *args) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:# 初始化Merger类,Merger class参见1.1merger = Merger(type=self.merge_mode, n=len(self.transforms))# transformer是Compose的类# Compose class 参见1.2for transformer in self.transforms:# 由Compose class可知 transformer是一个Transformer迭代器augmented_image = transformer.augment_image(image)augmented_output = self.model(augmented_image, *args)# 做增强,然后送进模型if self.output_key is not None:augmented_output = augmented_output[self.output_key]# 这里的deaugment_mask还不确定具体要做什么deaugmented_output = transformer.deaugment_mask(augmented_output)# 放到output里merger.append(deaugmented_output)result = merger.resultif self.output_key is not None:result = {self.output_key: result}return resultclass ClassificationTTAWrapper(nn.Module):"""Wrap PyTorch nn.Module (classification model) with test time augmentation transformsArgs:model (torch.nn.Module): classification model with single input and single output(.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])transforms (ttach.Compose): composition of test time transformsmerge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpenoutput_label_key (str): if model output is `dict`, specify which key belong to `label`"""def __init__(self,model: nn.Module,transforms: Compose,merge_mode: str = "mean",output_label_key: Optional[str] = None,):super().__init__()self.model = modelself.transforms = transformsself.merge_mode = merge_modeself.output_key = output_label_keydef forward(self, image: torch.Tensor, *args) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:merger = Merger(type=self.merge_mode, n=len(self.transforms))for transformer in self.transforms:augmented_image = transformer.augment_image(image)augmented_output = self.model(augmented_image, *args)if self.output_key is not None:augmented_output = augmented_output[self.output_key]deaugmented_output = transformer.deaugment_label(augmented_output)merger.append(deaugmented_output)result = merger.resultif self.output_key is not None:result = {self.output_key: result}return resultclass KeypointsTTAWrapper(nn.Module):"""Wrap PyTorch nn.Module (keypoints model) with test time augmentation transformsArgs:model (torch.nn.Module): keypoints model with single input and single outputin format [x1,y1, x2, y2, ..., xn, yn](.forward(x) should return either torch.Tensor or Mapping[str, torch.Tensor])transforms (ttach.Compose): composition of test time transformsmerge_mode (str): method to merge augmented predictions mean/gmean/max/min/sum/tsharpenoutput_keypoints_key (str): if model output is `dict`, specify which key belong to `label`scaled (bool): True if model return x, y scaled values in [0, 1], else False"""def __init__(self,model: nn.Module,transforms: Compose,merge_mode: str = "mean",output_keypoints_key: Optional[str] = None,scaled: bool = False,):super().__init__()self.model = modelself.transforms = transformsself.merge_mode = merge_modeself.output_key = output_keypoints_keyself.scaled = scaleddef forward(self, image: torch.Tensor, *args) -> Union[torch.Tensor, Mapping[str, torch.Tensor]]:merger = Merger(type=self.merge_mode, n=len(self.transforms))size = image.size()batch_size, image_height, image_width = size[0], size[2], size[3]for transformer in self.transforms:augmented_image = transformer.augment_image(image)augmented_output = self.model(augmented_image, *args)if self.output_key is not None:augmented_output = augmented_output[self.output_key]augmented_output = augmented_output.reshape(batch_size, -1, 2)if not self.scaled:augmented_output[..., 0] /= image_widthaugmented_output[..., 1] /= image_heightdeaugmented_output = transformer.deaugment_keypoints(augmented_output)merger.append(deaugmented_output)result = merger.resultif not self.scaled:result[..., 0] *= image_widthresult[..., 1] *= image_heightresult = result.reshape(batch_size, -1)if self.output_key is not None:result = {self.output_key: result}return result

1.1 Merger

class Merger:def __init__(self,type: str = 'mean',  # TTA预测时的方法n: int = 1,  # trans的个数):if type not in ['mean', 'gmean', 'sum', 'max', 'min', 'tsharpen']:raise ValueError('Not correct merge type `{}`.'.format(type))self.output = Noneself.type = typeself.n = ndef append(self, x):if self.type == 'tsharpen':x = x ** 0.5# 这里就是output的计算,第一个output是None,这时候就把x放到output里就行# 后面根据预测的方法计算output,累和,累乘if self.output is None:self.output = xelif self.type in ['mean', 'sum', 'tsharpen']:self.output = self.output + xelif self.type == 'gmean':self.output = self.output * xelif self.type == 'max':self.output = F.max(self.output, x)elif self.type == 'min':self.output = F.min(self.output, x)# 我们可以使用@property装饰器来创建只读属性,# @property装饰器会将方法转换为相同名称的只读属性,可以与所定义的属性配合使用,这样可以防止属性被修改。# 参考https://zhuanlan.zhihu.com/p/64487092@propertydef result(self):# 就是根据不同的方法返回TTA的预测值if self.type in ['sum', 'max', 'min']:result = self.outputelif self.type in ['mean', 'tsharpen']:result = self.output / self.nelif self.type in ['gmean']:result = self.output ** (1 / self.n)else:raise ValueError('Not correct merge type `{}`.'.format(self.type))return result

1.2 Compose

class Compose:def __init__(self,transforms: List[BaseTransform],  # BaseTransform class 参见1.2.1):self.aug_transforms = transforms# itertools.product 做成元组的迭代器# 假设transforms有3个,每个的params有3个,结果就应该为:# [ (t1.params1, t2.params1, t3.params1),#   (t1.params2, t2.params2, t3.params2),#   (t1.params3, t2.params3, t3.params3) ]self.aug_transform_parameters = list(itertools.product(*[t.params for t in self.aug_transforms]))# 逆序上面俩参数,那么就是应该做一个反数据增强的操作,转换到原图self.deaug_transforms = transforms[::-1]self.deaug_transform_parameters = [p[::-1] for p in self.aug_transform_parameters]# Transformer class 参见1.2.2# __iter__() 生成迭代器的时候调用,能用for循环调用next()方法# 这里用yeild生成的迭代器,所以不用再写一个next()方法def __iter__(self) -> Transformer:for aug_params, deaug_params in zip(self.aug_transform_parameters, self.deaug_transform_parameters):# partial表示对一个可调用对象进行操作,先传入一部分参数,做成一个有一部分参数的可调用对象,例如# add(x, y)需要两个参数,a = partial(add, y=1)# 此时调用a(2)相当于a(x=2, y=1)image_aug_chain = Chain([partial(t.apply_aug_image, **{t.pname: p})for t, p in zip(self.aug_transforms, aug_params)])mask_deaug_chain = Chain([partial(t.apply_deaug_mask, **{t.pname: p})for t, p in zip(self.deaug_transforms, deaug_params)])label_deaug_chain = Chain([partial(t.apply_deaug_label, **{t.pname: p})for t, p in zip(self.deaug_transforms, deaug_params)])keypoints_deaug_chain = Chain([partial(t.apply_deaug_keypoints, **{t.pname: p})for t, p in zip(self.deaug_transforms, deaug_params)])yield Transformer(image_pipeline=image_aug_chain,mask_pipeline=mask_deaug_chain,label_pipeline=label_deaug_chain,keypoints_pipeline=keypoints_deaug_chain)def __len__(self) -> int:return len(self.aug_transform_parameters)

1.2.1 BaseTransform

class BaseTransform:identity_param = Nonedef __init__(self,name: str,params: Union[list, tuple],):self.params = paramsself.pname = name# 目前看来只有一个初始化函数有用# raise NotImplementedError应该是表示这里的方法还没有具体实现# 这里表示的应该是父类抽象接口,当子类继承这个类的时候再具体去写接口就行了def apply_aug_image(self, image, *args, **params):raise NotImplementedErrordef apply_deaug_mask(self, mask, *args, **params):raise NotImplementedErrordef apply_deaug_label(self, label, *args, **params):raise NotImplementedErrordef apply_deaug_keypoints(self, keypoints, *args, **params):raise NotImplementedError

1.2.2 Transformer

class Transformer:# Chain class 参见1.2.3def __init__(self,image_pipeline: Chain,mask_pipeline: Chain,label_pipeline: Chain,keypoints_pipeline: Chain):self.image_pipeline = image_pipelineself.mask_pipeline = mask_pipelineself.label_pipeline = label_pipelineself.keypoints_pipeline = keypoints_pipeline# Transformer类的作用就是根据传进来的Chain去调用不同函数def augment_image(self, image):return self.image_pipeline(image)def deaugment_mask(self, mask):return self.mask_pipeline(mask)def deaugment_label(self, label):return self.label_pipeline(label)def deaugment_keypoints(self, keypoints):return self.keypoints_pipeline(keypoints)

1.2.4 Chain

class Chain:# 实际Chain就是一系列方法的列表def __init__(self,functions: List[callable]  # Callable 类型是可以被执行调用操作的类型。):                                 # 参考https://www.jianshu.com/p/429f00040555?            #utm_campaign=maleskine&utm_content=note&utm_medium=seo_notes&utm_source=recommendationself.functions = functions or []# 让Chain实例对象变为可调用的def __call__(self, x):for f in self.functions:x = f(x)return x

其实主要功能实现都在base这个文件里,因此外部调用的接口都是类似写法,主要还是base里的几个类的用法。

Pytorch TTA(预测增强) 源码阅读相关推荐

  1. pytorch load state dict_pytorch源码阅读(二)optimizer原理

    pytorch包含多种优化算法用于网络参数的更新,比如常用的SGD.Adam.LBFGS以及RMSProp等.使用中可以发现各种优化算法的使用方式几乎相同,是因为父类optimizer[1]定义了各个 ...

  2. pytorch版TTA 源码阅读2

    TTA 源码阅读2 1. Transforms.py 主要是图片增强方法的文件 首先要提一下的是,它这里transform的类继承DualTransform,而这个类又是完全继承上一篇解析过的Base ...

  3. bert模型简介、transformers中bert模型源码阅读、分类任务实战和难点总结

    bert模型简介.transformers中bert模型源码阅读.分类任务实战和难点总结:https://blog.csdn.net/HUSTHY/article/details/105882989 ...

  4. 源码阅读及理论详解《 Informer: Beyond Efficient Transformer for Long Sequence Time-Series Forecasting 》

    Informer论文:https://arxiv.org/pdf/2012.07436.pdf Informer源码:GitHub - zhouhaoyi/Informer2020: The GitH ...

  5. mybatis源码阅读

    说下mybatis执行一个sql语句的流程 执行语句,事务等SqlSession都交给了excutor,excutor又委托给statementHandler SimpleExecutor:每执行一次 ...

  6. 【NLP】NLP实战篇之bert源码阅读(run_classifier)

    本文主要会阅读bert源码 (https://github.com/google-research/bert )中run_classifier.py文件,已完成modeling.py.optimiza ...

  7. mybatis源码阅读(一):SqlSession和SqlSessionFactory

    转载自  mybatis源码阅读(一):SqlSession和SqlSessionFactory 一.接口定义 听名字就知道这里使用了工厂方法模式,SqlSessionFactory负责创建SqlSe ...

  8. 【Flink】Flink 源码阅读笔记(15)- Flink SQL 整体执行框架

    1.概述 转载:Flink 源码阅读笔记(15)- Flink SQL 整体执行框架 在数据处理领域,无论是实时数据处理还是离线数据处理,使用 SQL 简化开发将会是未来的整体发展趋势.尽管 SQL ...

  9. Alibaba Druid 源码阅读(五)数据库连接池 连接关闭探索

    Alibaba Druid 源码阅读(五)数据库连接池 连接关闭探索 简介 在上文中探索了数据库连接池的获取,下面接着初步来探索下数据库连接的关闭,看看其中具体执行了那些操作 连接关闭 下面的具体的代 ...

最新文章

  1. AndroidUI 视图动画-旋转动画效果 (RotateAnimation)
  2. 李沐团队提出最强ResNet改进版,多项任务达到SOTA | 已开源
  3. shell中条件判断语法与判断条件小结
  4. Centos 安装python 3.7 遇到 ModuleNotFoundError: No module named _ctypesmake [install] Error 1(亲测下面的红字内容)
  5. Asterisk Queue呼叫中心的实现
  6. lintcode-517-丑数
  7. shell脚本视频学习1
  8. java jdk 文档下载_JDK8 API文档(下载)
  9. 使用github安装atom插件
  10. SaaS-HRM 需求分析
  11. 斯坦福NLP名课带学详解 | CS224n 第5讲 - 句法分析与依存解析(NLP通关指南·完结)
  12. VelocityTracker笔记
  13. 黑马程序员——双列集合、泛型 笔记第十一篇
  14. mac打开注册机显示“您没有权限来打开应用程序
  15. java较全的面试题
  16. suparc服务器没信号,SupARC街机对战平台
  17. 计算机管理主分区改成逻辑分区,Win7将主分区变为逻辑分区的方法
  18. 什么是Heads-up displays(HUD)
  19. 复习IO流复制文件时,文件损坏并且文件变得超大(FileInputStream和FileOutputStream)数组复制
  20. 2017亿欧创新者年会暨第三届创新奖颁奖盛典 | 互联网行业公会

热门文章

  1. 计算机英语口语面试自我介绍,英语口语面试自我介绍范文
  2. Java项目校园兼职平台(三层架构+设计模式重构版)(含代码)
  3. 2022年湖北省高新技术企业申报材料以及认定条件汇总!
  4. 超好用的免费修图软件推荐
  5. Python、R和SAS哪个适合你?
  6. 轻松关闭QQ2007迷你首页
  7. mkv转mp4格式怎么转,5种便捷工具盘点
  8. sql 对查询出的 结果集 添加 自增序号列/排序列
  9. 嗨,年轻人,一定要做一个让自己不后悔的人哦!
  10. python将txt中的内容导入到excel