Attention U-Net网络pytorch构建

from torch import nn
from torch.nn import functional as F
import torch
from torchvision import models
import torchvisionclass conv_block(nn.Module):def __init__(self,ch_in,ch_out):super(conv_block,self).__init__()self.conv = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True),nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.conv(x)return xclass up_conv(nn.Module):def __init__(self,ch_in,ch_out):super(up_conv,self).__init__()self.up = nn.Sequential(nn.Upsample(scale_factor=2),nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),nn.BatchNorm2d(ch_out),nn.ReLU(inplace=True))def forward(self,x):x = self.up(x)return xclass Attention_block(nn.Module):def __init__(self, F_g, F_l, F_int):super(Attention_block, self).__init__()self.W_g = nn.Sequential(nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),nn.BatchNorm2d(F_int))self.W_x = nn.Sequential(nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),nn.BatchNorm2d(F_int))self.psi = nn.Sequential(nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),nn.BatchNorm2d(1),nn.Sigmoid())self.relu = nn.ReLU(inplace=True)def forward(self, g, x):g1 = self.W_g(g)x1 = self.W_x(x)psi = self.relu(g1 + x1)psi = self.psi(psi)return x * psiclass AttU_Net(nn.Module):def __init__(self, img_ch=3, output_ch=1):super(AttU_Net, self).__init__()self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.Conv1 = conv_block(ch_in=img_ch, ch_out=64)self.Conv2 = conv_block(ch_in=64, ch_out=128)self.Conv3 = conv_block(ch_in=128, ch_out=256)self.Conv4 = conv_block(ch_in=256, ch_out=512)self.Conv5 = conv_block(ch_in=512, ch_out=1024)self.Up5 = up_conv(ch_in=1024, ch_out=512)self.Att5 = Attention_block(F_g=512, F_l=512, F_int=256)self.Up_conv5 = conv_block(ch_in=1024, ch_out=512)self.Up4 = up_conv(ch_in=512, ch_out=256)self.Att4 = Attention_block(F_g=256, F_l=256, F_int=128)self.Up_conv4 = conv_block(ch_in=512, ch_out=256)self.Up3 = up_conv(ch_in=256, ch_out=128)self.Att3 = Attention_block(F_g=128, F_l=128, F_int=64)self.Up_conv3 = conv_block(ch_in=256, ch_out=128)self.Up2 = up_conv(ch_in=128, ch_out=64)self.Att2 = Attention_block(F_g=64, F_l=64, F_int=32)self.Up_conv2 = conv_block(ch_in=128, ch_out=64)self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)self.sigmoid = nn.Sigmoid()def forward(self, x):# encoding pathx1 = self.Conv1(x)x2 = self.Maxpool(x1)x2 = self.Conv2(x2)x3 = self.Maxpool(x2)x3 = self.Conv3(x3)x4 = self.Maxpool(x3)x4 = self.Conv4(x4)x5 = self.Maxpool(x4)x5 = self.Conv5(x5)d5 = self.Up5(x5)x4 = self.Att5(g=d5, x=x4)d5 = torch.cat((x4, d5), dim=1)d5 = self.Up_conv5(d5)d4 = self.Up4(d5)x3 = self.Att4(g=d4, x=x3)d4 = torch.cat((x3, d4), dim=1)d4 = self.Up_conv4(d4)d3 = self.Up3(d4)x2 = self.Att3(g=d3, x=x2)d3 = torch.cat((x2, d3), dim=1)d3 = self.Up_conv3(d3)d2 = self.Up2(d3)x1 = self.Att2(g=d2, x=x1)d2 = torch.cat((x1, d2), dim=1)d2 = self.Up_conv2(d2)d1 = self.Conv_1x1(d2)d1 = self.sigmoid(d1)return d1

