tf keras Dense源码解析
环境
package | version |
---|---|
tensorflow | 2.3.0 |
keras | 2.4.3 |
源码
class Dense(Layer):def __init__(self,units,activation=None,use_bias=True,kernel_initializer='glorot_uniform',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,**kwargs):super(Dense, self).__init__(activity_regularizer=activity_regularizer, **kwargs)self.units = int(units) if not isinstance(units, int) else unitsself.activation = activations.get(activation)self.use_bias = use_biasself.kernel_initializer = initializers.get(kernel_initializer)self.bias_initializer = initializers.get(bias_initializer)self.kernel_regularizer = regularizers.get(kernel_regularizer)self.bias_regularizer = regularizers.get(bias_regularizer)self.kernel_constraint = constraints.get(kernel_constraint)self.bias_constraint = constraints.get(bias_constraint)self.input_spec = InputSpec(min_ndim=2)self.supports_masking = Truedef build(self, input_shape):dtype = dtypes.as_dtype(self.dtype or K.floatx())if not (dtype.is_floating or dtype.is_complex):raise TypeError('Unable to build `Dense` layer with non-floating point ''dtype %s' % (dtype,))input_shape = tensor_shape.TensorShape(input_shape)last_dim = tensor_shape.dimension_value(input_shape[-1])if last_dim is None:raise ValueError('The last dimension of the inputs to `Dense` ''should be defined. Found `None`.')self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})self.kernel = self.add_weight('kernel',shape=[last_dim, self.units],initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,dtype=self.dtype,trainable=True)if self.use_bias:self.bias = self.add_weight('bias',shape=[self.units,],initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,dtype=self.dtype,trainable=True)else:self.bias = Noneself.built = Truedef call(self, inputs):return core_ops.dense(inputs,self.kernel,self.bias,self.activation,dtype=self._compute_dtype_object)def compute_output_shape(self, input_shape):input_shape = tensor_shape.TensorShape(input_shape)input_shape = input_shape.with_rank_at_least(2)if tensor_shape.dimension_value(input_shape[-1]) is None:raise ValueError('The innermost dimension of input_shape must be defined, but saw: %s'% input_shape)return input_shape[:-1].concatenate(self.units)def get_config(self):config = super(Dense, self).get_config()config.update({'units':self.units,'activation':activations.serialize(self.activation),'use_bias':self.use_bias,'kernel_initializer':initializers.serialize(self.kernel_initializer),'bias_initializer':initializers.serialize(self.bias_initializer),'kernel_regularizer':regularizers.serialize(self.kernel_regularizer),'bias_regularizer':regularizers.serialize(self.bias_regularizer),'activity_regularizer':regularizers.serialize(self.activity_regularizer),'kernel_constraint':constraints.serialize(self.kernel_constraint),'bias_constraint':constraints.serialize(self.bias_constraint)})return config
查看源码可以看到最简单的Dense总共有四个方法
- init 初始该层
- build 初始weight和bias
- call 计算
- get_config 获取config
init
创建时各个参数的含义
parms | x |
---|---|
units | 激活单元 |
activation | 激活函数 |
use_bias | 是否用偏移量 |
initializer | 矩阵初始化的方法 |
regularizer | 权重正则化的方法 |
constraint | 限制方法 |
build
初始化后就可以创建权重矩阵和偏移矩阵了(weight bias),主要运用的add_weight方法
call
计算,用的是core_ops.dense方法,以下是dense源码
def dense(inputs, kernel, bias=None, activation=None, dtype=None):if dtype:if inputs.dtype.base_dtype != dtype.base_dtype:inputs = math_ops.cast(inputs, dtype=dtype)rank = inputs.shape.rankif rank == 2 or rank is None:if isinstance(inputs, sparse_tensor.SparseTensor):outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)else:outputs = gen_math_ops.mat_mul(inputs, kernel)# Broadcast kernel to inputs.else:outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])# Reshape the output back to the original ndim of the input.if not context.executing_eagerly():shape = inputs.shape.as_list()output_shape = shape[:-1] + [kernel.shape[-1]]outputs.set_shape(output_shape)if bias is not None:outputs = nn_ops.bias_add(outputs, bias)if activation is not None:outputs = activation(outputs)return outputs## TODO:乘法区别
这里input是个tensor,所以有rank变量,rank即tensor是几维的
一个是 sparse_ops.sparse_tensor_dense_matmul 和 gen_math_ops.mat_mul
一个是 standard_ops.tensordot
compute_output_shape
根据input和units,计算output_shape
get_config
返回config dict
tf keras Dense源码解析相关推荐
- tf keras SimpleRNN源码解析
环境 package version tensorflow 2.3.0 keras 2.4.3 源码 部分主要源码 class RNN(Layer):def __init__(self,cell,re ...
- 谷歌BERT预训练源码解析(二):模型构建
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...
- 谷歌BERT预训练源码解析(三):训练过程
目录 前言 源码解析 主函数 自定义模型 遮蔽词预测 下一句预测 规范化数据集 前言 本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BE ...
- The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorf
The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorfl ...
- [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表
[源码解析] NVIDIA HugeCTR,GPU版本参数服务器- (5) 嵌入式hash表 文章目录 [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表 ...
- 《Attention is all you need》源码解析+算法详解
Attention is all you need 源码解析 最近学习Transformer模型的时候,并且好好读了一下Google的<Attention is all you need> ...
- [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现
[源码解析] 深度学习流水线并行Gpipe(1)-流水线基本实现 文章目录 [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 0x00 摘要 0x01 概述 1.1 什么是GPip ...
- [源码解析] TensorFlow 分布式之 ClusterCoordinator
[源码解析] TensorFlow 分布式之 ClusterCoordinator 文章目录 [源码解析] TensorFlow 分布式之 ClusterCoordinator 1. 思路 1.1 使 ...
- 谷歌BERT预训练源码解析(一):训练数据生成
目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...
最新文章
- Codeforces Round #539 (Div. 2) C. Sasha and a Bit of Relax
- 管理-Tomcat和Resin如何配置对指定后缀文件(如:.pptx)下载支持
- 游戏中应用强化学习技术,目的就是要打败人类玩家?
- java linux路径 home_根据linux自带的JDK,配置JAVA_HOME目录
- 3.菜鸟教你一步一步开发 web service 之 axis 服务端创建
- 使用yum时,保留下载包设置
- 【LaTeX】E喵的LaTeX新手入门教程(2)基础排版
- Test Article
- 硬件服务器采购指南,硬件组装_服务器采购指南_太平洋电脑网PConline
- 基于融合计算?蚂蚁金服的在线机器学习是如何做的
- [android] 请求码和结果码的作用
- JavaScript基础面试题
- 拓端tecdat|R语言基于协方差的结构方程拟合的卡方检验
- PostgreSQL:安装
- cmd命令实现百度云盘光速下载
- 夏日汽车保养 雨季汽车保养
- 【技术文档】《算法设计与分析导论》R.C.T.Lee等·第4章 分治策略
- 获得Windows主机的主机序列号
- 【RW007系列综合实战1】STM32+RW007实现BLE透传功能
- HTML5+CSS3小实例:酷炫的菱形加载动画
热门文章
- python(matplotlib6)——打印图像(imshow)3D数据(contourf等高线)
- 前、中、后缀表达式概述及转换+栈的计算器原理及代码分析(含完整源码)
- python pyqt5 窗体自适应_Pyqt5自适应布局实例
- 机器学习中为什么需要梯度下降_机器学习,梯度下降算法,问题引入
- Git 索引文件(index file)
- 软件工程 / 为什么基于接口而非实现编程?
- Tyznn人脸识别温度检测智能门禁系统现货发售,助力疫情防控
- python画误差棒_给妹子讲python-S02E06matplotlib散点图、频次直方图与误差线图
- rust油桶用什么打_选什么样的柜子才好用?别再选定制柜了,还是老手艺人打的柜子好...
- mfc 子窗体任何消息都不触发_winform让窗体一直显示在桌面上以及FindWindow