1. Group Normalization 介绍

Batch Normalization(BN)称为批量归一化,可加速网络收敛利于网络训练。但BN的误差会随着批量batch的减小而迅速增大。FAIR 研究工程师吴育昕和研究科学家何恺明合作的一篇论文 提出了一种新的与批量无关的Normalization 方法-[[1803.08494] Group Normalization]。GN 的主要工作是将通道分成组,并在每组内计算归一化的均值和方差。GN 的计算与批量大小无关,并且其准确度在各种批量大小下都很稳定。具体如下图(摘自论文):

BN 时的小批量会导致批量数据的统计两估算不准确,会显著增加模型误差。而无批量无关的GN方法得到的误差则相对稳定。

2. Keras自定义层方法

关于Keras中如何自定义层,可参考官方中文文档[编写你自己的层 - Keras 中文文档 ],[Keras简单自定义层例子]。自定义层中主要包括4种方法:

  • __init__(**kwargs):初始化方法,关键字参数保留,否则自定义层加载会报错。
  • build(input_shape):用于定义权重的方法
  • call(x): 自定义层具体功能的实现方法
  • get_config : 返回一个字典,获取当前层的参数信息。自定义层保存和加载时需要定义
  • compute_output_shape(input_shape):用于Keras可以自动推断shape

自定义层的保存和加载需要注意以下3点:

  • __init__(self, arg, **kwargs)初始化方法中关键字参数保留,否则自定义层加载会报错。

缺少**kwargs,TypeError: __init__() got an unexpected keyword argument 'name'

  • get_config(self)方法需要重写,否则网络结构无法保存。父类的config也需一并保存,将父类及继承类的config组装为字典形式,继承类config依据__init__方法传入的参数而定,具体如下:

缺少get_config方法,NotImplementedError: Layers with arguments in `__init__` must override `get_config`.

def get_config(self):base_config = super(LayerName, self).get_config() #父类config字典base_config['arg'] = self.arg #继承类config字典,__init__传入参数argreturn base_config #返回组装后的字典

  • load_model()需为custom_objects参数赋值

缺少custom_objects,ValueError: Unknown layer: LayerName

_custom_objects = {"LayerName":LayerName} #定义custom_objects
model = keras.models.load_model(model_path, custom_objects=_custom_objects) #加载模型


Keras官网提供了两种Normalization的源码,分别是:

  • 批量归一化 Keras-BatchNormalization
  • 实例归一化 Keras-InstanceNormalization

两者的不同在于IN的统计量估算是批量无关的基于单张图片单个通道,不需要用滑动平均项来记录全局的统计量,体现在源码的差异为:

# BN code
class BatchNormalization(Layer):def __init__(self,**kwargs):super(BatchNormalization, self).__init__(**kwargs)...def build(self, input_shape):...self.gamma = self.add_weight(...)self.beta = self.add_weight(...)self.moving_mean = self.add_weight(trainable=False)self.moving_variance = self.add_weight(trainable=False)...def call(sekf, x):...self.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),K.moving_average_update(self.moving_variance,variance,self.momentum)],inputs)...return K.in_train_phase(...)

# IN code
class BatchNormalization(Layer):def __init__(self,**kwargs):super(InstanceNormalization, self).__init__(**kwargs)...def build(self, input_shape):...self.gamma = self.add_weight(...)self.beta = self.add_weight(...)...

解释:

  • 所有自定义层都需要继承基础层Layer,并添加super().__init__(**kwargs)
  • **kwargs代表以字典方式继承父类
  • self.add_weight()是继承层Layer的方法,用于为变量添加权重,其中有参数trainable代表该参数的权重是否为可训练权重; 若trainable==True时,会执行self._trainable_weights.append(weight).
  • BN中需要添加moving_mean/variance滑动平均项的权重,且需要设置trainable==False,即为非训练参数。
  • self.add_update()用于更新滑动平均项
  • K.in_train_phase()针对训练状态选择不同的mean/variance计算BN

3. 定义Group Normalization层

源代码位置Bingohong/GroupNormalization-tensorflow-keras,里面包含了2个GN文件,分别是tensorflow和keras的实现版本,其中都包含了moving_average操作。

