自定义报错返回_Keras编写自定义层--以GroupNormalization为例
1. Group Normalization 介绍
Batch Normalization(BN)称为批量归一化,可加速网络收敛利于网络训练。但BN的误差会随着批量batch的减小而迅速增大。FAIR 研究工程师吴育昕和研究科学家何恺明合作的一篇论文 提出了一种新的与批量无关的Normalization 方法-[[1803.08494] Group Normalization]。GN 的主要工作是将通道分成组,并在每组内计算归一化的均值和方差。GN 的计算与批量大小无关,并且其准确度在各种批量大小下都很稳定。具体如下图(摘自论文):
2. Keras自定义层方法
关于Keras中如何自定义层,可参考官方中文文档[编写你自己的层 - Keras 中文文档 ],[Keras简单自定义层例子]。自定义层中主要包括4种方法:
- __init__(**kwargs):初始化方法,关键字参数保留,否则自定义层加载会报错。
- build(input_shape):用于定义权重的方法
- call(x): 自定义层具体功能的实现方法
- get_config : 返回一个字典,获取当前层的参数信息。自定义层保存和加载时需要定义
- compute_output_shape(input_shape):用于Keras可以自动推断shape
自定义层的保存和加载需要注意以下3点:
- __init__(self, arg, **kwargs)初始化方法中关键字参数保留,否则自定义层加载会报错。
缺少**kwargs,TypeError: __init__() got an unexpected keyword argument 'name'
- get_config(self)方法需要重写,否则网络结构无法保存。父类的config也需一并保存,将父类及继承类的config组装为字典形式,继承类config依据__init__方法传入的参数而定,具体如下:
缺少get_config方法,NotImplementedError: Layers with arguments in `__init__` must override `get_config`.
def get_config(self):base_config = super(LayerName, self).get_config() #父类config字典base_config['arg'] = self.arg #继承类config字典,__init__传入参数argreturn base_config #返回组装后的字典
- load_model()需为custom_objects参数赋值
缺少custom_objects,ValueError: Unknown layer: LayerName
_custom_objects = {"LayerName":LayerName} #定义custom_objects
model = keras.models.load_model(model_path, custom_objects=_custom_objects) #加载模型
Keras官网提供了两种Normalization的源码,分别是:
- 批量归一化 Keras-BatchNormalization
- 实例归一化 Keras-InstanceNormalization
两者的不同在于IN的统计量估算是批量无关的基于单张图片单个通道,不需要用滑动平均项来记录全局的统计量,体现在源码的差异为:
# BN code
class BatchNormalization(Layer):def __init__(self,**kwargs):super(BatchNormalization, self).__init__(**kwargs)...def build(self, input_shape):...self.gamma = self.add_weight(...)self.beta = self.add_weight(...)self.moving_mean = self.add_weight(trainable=False)self.moving_variance = self.add_weight(trainable=False)...def call(sekf, x):...self.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),K.moving_average_update(self.moving_variance,variance,self.momentum)],inputs)...return K.in_train_phase(...)
# IN code
class BatchNormalization(Layer):def __init__(self,**kwargs):super(InstanceNormalization, self).__init__(**kwargs)...def build(self, input_shape):...self.gamma = self.add_weight(...)self.beta = self.add_weight(...)...
解释:
- 所有自定义层都需要继承基础层Layer,并添加super().__init__(**kwargs)
- **kwargs代表以字典方式继承父类
- self.add_weight()是继承层Layer的方法,用于为变量添加权重,其中有参数trainable代表该参数的权重是否为可训练权重; 若trainable==True时,会执行self._trainable_weights.append(weight).
- BN中需要添加moving_mean/variance滑动平均项的权重,且需要设置trainable==False,即为非训练参数。
- self.add_update()用于更新滑动平均项
- K.in_train_phase()针对训练状态选择不同的mean/variance计算BN
3. 定义Group Normalization层
源代码位置Bingohong/GroupNormalization-tensorflow-keras,里面包含了2个GN文件,分别是tensorflow和keras的实现版本,其中都包含了moving_average操作。
其实关于GN操作,是否需要apply moving_average是值得商榷的,论文中貌似没有明确提及,其他实现版本中的实现都是无moving_average操作。但通过对比IN、BN和GN特点及后期的实验对比,觉得GN应该是不需要moving_average操作的。因此这部分内容包括:
- 主要介绍有moving_average操作的GN层的定义过程,而无moving_average操作时,只需要将对应的代码去掉。
- 使用BN/GN_with_moving_average/GN_without_moving_average3种Normalization方法,对比U-net的实验结果。
keras GN层
完整代码在这里,以下仅解释部分关键代码。
# GN_with_moving_average code
class GroupNormalization(Layer):def __init__(self,**kwargs):super(GroupNormalization, self).__init__(**kwargs)...def build(self, input_shape):...shape = (self.groups,)broadcast_shape = [-1, self.groups, 1, 1, 1]# 添加滑动平均项参数,并设置为非训练参数# 后续的K.reshape和K.variable操作,是为了在call()方法内进行add_update()时保证# self.moving_mean/variance的维度与inputs一致,且为variable变量self.moving_mean = self.add_weight(shape=shape,trainable=False)self.moving_mean = K.reshape(self.moving_mean,broadcast_shape)self.moving_mean = K.variable(value=self.moving_mean)self.moving_variance = self.add_weight(shape=shape,trainable=False)self.moving_variance = K.reshape(self.moving_variance,broadcast_shape)self.moving_variance = K.variable(value=self.moving_variance)...def call(sekf, inputs):G = self.groups# transpose:[ba,h,w,c] -> [bs,c,h,w]if self.axis in {-1,3}:inputs = K.permute_dimensions(inputs,(0,3,1,2))# GN操作需要根据groups对通道分组input_shape = K.int_shape(inputs)N, C, H, W = input_shapeinputs = K.reshape(inputs,(-1, G, C // G, H, W))#inputs.assign_sub()# 计算分组通道的均值和方差gn_mean = K.mean(inputs,axis=[2,3,4],keepdims=True)gn_variance = K.var(inputs,axis=[2,3,4],keepdims=True)# 当模型用于测试阶段时,使用moving_mean/variance记录的均值/方差def gn_inference():# when in test phase, just return moving_mean & moving_varmean, variance = self.moving_mean, self.moving_varianceoutputs = (inputs - mean) / (K.sqrt(variance + self.epsilon))outputs = K.reshape(outputs,[-1, C, H, W]) * self.gamma + self.beta# transpose: [bs,c,h,w] -> [ba,h,w,c]if self.axis in {-1,3}:outputs = K.permute_dimensions(outputs,(0,2,3,1))return outputsif training in {0,False}:return gn_inference()# 当模型用于训练阶段时,使用分组通道实时计算均值/方差outputs = (inputs - gn_mean) / (K.sqrt(gn_variance + self.epsilon))outputs = K.reshape(outputs,[-1, C, H, W]) * self.gamma + self.beta # transpose: [bs,c,h,w] -> [ba,h,w,c]if self.axis in {-1,3}:outputs = K.permute_dimensions(outputs,(0,2,3,1))# 手动更新self.moving_mean/varianceself.add_update([K.moving_average_update(self.moving_mean, mean,self.momentum),K.moving_average_update(self.moving_variance,variance,self.momentum)],inputs)# 根据模型状态不同选择不同的GN计算方法,train时选择outputs,test时选择gn_inferencereturn K.in_train_phase(outputs, gn_inference,training=training)
实验对比结果
实验日志位于compare_log,包含3个文件:
- train_bn.log -> unet+bn日志
- train_gn_ema.log -> unet+gn(有moving_average操作)
- train_gn_noema.log -> unet+gn(无moving_average操作)
结果说明:
- gn without moving average 得到的val_loss会更低,可达到 0.2左右
- gn with moving average 有时会一直存在很高的val_loss, 所以我觉得可能GN并不需要 apply moving average
- bn 得到的val_ loss约为 0.26, 高于gn without moving average.
欢迎大家批评指正~ 谢谢谢谢~~~
自定义报错返回_Keras编写自定义层--以GroupNormalization为例相关推荐
- 自定义报错返回_MybatisPlus基础篇学习笔记(五)------自定义sql及分页查询
本章目录 自定义sql 分页查询 1. 自定义sql 在dao文件中编写自定义接口,并在方法上使用注解形式注入SQL,如图所示: 第一种: 第二种 ① application.yml加入下面配置 my ...
- 自定义报错返回_Spring Cloud Feign的使用和自定义配置
在上一篇文章 null:Spring Cloud 自定义Eureka Ribbon负载均衡策略zhuanlan.zhihu.com 中,我们使用Ribbon自定义的策略实现了负载均衡,接下来我们介绍 ...
- Nuxt.js - 最新自定义报错、缺省、404、500 定制化 error.vue(页面、接口报错时自动跳转到该自定义页面)支持自定义文案、状态码等功能
前言 在开发 Nuxt.js 时,当页面出错或接口后台数据返回异常时,页面就会 "直接呈现" 报错的信息. 正常情况下,当页面 404.500 或页面报错时, 前端应该 自动跳转到 ...
- spring boot整合SpringSecurity-03 自定义报错信息
spring boot整合SpringSecurity 目录 spring boot整合SpringSecurity-01入门 spring boot整合SpringSecurity-02 基于Ser ...
- python def函数报错详解_python自定义函数def的应用详解
这篇文章主要介绍了python自定义函数def的应用详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧 这里是三岁,来和大家唠唠 ...
- 'int' object has no attribute 'backward'报错 使用Pytorch编写 Hinge loss函数
在编写SVM中的Hinge loss函数的时候报错"'int' object has no attribute 'backward'" for epoch in range(50) ...
- Linux系统 安装飞桨PaddleHub+LAC实现词法分析 实现加载自定义词典分词 (解决Lac服务启动报错问题、解决自定义词典空格无法分词问题)
1.先上链接:飞桨PaddlePaddle-源于产业实践的开源深度学习平台 2.LAC模型简介:Lexical Analysis of Chinese,简称 LAC,是一个联合的词法分析模型,能整体性 ...
- java 8流自定义收集器_Java 8编写自定义收集器简介
java 8流自定义收集器 Java 8引入了收集器的概念. 大多数时候,我们几乎不使用Collectors类中的工厂方法,例如collect(toList()) , toSet()或其他更有趣的方法 ...
- $Ajax构成和请求报错返回值
$.ajax({ type: "post", url: sUrl, data: po ...
最新文章
- python网络编程——简单例子
- 网站制作基本要素了解一下
- 797C C. Minimal string
- MATLAB实战系列(十二)-如何用人工鱼群算法解决带时间窗车辆路径(CVRP)问题(附MATLAB代码)
- 10.ASCII码对照
- Qt编写OpenMP程序--HelloWorld
- solr 启动时指定 solr.home
- [react] 在React中如何避免不必要的render?
- 原来嵌套个网页的技术是这样的
- 国家生物信息中心在核酸研究发表单细胞DNA甲基化数据库—scMethBank
- 函数最值题目及答案_高考数学攻克压轴题:圆锥曲线取值范围和最值问题解题模型...
- 移动花卡服务器系统异常,开通了抖音移动花卡免流服务,为什么使用抖音不显示免流呢?...
- U盘, USB读卡器, U盘读卡器三者技术分析区别
- 关于GTP-4,这是14个被忽略的惊人细节!
- SVG格式文件插入Word/WPS,三种简单快捷的方法,实现图片高清无损
- 【NOI2015】BZOJ4199品酒大会题解(SAM+树形DP)
- 华为6 有没有计算机,华为手机连电脑没有usb存储 华为手机连电脑不显示usb存储怎么回事 - 云骑士一键重装系统...
- 我现在必须new一个对象!!!
- C#使用西门子S7 协议读写PLC DB块
- 2021全年营收净利润双增,李宁财报透露哪些确定与不定?
热门文章
- python3socket非阻塞_利用Python中SocketServer 实现客户端与服务器间非阻塞通信
- C++中string erase函数的使用
- 复制已有的Tomcat作为新的Tomcat,只需修改三个配置文件,五步操作,保证能正常运行!
- Maya游戏角色绑定入门学习教程 Game Character Rigging for Beginners in Maya
- 设计模式:简单工厂、工厂方法、抽象工厂之小结与区别
- Go 分布式学习利器(10)-- Go语言的接口
- [Java in NetBeans] Lesson 01. Java Programming Basics
- spark ml中一个比较通用的transformer
- IIS 7.5 去掉index.php 西数服务器
- WinForm绘制带有升序、降序的柱形图