在Pytorch中实现SMU激活函数
本文代码来源于githubuSMU源码链接

# coding=utf-8import torch
from torch import nnclass SMU(nn.Module):'''Implementation of SMU activation.Shape:- Input: (N, *) where * means, any number of additionaldimensions- Output: (N, *), same shape as the inputParameters:- alpha: hyper parameterReferences:- See related paper:https://arxiv.org/abs/2111.04682Examples:>>> smu = SMU()>>> x = torch.Tensor([0.6,-0.3])>>> x = smu(x)'''def __init__(self, alpha = 0.25):'''Initialization.INPUT:- alpha: hyper parameteraplha is initialized with zero value by default'''super(SMU,self).__init__()self.alpha = alpha# initialize muself.mu = torch.nn.Parameter(torch.tensor(1000000.0)) def forward(self, x):return ((1+self.alpha)*x + (1-self.alpha)*x*torch.erf(self.mu*(1-self.alpha)*x))/2class SMU1(nn.Module):'''Implementation of SMU-1 activation.Shape:- Input: (N, *) where * means, any number of additionaldimensions- Output: (N, *), same shape as the inputParameters:- alpha: hyper parameterReferences:- See related paper:https://arxiv.org/abs/2111.04682Examples:>>> smu1 = SMU1()>>> x = torch.Tensor([0.6,-0.3])>>> x = smu1(x)'''def __init__(self, alpha = 0.25):'''Initialization.INPUT:- alpha: hyper parameteraplha is initialized with zero value by default'''super(SMU1,self).__init__()self.alpha = alpha# initialize muself.mu = torch.nn.Parameter(torch.tensor(4.352665993287951e-9)) def forward(self, x):return ((1+self.alpha)*x+torch.sqrt(torch.square(x-self.alpha*x)+torch.square(self.mu)))/2def test_SMU(x):smu_activation = SMU()print(smu_activation(x))def test_SMU1(x):smu1_activation=SMU1()print(smu1_activation(x))def test():x = torch.Tensor([0.6,-0.3])test_SMU(x)test_SMU1(x)if __name__ == '__main__':test()

在Tensorflow中实现SMU激活函数

# coding=utf-8import tensorflow as tfdef SMU(x,alpha=0.25):mu = tf.compat.v1.get_variable('SMU_mu', shape=(),initializer=tf.constant_initializer(1000000),dtype=tf.float32)return ((1+alpha)*x + (1-alpha)*x*tf.math.erf(mu*(1-alpha)*x))/2def SMU1(x,alpha=0.25):mu = tf.compat.v1.get_variable('SMU1_mu', shape=(),initializer=tf.constant_initializer(4.352665993287951e-9),dtype=tf.float32)return ((1+alpha)*x+tf.math.sqrt(tf.math.square(x-alpha*x)+tf.math.square(mu)))/2def test_SMU(x):print(SMU(x))def test_SMU1(x):print(SMU1(x))def test():x = tf.convert_to_tensor(np.array([[-0.6],[0.6]]),dtype=tf.float32)test_SMU(x)test_SMU1(x)if __name__ == '__main__':test()

代码及原理讲解可参考博客