其实关于GN操作,是否需要apply moving_average是值得商榷的,论文中貌似没有明确提及,其他实现版本中的实现都是无moving_average操作。但通过对比IN、BN和GN特点及后期的实验对比,觉得GN应该是不需要moving_average操作的。因此这部分内容包括:

  • 主要介绍有moving_average操作的GN层的定义过程,而无moving_average操作时,只需要将对应的代码去掉。
  • 使用BN/GN_with_moving_average/GN_without_moving_average3种Normalization方法,对比U-net的实验结果。

keras GN层

完整代码在这里,以下仅解释部分关键代码。

# GN_with_moving_average code
class GroupNormalization(Layer):def __init__(self,**kwargs):super(GroupNormalization, self).__init__(**kwargs)...def build(self, input_shape):...shape = (self.groups,)broadcast_shape = [-1, self.groups, 1, 1, 1]# 添加滑动平均项参数,并设置为非训练参数# 后续的K.reshape和K.variable操作,是为了在call()方法内进行add_update()时保证# self.moving_mean/variance的维度与inputs一致,且为variable变量self.moving_mean = self.add_weight(shape=shape,trainable=False)self.moving_mean = K.reshape(self.moving_mean,broadcast_shape)self.moving_mean = K.variable(value=self.moving_mean)self.moving_variance = self.add_weight(shape=shape,trainable=False)self.moving_variance = K.reshape(self.moving_variance,broadcast_shape)self.moving_variance = K.variable(value=self.moving_variance)...def call(sekf, inputs):G = self.groups# transpose:[ba,h,w,c] -> [bs,c,h,w]if self.axis in {-1,3}:inputs = K.permute_dimensions(inputs,(0,3,1,2))# GN操作需要根据groups对通道分组input_shape = K.int_shape(inputs)N, C, H, W = input_shapeinputs = K.reshape(inputs,(-1, G, C // G, H, W))#inputs.assign_sub()# 计算分组通道的均值和方差gn_mean = K.mean(inputs,axis=[2,3,4],keepdims=True)gn_variance = K.var(inputs,axis=[2,3,4],keepdims=True)# 当模型用于测试阶段时,使用moving_mean/variance记录的均值/方差def gn_inference():# when in test phase, just return moving_mean & moving_varmean, variance = self.moving_mean, self.moving_varianceoutputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))outputs = K.reshape(outputs,[-1, C, H, W]) * self.gamma + self.beta# transpose: [bs,c,h,w] -> [ba,h,w,c]if self.axis in {-1,3}:outputs = K.permute_dimensions(outputs,(0,2,3,1))return outputsif training in {0,False}:return gn_inference()# 当模型用于训练阶段时,使用分组通道实时计算均值/方差outputs = (inputs - gn_mean) / (K.sqrt(gn_variance + self.epsilon))outputs = K.reshape(outputs,[-1, C, H, W]) * self.gamma + self.beta # transpose: [bs,c,h,w] -> [ba,h,w,c]if self.axis in {-1,3}:outputs = K.permute_dimensions(outputs,(0,2,3,1))# 手动更新self.moving_mean/varianceself.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),K.moving_average_update(self.moving_variance,variance,self.momentum)],inputs)# 根据模型状态不同选择不同的GN计算方法,train时选择outputs,test时选择gn_inferencereturn K.in_train_phase(outputs, gn_inference,training=training)


实验对比结果

实验日志位于compare_log,包含3个文件:

  • train_bn.log -> unet+bn日志
  • train_gn_ema.log -> unet+gn(有moving_average操作)
  • train_gn_noema.log -> unet+gn(无moving_average操作)

结果说明:

  • gn without moving average 得到的val_loss会更低,可达到 0.2左右
  • gn with moving average 有时会一直存在很高的val_loss, 所以我觉得可能GN并不需要 apply moving average
  • bn 得到的val_ loss约为 0.26, 高于gn without moving average.

欢迎大家批评指正~ 谢谢谢谢~~~

