基于改进注意力机制的U-Net模型实现及应用(keras框架实现)
1.摘要
上节我们基于U-Net模型设计并实现了在医学细胞分割上的应用(ISBI 挑战数据集),并给出了模型的详细代码解释,在上个博客中,我们为了快速训练U-Net模型对其进行了缩减,将庞大的U-Net的转换为很小&的结构,导致其准确率才达到75%左右。为了进一步提高U-Net模型在细胞分割上的准确率,本文将主要研究两个方面:一是基于U-Net的原始模型结构进行改进,引入卷积注意力机制模块(CBAM)和Focal Tversky损失函数;二是引入深监督方法(DEEP SUPERVISION)及多尺度输入作为U-Net模型的原始输入,该模型被命名为DAMU-Net。为了进一步验证该模型的性能,我们同样在ISBI 挑战数据集上进行实验,并给出相应的实验结果。
2.相关技术概述
2.1 Focal Tversky损失函数
医学影像中存在很多的数据不平衡现象,使用不平衡数据进行训练会导致严重偏向高精度但低召回率(sensitivity)的预测,这是我们不希望的,特别是在医学应用中,假阴性比假阳性多更难容忍。而Tversky广义损失函数可以有效解决了三维全卷积深神经网络训练中数据不平衡的问题,在精度和召回率之间找到更好的平衡。与Focal loss相似,Focal Tversky Loss着重于通过通过调整超参数α和β,我们可以控制假阳性和假阴性之间的权衡。较大的β会使召回的准确性高于精确度(通过更加强调假阴性)。其公式如下:
2.2 深监督方法
所谓深监督(Deep Supervision),就是在深度神经网络的某些中间隐藏层加了一个辅助的分类器作为一种网络分支来对主干网络进行监督的技巧,用来解决深度神经网络训练梯度消失和收敛速度过慢等问题。 深监督作为一个训练trick在2014年就已经通过DSN(Deeply-Supervised Nets)提出来了.
通常而言,增加神经网络的深度可以一定程度上提高网络的表征能力,但随着深度加深,会逐渐出现神经网络难以训练的情况,其中就包括像梯度消失和梯度爆炸等现象。为了更好的训练深度网络,人们尝试给神经网络的某些层添加一些辅助的分支分类器来解决这个问题。这种辅助的分支分类器能够起到一种判断隐藏层特征图质量好坏的作用。其结构如下:
其中各个模块含义如下:
可以看到,图中在第四个卷积块之后添加了一个监督分类器作为分支。Conv4输出的特征图除了随着主网络进入Conv5之外,也作为输入进入了分支分类器。往往分支与主网络一起训练。
3.模型实现
3.1 基于卷积注意力机制的U-Net模型
3.2 基于卷积注意力机制和深监督的U-Net模型
其具体代码实现可以查看上篇博客:https://haosen.blog.csdn.net/article/details/117756027;
3.3 模型代码实现
def attn_reg(opt,input_size, lossfxn):img_input = Input(shape=input_size, name='input_scale1')scale_img_2 = AveragePooling2D(pool_size=(2, 2), name='input_scale2')(img_input)scale_img_3 = AveragePooling2D(pool_size=(2, 2), name='input_scale3')(scale_img_2)scale_img_4 = AveragePooling2D(pool_size=(2, 2), name='input_scale4')(scale_img_3)conv1 = UnetConv2D(img_input, 32, is_batchnorm=True, name='conv1')pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)input2 = Conv2D(64, (3, 3), padding='same', activation='relu', name='conv_scale2')(scale_img_2)input2 = concatenate([input2, pool1], axis=3)conv2 = UnetConv2D(input2, 64, is_batchnorm=True, name='conv2')pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)input3 = Conv2D(128, (3, 3), padding='same', activation='relu', name='conv_scale3')(scale_img_3)input3 = concatenate([input3, pool2], axis=3)conv3 = UnetConv2D(input3, 128, is_batchnorm=True, name='conv3')pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)input4 = Conv2D(256, (3, 3), padding='same', activation='relu', name='conv_scale4')(scale_img_4)input4 = concatenate([input4, pool3], axis=3)conv4 = UnetConv2D(input4, 64, is_batchnorm=True, name='conv4')pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)center = UnetConv2D(pool4, 512, is_batchnorm=True, name='center')g1 = UnetGatingSignal(center, is_batchnorm=True, name='g1')attn1 = AttnGatingBlock(conv4, g1, 128, '_1')up1 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(center), attn1], name='up1')g2 = UnetGatingSignal(up1, is_batchnorm=True, name='g2')attn2 = AttnGatingBlock(conv3, g2, 64, '_2')up2 = concatenate([Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up1), attn2], name='up2')g3 = UnetGatingSignal(up1, is_batchnorm=True, name='g3')attn3 = AttnGatingBlock(conv2, g3, 32, '_3')up3 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up2), attn3], name='up3')up4 = concatenate([Conv2DTranspose(32, (3,3), strides=(2,2), padding='same', activation='relu', kernel_initializer=kinit)(up3), conv1], name='up4')conv6 = UnetConv2D(up1, 256, is_batchnorm=True, name='conv6')conv7 = UnetConv2D(up2, 128, is_batchnorm=True, name='conv7')conv8 = UnetConv2D(up3, 64, is_batchnorm=True, name='conv8')conv9 = UnetConv2D(up4, 32, is_batchnorm=True, name='conv9')out6 = Conv2D(1, (1, 1), activation='sigmoid', name='pred1')(conv6)out7 = Conv2D(1, (1, 1), activation='sigmoid', name='pred2')(conv7)out8 = Conv2D(1, (1, 1), activation='sigmoid', name='pred3')(conv8)out9 = Conv2D(1, (1, 1), activation='sigmoid', name='final')(conv9)model = Model(inputs=[img_input], outputs=[out6, out7, out8, out9])loss = {'pred1':lossfxn,'pred2':lossfxn,'pred3':lossfxn,'final': losses.tversky_loss}loss_weights = {'pred1':1,'pred2':1,'pred3':1,'final':1}model.compile(optimizer=opt, loss=loss, loss_weights=loss_weights,metrics=[losses.dsc])model.summary()from keras.utils.vis_utils import plot_modelplot_model(model, to_file='model1.png', show_shapes=True)return model
模型参数结构图(点击观看)
4. 实验结果
模型 |
DES |
U-Net |
0.878 |
ATT-U-Net |
? |
DATT-U-Net |
? |
DAMU-Net |
? |
还有一些结果正在用CPU运行,太慢了.....
基于改进注意力机制的U-Net模型实现及应用(keras框架实现)相关推荐
- IJCAI 2019 | 为推荐系统生成高质量的文本解释:基于互注意力机制的多任务学习模型...
编者按:在个性化推荐系统中,如果能在提高推荐准确性的同时生成高质量的文本解释,将更容易获得用户的"芳心".然而,现有方法通常将两者分开优化,或只优化其中一个目标.为了同时兼顾二者, ...
- 基于频谱注意力机制和编码解码模型的时间序列分类研究
文章来源 浙江大学 2021年硕士论文 小论文 IEEE Spectrum Attention Mechanism for Time Series Classification 1 摘要 本文贡献 时 ...
- HAN:基于双层注意力机制的异质图深度神经网络
「论文访谈间」是由 PaperWeekly 和中国中文信息学会社会媒体处理专委会(SMP)联合发起的论文报道栏目,旨在让国内优质论文得到更多关注和认可. 图神经网络是近年来图数据挖掘领域的热门研究方向 ...
- 城市异常事件精确预测:基于交互注意力机制的时空数据预测模型
点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 精确实现城市中异常事件的预测,在智能城市中智能交通以及城市公共安全的应用中起着重要的作用.为实现该目的,研究工作从以下两个重要层面对时空 ...
- 基于自注意力机制与无锚点的仔猪姿态识别(农业工程学报)
写在前面的话 该论文于2022年3月底投稿,9月初定稿,1次大改,小改不下10次,暑假几乎都在改论文中度过,非常感谢导师和农工编辑的耐心指导,同时感谢所有对本文作出贡献的实验室同伴,论文可在农业工程学 ...
- IJCAI论文 | 基于改进注意力循环控制门,品牌个性化排序升级系统来了
[小叽导读]:在淘宝网等电子商务网站中,品牌在用户对商品做点击.购买选择正起着越来越重要的作用,部分原因是用户现在越来越关注商品的质量,而品牌是商品质量的一个保证. 但是,现有的排序系统并不是针对用户 ...
- IJCAI论文 | 基于改进注意力循环控制门,品牌个性化排序升级系统来了...
[小叽导读]:在淘宝网等电子商务网站中,品牌在用户对商品做点击.购买选择正起着越来越重要的作用,部分原因是用户现在越来越关注商品的质量,而品牌是商品质量的一个保证. 但是,现有的排序系统并不是针对用户 ...
- IJCAI 阿里论文 | 基于改进注意力循环控制门,品牌个性化排序升级系统来了
阿里妹导读:在电子商务中,品牌在用户对商品做点击.购买选择正起着越来越重要的作用,部分原因是用户现在越来越关注商品的质量,而品牌是商品质量的一个保证. 但是,现有的排序系统并不是针对用户对品牌的偏好设 ...
- IJCAI 阿里论文 | 基于改进注意力循环控制门 品牌个性化排序升级
阿里妹导读:在电子商务中,品牌在用户对商品做点击.购买选择正起着越来越重要的作用,部分原因是用户现在越来越关注商品的质量,而品牌是商品质量的一个保证. 但是,现有的排序系统并不是针对用户对品牌的偏好设 ...
- SIGIR2020|基于自注意力机制和多鉴别器的序列推荐
Sequential Recommendation with Self-Attentive Multi-Adversarial Network https://arxiv.org/pdf/2005.1 ...
最新文章
- hdu5094(上海邀请赛E) 状态压缩bfs:取钥匙开门到目的地
- python3 hasattr getattr setattr delattr 对象属性 反射
- python必背内容-学 Python 必背的42个常见单词,看看你记住了几个?
- Tornado入门三
- 深度学习之卷积神经网络(11)卷积层变种
- P2053-修车【网络流,费用流】
- 记6分的交通违法行为
- @ResponseBody注解學習
- 原生js设置div隐藏或者显示_10种JS控制DIV的显示隐藏代码
- POI导出数据内存溢出问题
- httpd配置ResponseHeader
- mysql预编译表名_JDBC预编译语句表名占位异常
- 如何批量更改Mac视频帧速率
- 外部中断器微型计算机课程设计,课程设计-电子时钟参考.doc
- Poco::TCPServer框架解析
- 计算机主机结构图片,电脑的组成的图文详解
- 【JZOJ 4623】搬运干草捆
- android 不如 ios,安卓永远不如iOS运行流畅的根本原因
- Bug的级别,按照什么划分
- Handler详解(中)
热门文章
- 论文翻译《Computer Vision for Autonomous Vehicles Problems, Datasets and State-of-the-Art》(第六章)
- Apache网站根目录
- 新型城镇化:智慧城市成亮点
- 微信小程序云端图片上传,存储,获取,显示
- 障碍期权定价 python_Python二项期权定价
- 意创坊-移动富媒体平台
- 快捷打开mysql_Windows 平台快速启动MYSQL的方法
- Android仿虾米音乐播放器之布局介绍
- Cassandra Secondary Index 介绍
- python xlwt 写入Excel