Attention U-Net网络相关推荐

  1. 基于Attention机制的轻量级网络架构以及代码实现

    点击上方,选择星标或置顶,不定期资源大放送! 阅读大概需要10分钟 Follow小博主,每天更新前沿干货 导读 之前详细介绍了轻量级网络架构的开源项目,详情请看深度学习中的轻量级网络架构总结与代码实现 ...

  2. 【自然语言处理】2. Attention实现详细解析( tfa, keras 方法调用源码分析 自建网络)

    NLP系列讲解笔记 本专题是针对NLP的一些常用知识进行记录,主要由于本人接下来的实验需要用到NLP的一些知识点,但是本人非NLP方向学生,对此不是很熟悉,所以打算做个笔记记录一下自己的学习过程,也是 ...

  3. Occluded Pedestrian Detection Through Guided Attention in CNNs 论文总结

    概述  行人检测在过去几年中取得了显著进展.然而行人检测的遮挡问题仍然是研究的重点和难点,因为行人外表因遮挡模式的不同而有很大差异.在本文中,提出一种基于Faster-rcnn 方法的一种遮挡行人检测 ...

  4. 全民 Transformer (一): Attention 在深度学习中是如何发挥作用的

    <Attention 在深度学习中是如何发挥作用的:理解序列模型中的 Attention>    Transformer 的出现让 Deep Learning 出现了大一统的局面.Tran ...

  5. 【论文笔记】D2A U-Net: Automatic segmentation of COVID-19 CT slices based on dual attention and hybrid di

    声明 不定期更新自己精度论文,通俗易懂,初级小白也可以理解 涉及范围:深度学习方向,包括 CV.NLP.Data Fusion.Digital Twin 论文标题:D2A U-Net: Automat ...

  6. Attention机制【图像】

    1. 什么是Attention机制? 其实我没有找到attention的具体定义,但在计算机视觉的相关应用中大概可以分为两种: 1)学习权重分布:输入数据或特征图上的不同部分对应的专注度不同,对此Ja ...

  7. 【论文阅读】Online Attention Accumulation for Weakly Supervised Semantic Segmentation

    一篇弱监督分割领域的论文,其会议版本为: (ICCV2019)Integral Object Mining via Online Attention Accumulation 论文标题: Online ...

  8. Integral Object Mining via Online Attention Accumulation

    Integral Object Mining via Online Attention Accumulation 摘要 1. Introduction 2. Related Work 3. Metho ...

  9. Mobile-Former来了!微软提出:MobileNet+Transformer轻量化并行网络

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 转载自:集智书童 Mobile-Former: Bridging MobileNet and Transfo ...

  10. CVPR 2022 | Mobile-Former来了!微软提出:MobileNet+Transformer轻量化并行网络

    点击下方卡片,关注"CVer"公众号 AI/CV重磅干货,第一时间送达 转载自:集智书童 Mobile-Former: Bridging MobileNet and Transfo ...

最新文章

  1. Python3 Scrapy爬取猫眼TOP100代码示例
  2. keyloadtool_keytool:术语“keytool”无法识别为cmdlet、函数、脚本文件或可操作程序的名称...
  3. 28Python正则表达式、正则表达式对象、正则表达式修饰符、表达式模板、表达式实例、match函数、search方法、检索和替换、repl、compile、findall等
  4. redis在linux搭建集群,Linux/Centos 7 redis4 集群搭建
  5. 关东升的iOS实战系列图书 《iOS实战:传感器卷(Swift版)》已经上市
  6. json取数据怎么取_干货速递丨书名应该怎么取?
  7. Thinkphp V5.X 远程代码执行漏洞 - POC(精:集群5.0*、5.1*、5.2*)
  8. hat怎么安装mysql_Red Hat Enterprise Linux中怎么安装Mysql+apache+php+zend
  9. 处理工行b2c上海机构问题反思
  10. NET常出现的三十三种代码(1)
  11. Atitit nodejs db api// 加载 mysql modulevar sys = require(“sys“);var mysql = require(“mysql“);va
  12. android 大牛播放组件,大牛直播Android播放端SDK说明
  13. linux返回上一行命令行,linux命令行编辑快捷键
  14. ValueError: n_splits=n cannot be greater than the number of members in each class.
  15. Node.js报错:UnhandledPromiseRejectionWarning: Unhandled promise rejection
  16. 关于文件夹的手动隐藏和恢复
  17. HTML Hover 的巧用。
  18. 四-python爬虫学习--下载电视剧
  19. NCAE(全国工业和信息化应用人才考试 )-- 服务外包 JAVA 软件开发复习整理(二)
  20. 长沙现象-互联网教育行业

热门文章

  1. ASP.NET程序中常用代码汇总-1[转]
  2. wp load.php下载,WP 代码分析:wp-load.php
  3. 如何做好技术演讲-口才提升篇章
  4. 知识蒸馏Knownledge Distillation
  5. 【Python】Django集成Github登陆
  6. JDK7新特性之try-with-resources
  7. 使用putty连接linux服务器拒绝,使用Putty远程连接Linux系统遇到的问题及解决方法...
  8. 【案例篇4】HTML+CSS实现漂亮的套餐价格表页面演示(源码)
  9. Hyperchain 超块链创始人史兴国对谈李国权:为什么新加坡能在全球“Web3桥头堡抢夺战”中突出重围?
  10. 半入耳TWS耳机有哪些?半入耳TWS耳机推荐