EvoNorms_PyTorch

https://github.com/lonePatient/EvoNorms_PyTorch

原版说精度提升了一个多点,但是内存占用比原来大了很多,也变慢了

import torch
import torch.nn as nn
from torch.nn import init
from torch.nn.parameter import Parameterdef instance_std(x, eps=1e-5):N,C,H,W = x.size()x1 = x.reshape(N*C,-1)var = x1.var(dim=-1, keepdim=True)+epsreturn var.sqrt().reshape(N,C,1,1)def group_std(x, groups, eps = 1e-5):N, C, H, W = x.size()x1 = x.reshape(N,groups,-1)var = (x1.var(dim=-1, keepdim = True)+eps).reshape(N,groups,-1)return (x1 / var.sqrt()).reshape(N,C,H,W)class BatchNorm2dRelu(nn.Module):def __init__(self,in_channels):super(BatchNorm2dRelu,self).__init__()self.layer = nn.Sequential(nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True))def forward(self, x):output = self.layer(x)return outputclass EvoNorm2dB0(nn.Module):def __init__(self,in_channels,nonlinear=True,momentum=0.9,eps = 1e-5):super(EvoNorm2dB0, self).__init__()self.nonlinear = nonlinearself.momentum = momentumself.eps = epsself.gamma = Parameter(torch.Tensor(1,in_channels,1,1))self.beta = Parameter(torch.Tensor(1,in_channels,1,1))if nonlinear:self.v = Parameter(torch.Tensor(1,in_channels,1,1))self.register_buffer('running_var', torch.ones(1, in_channels, 1, 1))self.reset_parameters()def reset_parameters(self):init.ones_(self.gamma)init.zeros_(self.beta)if self.nonlinear:init.ones_(self.v)def forward(self, x):N, C, H, W = x.size()if self.training:x1 = x.permute(1, 0, 2, 3).reshape(C, -1)var = x1.var(dim=1).reshape(1, C, 1, 1)self.running_var.copy_(self.momentum * self.running_var + (1 - self.momentum) * var)else:var = self.running_varif self.nonlinear:den = torch.max((var+self.eps).sqrt(), self.v * x + instance_std(x))return x / den * self.gamma + self.betaelse:return x * self.gamma + self.betaclass EvoNorm2dS0(nn.Module):def __init__(self,in_channels,groups=8,nonlinear=True):super(EvoNorm2dS0, self).__init__()self.nonlinear = nonlinearself.groups = groupsself.gamma = Parameter(torch.Tensor(1,in_channels,1,1))self.beta = Parameter(torch.Tensor(1,in_channels,1,1))if nonlinear:self.v = Parameter(torch.Tensor(1,in_channels,1,1))self.reset_parameters()def reset_parameters(self):init.ones_(self.gamma)init.zeros_(self.beta)if self.nonlinear:init.ones_(self.v)def forward(self, x):if self.nonlinear:num = torch.sigmoid(self.v * x)std = group_std(x,self.groups)return num * std * self.gamma + self.betaelse:return x * self.gamma + self.beta

