环境
package version
tensorflow 2.3.0
keras 2.4.3
源码
class Dense(Layer):def __init__(self,units,activation=None,use_bias=True,kernel_initializer='glorot_uniform',bias_initializer='zeros',kernel_regularizer=None,bias_regularizer=None,activity_regularizer=None,kernel_constraint=None,bias_constraint=None,**kwargs):super(Dense, self).__init__(activity_regularizer=activity_regularizer, **kwargs)self.units = int(units) if not isinstance(units, int) else unitsself.activation = activations.get(activation)self.use_bias = use_biasself.kernel_initializer = initializers.get(kernel_initializer)self.bias_initializer = initializers.get(bias_initializer)self.kernel_regularizer = regularizers.get(kernel_regularizer)self.bias_regularizer = regularizers.get(bias_regularizer)self.kernel_constraint = constraints.get(kernel_constraint)self.bias_constraint = constraints.get(bias_constraint)self.input_spec = InputSpec(min_ndim=2)self.supports_masking = Truedef build(self, input_shape):dtype = dtypes.as_dtype(self.dtype or K.floatx())if not (dtype.is_floating or dtype.is_complex):raise TypeError('Unable to build `Dense` layer with non-floating point ''dtype %s' % (dtype,))input_shape = tensor_shape.TensorShape(input_shape)last_dim = tensor_shape.dimension_value(input_shape[-1])if last_dim is None:raise ValueError('The last dimension of the inputs to `Dense` ''should be defined. Found `None`.')self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})self.kernel = self.add_weight('kernel',shape=[last_dim, self.units],initializer=self.kernel_initializer,regularizer=self.kernel_regularizer,constraint=self.kernel_constraint,dtype=self.dtype,trainable=True)if self.use_bias:self.bias = self.add_weight('bias',shape=[self.units,],initializer=self.bias_initializer,regularizer=self.bias_regularizer,constraint=self.bias_constraint,dtype=self.dtype,trainable=True)else:self.bias = Noneself.built = Truedef call(self, inputs):return core_ops.dense(inputs,self.kernel,self.bias,self.activation,dtype=self._compute_dtype_object)def compute_output_shape(self, input_shape):input_shape = tensor_shape.TensorShape(input_shape)input_shape = input_shape.with_rank_at_least(2)if tensor_shape.dimension_value(input_shape[-1]) is None:raise ValueError('The innermost dimension of input_shape must be defined, but saw: %s'% input_shape)return input_shape[:-1].concatenate(self.units)def get_config(self):config = super(Dense, self).get_config()config.update({'units':self.units,'activation':activations.serialize(self.activation),'use_bias':self.use_bias,'kernel_initializer':initializers.serialize(self.kernel_initializer),'bias_initializer':initializers.serialize(self.bias_initializer),'kernel_regularizer':regularizers.serialize(self.kernel_regularizer),'bias_regularizer':regularizers.serialize(self.bias_regularizer),'activity_regularizer':regularizers.serialize(self.activity_regularizer),'kernel_constraint':constraints.serialize(self.kernel_constraint),'bias_constraint':constraints.serialize(self.bias_constraint)})return config

查看源码可以看到最简单的Dense总共有四个方法

  1. init 初始该层
  2. build 初始weight和bias
  3. call 计算
  4. get_config 获取config
init

创建时各个参数的含义

parms x
units 激活单元
activation 激活函数
use_bias 是否用偏移量
initializer 矩阵初始化的方法
regularizer 权重正则化的方法
constraint 限制方法
build

初始化后就可以创建权重矩阵和偏移矩阵了(weight bias),主要运用的add_weight方法

call

计算,用的是core_ops.dense方法,以下是dense源码

def dense(inputs, kernel, bias=None, activation=None, dtype=None):if dtype:if inputs.dtype.base_dtype != dtype.base_dtype:inputs = math_ops.cast(inputs, dtype=dtype)rank = inputs.shape.rankif rank == 2 or rank is None:if isinstance(inputs, sparse_tensor.SparseTensor):outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)else:outputs = gen_math_ops.mat_mul(inputs, kernel)# Broadcast kernel to inputs.else:outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])# Reshape the output back to the original ndim of the input.if not context.executing_eagerly():shape = inputs.shape.as_list()output_shape = shape[:-1] + [kernel.shape[-1]]outputs.set_shape(output_shape)if bias is not None:outputs = nn_ops.bias_add(outputs, bias)if activation is not None:outputs = activation(outputs)return outputs## TODO:乘法区别