pytorch和tensorflow中实现SMU激活函数相关推荐

  1. Pytorch以及tensorflow中KLdivergence的计算

    1. KL divergence是什么 KL 散度是一个距离衡量指标,衡量的是两个概率分布之间的差异. y p r e d y_{pred} ypred​指的是模型的输出的预测概率,形如[0.35,0 ...

  2. tensorflow中Leaky Relu激活函数

    tensorflow中Leaky Relu激活函数 引用API:tensorflow.nn.leaky_relu(x) Leaky Relu激活函数 Leaky Relu激活函数引入一个固定斜率a,具 ...

  3. 深度学习PyTorch,TensorFlow中GPU利用率较低,使用率周期性变化的问题

    在用tensorflow训练神经网络时,发现训练迭代的速度时而快时而慢,监督的GPU使用率也是周期性变化,通过了解,发现原因是: GPU在等待CPU读取,预处理,并传输数据过来,因此要提高GPU的使用 ...

  4. SELU︱在keras、tensorflow中使用SELU激活函数

    arXiv 上公开的一篇 NIPS 投稿论文<Self-Normalizing Neural Networks>引起了圈内极大的关注,它提出了缩放指数型线性单元(SELU)而引进了自归一化 ...

  5. 深度学习PyTorch,TensorFlow中GPU利用率较低,CPU利用率很低,且模型训练速度很慢的问题总结与分析

    在深度学习模型训练过程中,在服务器端或者本地pc端,输入nvidia-smi来观察显卡的GPU内存占用率(Memory-Usage),显卡的GPU利用率(GPU-util),然后采用top来查看CPU ...

  6. 深度学习PyTorch、TensorFlow中GPU利用率与内存占用率很低的问题

    上周,在一个使用Pytorch搭建的目标训练项目中,训练时,通过使用命令行执行NVIDIA-SMI(仅支持英伟达显卡)命令发现GPU的利用率基本一直停留在0%,并且显存占用率也较低.CSDN上有一篇分 ...

  7. 编写同时在PyTorch和Tensorflow上工作的代码

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 ❝ "库开发人员不再需要在框架之间进行选择." ...

  8. 【深度学习】编写同时在PyTorch和Tensorflow上工作的代码

    作者 | Ram Sagar 编译 | VK 来源 | Analytics In Diamag ❝ "库开发人员不再需要在框架之间进行选择." ❞ 来自德国图宾根人工智能中心的研究 ...

  9. [TensorFlow 学习笔记-06]激活函数(Activation Function)

    [版权说明] TensorFlow 学习笔记参考:  李嘉璇 著 TensorFlow技术解析与实战 黄文坚 唐源 著 TensorFlow实战郑泽宇  顾思宇 著 TensorFlow实战Googl ...

最新文章

  1. 基于吉日嘎拉的通用权限管理WebForm版扩展:字典选项管理和缓存管理
  2. Nginx学习笔记(二) Nginx--connectionrequest
  3. jquery 选择器 逗号
  4. ViT (Vision Transformer) ---- Text Generation(文本生成器)
  5. Redies安装,修配置,设置密码,
  6. Mxnet框架学习笔记(一):常用数据操作方法学习记录
  7. 每台计算机需要配置网关吗,怎么设置一台电脑作为网关
  8. 多元线性回归分析spss结果解读_多重线性回归的结果解读和报告(SPSS实例教程)...
  9. matlab分析启动子特征,文献编译 | 相对脑血容量(rCBV)可作为MGMT启动子甲基化阳性GBM的辅助预后指标...
  10. canvas 擦除动画_HTML5 实现橡皮擦的擦除效果
  11. choose标签使用
  12. 树莓派平台的ADXL345三轴加速度传感器编程
  13. 公众号第三方平台开发 教程五 代公众号处理消息和事件
  14. 多数据中心架构,异地多活架构
  15. Trizol法提取RNA实验步骤
  16. 怎么做手游性能测试?
  17. mysqld --defaults-file=/myfolder/my.cnf --defaults-extra-file=/myfolder2/my.cnf
  18. 袁永福对北京奥运会的评论
  19. 使用jQuery与后端进行数据传输代码示例
  20. resent101-DSSD报错solution

热门文章

  1. 因子分析factor analysis_spss运用_python建模(推荐AAA)
  2. Win7 Office Outlook客户端报没有默认的邮件客户端,或当前客户端无法实现该邮件的请求。
  3. 如何配置web服务器及发布网页
  4. 学日语、记单词是有规律的(转载)
  5. to 管理员:网站的“技术区文章列表RSS”有问题 我用GUSH连不上!
  6. UIC564-2 附录10 – 橡胶法兰产品的阻燃防火测试
  7. VMware 收费太贵? 试试这款更轻量级的虚拟机, 完全免费!
  8. DNS无法区域传送(axfr,ixfr)
  9. 机器翻译模型一多层LSTM__Pytorch实现
  10. layui中select及submit提交