自定义报错返回_Keras编写自定义层--以GroupNormalization为例相关推荐

  1. 自定义报错返回_MybatisPlus基础篇学习笔记(五)------自定义sql及分页查询

    本章目录 自定义sql 分页查询 1. 自定义sql 在dao文件中编写自定义接口,并在方法上使用注解形式注入SQL,如图所示: 第一种: 第二种 ① application.yml加入下面配置 my ...

  2. 自定义报错返回_Spring Cloud Feign的使用和自定义配置

    在上一篇文章 null:Spring Cloud 自定义Eureka Ribbon负载均衡策略​zhuanlan.zhihu.com 中,我们使用Ribbon自定义的策略实现了负载均衡,接下来我们介绍 ...

  3. Nuxt.js - 最新自定义报错、缺省、404、500 定制化 error.vue(页面、接口报错时自动跳转到该自定义页面)支持自定义文案、状态码等功能

    前言 在开发 Nuxt.js 时,当页面出错或接口后台数据返回异常时,页面就会 "直接呈现" 报错的信息. 正常情况下,当页面 404.500 或页面报错时, 前端应该 自动跳转到 ...

  4. spring boot整合SpringSecurity-03 自定义报错信息

    spring boot整合SpringSecurity 目录 spring boot整合SpringSecurity-01入门 spring boot整合SpringSecurity-02 基于Ser ...

  5. python def函数报错详解_python自定义函数def的应用详解

    这篇文章主要介绍了python自定义函数def的应用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧 这里是三岁,来和大家唠唠 ...

  6. 'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数

    在编写SVM中的Hinge loss函数的时候报错"'int' object has no attribute 'backward'" for epoch in range(50) ...

  7. Linux系统 安装飞桨PaddleHub+LAC实现词法分析 实现加载自定义词典分词 (解决Lac服务启动报错问题、解决自定义词典空格无法分词问题)

    1.先上链接:飞桨PaddlePaddle-源于产业实践的开源深度学习平台 2.LAC模型简介:Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型,能整体性 ...

  8. java 8流自定义收集器_Java 8编写自定义收集器简介

    java 8流自定义收集器 Java 8引入了收集器的概念. 大多数时候,我们几乎不使用Collectors类中的工厂方法,例如collect(toList()) , toSet()或其他更有趣的方法 ...

  9. $Ajax构成和请求报错返回值

    $.ajax({                 type: "post",                 url: sUrl,                 data: po ...

最新文章

  1. python网络编程——简单例子
  2. 网站制作基本要素了解一下
  3. 797C C. Minimal string
  4. MATLAB实战系列(十二)-如何用人工鱼群算法解决带时间窗车辆路径(CVRP)问题(附MATLAB代码)
  5. 10.ASCII码对照
  6. Qt编写OpenMP程序--HelloWorld
  7. solr 启动时指定 solr.home
  8. [react] 在React中如何避免不必要的render?
  9. 原来嵌套个网页的技术是这样的
  10. 国家生物信息中心在核酸研究发表单细胞DNA甲基化数据库—scMethBank
  11. 函数最值题目及答案_高考数学攻克压轴题:圆锥曲线取值范围和最值问题解题模型...
  12. 移动花卡服务器系统异常,开通了抖音移动花卡免流服务,为什么使用抖音不显示免流呢?...
  13. U盘, USB读卡器, U盘读卡器三者技术分析区别
  14. 关于GTP-4,这是14个被忽略的惊人细节!
  15. SVG格式文件插入Word/WPS,三种简单快捷的方法,实现图片高清无损
  16. 【NOI2015】BZOJ4199品酒大会题解(SAM+树形DP)
  17. 华为6 有没有计算机,华为手机连电脑没有usb存储 华为手机连电脑不显示usb存储怎么回事 - 云骑士一键重装系统...
  18. 我现在必须new一个对象!!!
  19. C#使用西门子S7 协议读写PLC DB块
  20. 2021全年营收净利润双增,李宁财报透露哪些确定与不定?

热门文章

  1. python3socket非阻塞_利用Python中SocketServer 实现客户端与服务器间非阻塞通信
  2. C++中string erase函数的使用
  3. 复制已有的Tomcat作为新的Tomcat,只需修改三个配置文件,五步操作,保证能正常运行!
  4. Maya游戏角色绑定入门学习教程 Game Character Rigging for Beginners in Maya
  5. 设计模式:简单工厂、工厂方法、抽象工厂之小结与区别
  6. Go 分布式学习利器(10)-- Go语言的接口
  7. [Java in NetBeans] Lesson 01. Java Programming Basics
  8. spark ml中一个比较通用的transformer
  9. IIS 7.5 去掉index.php 西数服务器
  10. WinForm绘制带有升序、降序的柱形图