这里input是个tensor,所以有rank变量,rank即tensor是几维的
一个是 sparse_ops.sparse_tensor_dense_matmul 和 gen_math_ops.mat_mul
一个是 standard_ops.tensordot

compute_output_shape

根据input和units,计算output_shape

get_config

返回config dict

tf keras Dense源码解析相关推荐

  1. tf keras SimpleRNN源码解析

    环境 package version tensorflow 2.3.0 keras 2.4.3 源码 部分主要源码 class RNN(Layer):def __init__(self,cell,re ...

  2. 谷歌BERT预训练源码解析(二):模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...

  3. 谷歌BERT预训练源码解析(三):训练过程

    目录 前言 源码解析 主函数 自定义模型 遮蔽词预测 下一句预测 规范化数据集 前言 本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BE ...

  4. The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorf

    The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorfl ...

  5. [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表

    [源码解析] NVIDIA HugeCTR,GPU版本参数服务器- (5) 嵌入式hash表 文章目录 [源码解析] NVIDIA HugeCTR,GPU版本参数服务器--- (5) 嵌入式hash表 ...

  6. 《Attention is all you need》源码解析+算法详解

    Attention is all you need 源码解析 最近学习Transformer模型的时候,并且好好读了一下Google的<Attention is all you need> ...

  7. [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

    [源码解析] 深度学习流水线并行Gpipe(1)-流水线基本实现 文章目录 [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 0x00 摘要 0x01 概述 1.1 什么是GPip ...

  8. [源码解析] TensorFlow 分布式之 ClusterCoordinator

    [源码解析] TensorFlow 分布式之 ClusterCoordinator 文章目录 [源码解析] TensorFlow 分布式之 ClusterCoordinator 1. 思路 1.1 使 ...

  9. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

最新文章

  1. Codeforces Round #539 (Div. 2) C. Sasha and a Bit of Relax
  2. 管理-Tomcat和Resin如何配置对指定后缀文件(如:.pptx)下载支持
  3. 游戏中应用强化学习技术,目的就是要打败人类玩家?
  4. java linux路径 home_根据linux自带的JDK,配置JAVA_HOME目录
  5. 3.菜鸟教你一步一步开发 web service 之 axis 服务端创建
  6. 使用yum时,保留下载包设置
  7. 【LaTeX】E喵的LaTeX新手入门教程(2)基础排版
  8. Test Article
  9. 硬件服务器采购指南,硬件组装_服务器采购指南_太平洋电脑网PConline
  10. 基于融合计算?蚂蚁金服的在线机器学习是如何做的
  11. [android] 请求码和结果码的作用
  12. JavaScript基础面试题
  13. 拓端tecdat|R语言基于协方差的结构方程拟合的卡方检验
  14. PostgreSQL:安装
  15. cmd命令实现百度云盘光速下载
  16. 夏日汽车保养 雨季汽车保养
  17. 【技术文档】《算法设计与分析导论》R.C.T.Lee等·第4章 分治策略
  18. 获得Windows主机的主机序列号
  19. 【RW007系列综合实战1】STM32+RW007实现BLE透传功能
  20. HTML5+CSS3小实例:酷炫的菱形加载动画

热门文章

  1. python(matplotlib6)——打印图像(imshow)3D数据(contourf等高线)
  2. 前、中、后缀表达式概述及转换+栈的计算器原理及代码分析(含完整源码)
  3. python pyqt5 窗体自适应_Pyqt5自适应布局实例
  4. 机器学习中为什么需要梯度下降_机器学习,梯度下降算法,问题引入
  5. Git 索引文件(index file)
  6. 软件工程 / 为什么基于接口而非实现编程?
  7. Tyznn人脸识别温度检测智能门禁系统现货发售,助力疫情防控
  8. python画误差棒_给妹子讲python-S02E06matplotlib散点图、频次直方图与误差线图
  9. rust油桶用什么打_选什么样的柜子才好用?别再选定制柜了,还是老手艺人打的柜子好...
  10. mfc 子窗体任何消息都不触发_winform让窗体一直显示在桌面上以及FindWindow