链接:https://www.kaggle.com/leighplt/pytorch-tta-flip-left-right
tta 见过不少了,今天发现一个python的代码技巧记录一下

import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pdimport torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data
import torchvision
from torchvision import modelsimport cv2
from pathlib import Path
import glob#============= tta ===================
#这里的tta是可插拔的,用在训练和预测上都行,下面有使用的方法,一看就很明了,其中这里面的staticmethod返回函数的静态方法,
#该方法不强制要求传递参数,并且无需实例化就可以调用,也可以实例化调用,很灵活。
#实例化调用方法就是 形如:C = TTAFunction()  然后调用时C.tta()这样,不实例化的话可以直接TTAFunction.tta()
class TTAFunction:"""Simple TTA function"""@staticmethoddef hflip(x):return x.flip(3)@staticmethoddef vflip(x):return x.flip(2)def tta(self, x):self.eval()with torch.no_grad():result = self.forward(x)x = self.hflip(x)result += self.hflip(self.forward(x))return 0.5*result
#============= model ===================
def conv3x3(in_, out):return nn.Conv2d(in_, out, 3, padding=1)class ConvRelu(nn.Module):def __init__(self, in_, out):super().__init__()self.conv = conv3x3(in_, out)self.activation = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.activation(x)return xclass DecoderBlock(nn.Module):def __init__(self, in_channels, middle_channels, out_channels):super().__init__()self.block = nn.Sequential(ConvRelu(in_channels, middle_channels),nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),nn.ReLU(inplace=True))def forward(self, x):return self.block(x)class UNet11(TTAFunction, nn.Module): # use our class with TTA functiondef __init__(self, num_filters=32):""":param num_classes::param num_filters:"""super().__init__()self.pool = nn.MaxPool2d(2, 2)# Convolutions are from VGG11self.encoder = models.vgg11().features# "relu" layer is taken from VGG probably for generality, but it's not clear self.relu = self.encoder[1]self.conv1 = self.encoder[0]self.conv2 = self.encoder[3]self.conv3s = self.encoder[6]self.conv3 = self.encoder[8]self.conv4s = self.encoder[11]self.conv4 = self.encoder[13]self.conv5s = self.encoder[16]self.conv5 = self.encoder[18]self.center = DecoderBlock(num_filters * 8 * 2, num_filters * 8 * 2, num_filters * 8)self.dec5 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 8)self.dec4 = DecoderBlock(num_filters * (16 + 8), num_filters * 8 * 2, num_filters * 4)self.dec3 = DecoderBlock(num_filters * (8 + 4), num_filters * 4 * 2, num_filters * 2)self.dec2 = DecoderBlock(num_filters * (4 + 2), num_filters * 2 * 2, num_filters)self.dec1 = ConvRelu(num_filters * (2 + 1), num_filters)self.final = nn.Conv2d(num_filters, 1, kernel_size=1, )def forward(self, x):conv1 = self.relu(self.conv1(x))conv2 = self.relu(self.conv2(self.pool(conv1)))conv3s = self.relu(self.conv3s(self.pool(conv2)))conv3 = self.relu(self.conv3(conv3s))conv4s = self.relu(self.conv4s(self.pool(conv3)))conv4 = self.relu(self.conv4(conv4s))conv5s = self.relu(self.conv5s(self.pool(conv4)))conv5 = self.relu(self.conv5(conv5s))center = self.center(self.pool(conv5))# Deconvolutions with copies of VGG11 layers of corresponding size dec5 = self.dec5(torch.cat([center, conv5], 1))dec4 = self.dec4(torch.cat([dec5, conv4], 1))dec3 = self.dec3(torch.cat([dec4, conv3], 1))dec2 = self.dec2(torch.cat([dec3, conv2], 1))dec1 = self.dec1(torch.cat([dec2, conv1], 1))return torch.sigmoid(self.final(dec1))def unet11(**kwargs):model = UNet11(**kwargs)return modeldef get_model():np.random.seed(717)torch.cuda.manual_seed(717);torch.manual_seed(717);model = unet11()model.train()return model.to(device)
#============= use tta for predict===================
model = get_model()
model.load_state_dict(torch.load(model_pth)['state_dict'])test_dataset = TGSSaltDataset(test_path, test_file_list, is_test = True)  #这个函数原来链接里有all_predictions = []
for image in data.DataLoader(test_dataset, batch_size = 30):image = image[0].type(torch.FloatTensor).to(device)y_pred = model.tta(image).cpu().data.numpy() # use tta_flipall_predictions.append(y_pred)
all_predictions_stacked = np.vstack(all_predictions)[:, 0, :, :]