归一化EvoNorms相关推荐

  1. 超越BN-ReLU!谷歌大脑等提出EvoNorms:归一化激活层的进化

    点击上方"视学算法",选择加"星标"或"置顶" 重磅干货,第一时间送达 本文转载自:机器之心  | 作者:Hanxiao Liu等 最近,谷 ...

  2. 归一化激活层的进化:谷歌Quoc Le等人利用AutoML 技术发现新型ML模块

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 最近,谷歌大脑团队和 DeepMind 合作发布了一篇论文,利用 AutoML 技 ...

  3. pytorch 归一化_用PyTorch进行语义分割

    点击上方"机器学习与生成对抗网络",关注"星标" 获取有趣.好玩的前沿干货! 木易 发自 凹非寺  量子位 报道 | 公众号 QbitAI 很久没给大家带来教程 ...

  4. 关于使用sklearn进行数据预处理 —— 归一化/标准化/正则化

    20220121 z-score标准化 模型存储和load再调用其实没有关系 再load计算的时候,也是以实际的数据重新计算 并不是以save模型的边界来计算的 20211227 onehot训练集保 ...

  5. 一篇文章告诉你标准化和归一化的区别?

    一篇文章告诉你标准化和归一化的区别? 2019-02-28 17:12:39 融融网融融网阅读量:484 进一步推进企业的标准化工作,使之发展水平适应经济全球化下市场竞争的要求,促进企业综合实力的提升 ...

  6. 机器学习——标准化/归一化的目的、作用和场景

    对每个特征进行归一化处理,使得每个特征的取值缩放到0~1之间.这样做有两个好处: 模型训练更高效. 特征前的权重大小可代表该变量对预测结果的贡献度(因为每个特征值本身的范围相同). (一)归一化的作用 ...

  7. 机器学习入门(13)— Affine 仿射层、Softmax 归一化指数函数层实现

    1. 一维 Affine 仿射层 我们回顾下之前为了计算加权信号的总和,使用了矩阵的乘积运算 NumPy 中是 np.dot() , 参照代码如下: In [7]: X = np.random.ran ...

  8. 批标准归一化(Batch Normalization)解析

    1,背景 网络一旦train起来,那么参数就要发生更新,除了输入层的数据外(因为输入层数据,我们已经人为的为每个样本归一化),后面网络每一层的输入数据分布是一直在发生变化的,因为在训练的时候,前面层训 ...

  9. AlexNet中的局部响应归一化(LRN)

    1,简介 局部响应归一化(Local Response Normalization,LRN),提出于2012年的AlexNet中.首先要引入一个神经生物学的概念:侧抑制(lateral inhibit ...

最新文章

  1. Laravel 加载第三方类库的方法
  2. python在excel中的应用-python中的excel操作
  3. Kali Linux GRUB修复
  4. 达摩院 AI 进入中国科技馆,首张 AI 识别新冠 CT 成科技抗疫历史见证
  5. 中断处理函数_ARM的中断处理 [二]
  6. 带有.NET Core 3和Electron.NET的多平台桌面HTML编辑器
  7. ASP.NET MVC学习---(一)ORM框架,EF实体数据模型简介
  8. 吴恩达机器学习(六)神经网络(前向传播)
  9. NUC1076 LCD-Display【打印图案】
  10. VMWaer克隆centos后网络的问题解决
  11. css实现LED液晶数码字体
  12. 联想服务器风扇智能调节,联想怎么调风扇转速
  13. gradle下载很慢的解决方式
  14. 用极大似然法估计因子载荷矩阵_关于因子分析|stata
  15. ubuntu 安装 网易云音乐
  16. 1人30天44587行代码,分享舍得网开发过程
  17. MySQL数据库维护手册
  18. 软考高项 : (22)2016年下半年论文写作真题
  19. 【非原创】Ubuntu14.04+cuda6.5+opencv2.4.9+caffe配置记录
  20. 5口千兆工业以太网交换机宽温导轨式二层非网管全千兆工业级交换机

热门文章

  1. Android--AudioManager控制音量
  2. 新手坐高铁怎么找车厢_一女子坐高铁回桂平坐过站,到了平南南站,怎么办?...
  3. linux系统编码修改
  4. 计算机原理指令系统测试卷,计算机组成原理(下)第7章 指令系统测试
  5. oracle 某天 减一天,案例一:shell脚本指定日期减去一天
  6. python多态的例子_Python编程之多态用法实例详解
  7. Mysql的sql注入_MySQL SQL注入
  8. 技师学院计算机老师,技师学院计算机教学课堂改革探索论文
  9. linux redis ruby,redis requires ruby version 2.2.2的解决方案
  10. java webdriver page object_Selenium+PageObject+Java实现测试用例