kaggle可插拔tta应用记录相关推荐

  1. 记录关于监听HDMI插拔广播

    记录关于监听HDMI插拔广播 hdmi的广播有两种,目前大部分文章讲诉的都是使用android.intent.action.HDMI_PLUGGED来监听hdmi插拔的状态变化,但是这个方法在高版本中 ...

  2. 删除u盘插拔记录linux,如何删除电脑里中的u盘使用记录

    五.打开注册表(在"开始"→"运行"输入框中输入"regedit"可直接打开注册表),注册表打开后依次在"编辑→查找", ...

  3. 删除u盘插拔记录linux,电脑u盘插拔记录_电脑u盘插拔时间记录

    2016-06-07 18:50:08 按照如下步骤进行处理,尝试解决问题. 1.USB接口损坏.接触不良也会出现此情况,换一个USB接口,台式机插入主机后接口试试. 2.插入电脑,右下角有图标,不显 ...

  4. k8s kubesphere启用可插拔组件(安装前、后均可)

    启用可插拔组件 本教程演示如何在安装前或安装后启用 KubeSphere 的可插拔组件.KubeSphere 具有以下列出的十个可插拔组件. 配置项 功能组件 描述 alerting KubeSphe ...

  5. sim插拔识别时间_特斯拉+树莓派实现车牌识别检测系统

    转自机器之心 | 作者:Robert Lucian Chiriac | 参与:王子嘉.思.一鸣 怎样在不换车的前提下打造一个智能车系统呢?一段时间以来,本文作者 Robert Lucian Chiri ...

  6. Foundatio - .Net Core用于构建分布式应用程序的可插拔基础块

    简介 Foundatio - 用于构建分布式应用程序的可插拔基础块 •想要针对抽象接口进行构建,以便我们可以轻松更改实现.希望这些块对依赖注入友好.•缓存:我们最初使用的是开源 Redis 缓存客户端 ...

  7. oracle关闭数据库容器,Oracle12cr1新特性之容器数据库(CDB)和可插拔数据库(PDB) 的启动和关闭...

    Oracle12c中引入的多宿主选项(multitenant option)允许一个容器数据库容纳多个独立的可插拔数据库(PDB).本文将说明如何启动和关闭容器数据库(CDB)和可插拔数据库(PDB) ...

  8. java 可插拔注解_20200311 8. 注解和可插拔性

    8. 注解和可插拔性 8.1 注解和可插拔性 在 web 应用中,使用注解的类仅当它们位于 WEB-INF/classes 目录中,或它们被打包到位于应用的WEB-INF/lib 中的 jar 文件中 ...

  9. 我心中的核心组件(可插拔的AOP)~大话开篇及目录

    我心中的核心组件(可插拔的AOP)~大话开篇及目录 http://www.cnblogs.com/lori/p/3247905.html 回到占占推荐博客索引 核心组件 我心中的核心组件,核心组件就是 ...

最新文章

  1. android Json详解
  2. NLTK基础教程学习笔记(一)
  3. Spring注解开发-@Scope作用域注解
  4. eclipse中properties文件编辑插件:PropertiesEditor
  5. RS-232转RS-485/422串口转换器产品介绍
  6. 数据科学中的数据可视化
  7. 【Python CheckiO 题解】First Word
  8. 计算机安全事故由谁整改,信息安全检查整改方案 整改方案 .doc
  9. php mysql密码验证_php 连接数据库 验证用户名密码
  10. 喜庆普通铁路也要跑动车了
  11. mysql怎么查看记录时间戳_mysql TIMESTAMP(时间戳)详解——查询最近一段时间操作的记录...
  12. Chango的数学Shader世界(十六)RayTrace三维分形(一)—— ue4中最简单的RayMarch
  13. XSS Overview
  14. 全闪存存储的数据库加速场景应用
  15. Unity与服务器通信方式有哪些?
  16. ARM架构SMMU驱动详解
  17. JDK8的介绍下载和安装(附网盘地址)
  18. 使用uEdit时,在线管理图片功能不可用
  19. leetcode刷题之 树(14)-递归:找出二叉树中第二小的节点
  20. 祝愿天下所有的有情人都终成眷属

热门文章

  1. 很多人嘲笑谷歌,这没什么好奇怪的。在娼妓云集的地方,贞洁的女子会受到诽谤和讥笑。
  2. 一文搞懂——软件模拟SPI
  3. adams功能区不显示_技巧 | Word 文本突出显示颜色,原来还隐藏了一种颜色
  4. win10 怎么禁用win键盘
  5. 详解可微神经网络架构搜索框架(DNAS)
  6. 护卫神mysql无法启动_MySQL降权:MySQL以Guests帐户启动设置方法_护卫神
  7. sEMG项目总结(4)康复手上位机系统
  8. JVM垃圾回收算法和回收器
  9. maven-war-plugin插件 overlays maven-war-plugin翻译
  10. 怎么在线把一个PDF文件分割为多个PDF文件