环境
package version
tensorflow 2.3.0
keras 2.4.3
源码

部分主要源码

class RNN(Layer):def __init__(self,cell,return_sequences=False,return_state=False,go_backwards=False,stateful=False,unroll=False,time_major=False,**kwargs):if isinstance(cell, (list, tuple)):cell = StackedRNNCells(cell)# If True, the output for masked timestep will be zeros, whereas in the# False case, output from previous timestep is returned for masked timestep.self.zero_output_for_mask = kwargs.pop('zero_output_for_mask', False)if 'input_shape' not in kwargs and ('input_dim' in kwargs or 'input_length' in kwargs):input_shape = (kwargs.pop('input_length', None),kwargs.pop('input_dim', None))kwargs['input_shape'] = input_shapesuper(RNN, self).__init__(**kwargs)self.cell = cellself.return_sequences = return_sequencesself.return_state = return_stateself.go_backwards = go_backwardsself.stateful = statefulself.unroll = unrollself.time_major = time_majorself.supports_masking = Trueself.input_spec = Noneself.state_spec = Noneself._states = Noneself.constants_spec = Noneself._num_constants = 0if stateful:if ds_context.has_strategy():raise ValueError('RNNs with stateful=True not yet supported with ''tf.distribute.Strategy.')@propertydef states(self):if self._states is None:state = nest.map_structure(lambda _: None, self.cell.state_size)return state if nest.is_sequence(self.cell.state_size) else [state]return self._states@trackable.no_automatic_dependency_trackingdef states(self, states):self._states = statesdef compute_mask(self, inputs, mask):# Time step masks must be the same for each input.# This is because the mask for an RNN is of size [batch, time_steps, 1],# and specifies which time steps should be skipped, and a time step# must be skipped for all inputs.# TODO(scottzhu): Should we accept multiple different masks?mask = nest.flatten(mask)[0]output_mask = mask if self.return_sequences else Noneif self.return_state:state_mask = [None for _ in self.states]return [output_mask] + state_maskelse:return output_maskdef build(self, input_shape):if isinstance(input_shape, list):input_shape = input_shape[0]# The input_shape here could be a nest structure.# do the tensor_shape to shapes here. The input could be single tensor, or a# nested structure of tensors.def get_input_spec(shape):"""Convert input shape to InputSpec."""if isinstance(shape, tensor_shape.TensorShape):input_spec_shape = shape.as_list()else:input_spec_shape = list(shape)batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)if not self.stateful:input_spec_shape[batch_index] = Noneinput_spec_shape[time_step_index] = Nonereturn InputSpec(shape=tuple(input_spec_shape))def get_step_input_shape(shape):if isinstance(shape, tensor_shape.TensorShape):shape = tuple(shape.as_list())# remove the timestep from the input_shapereturn shape[1:] if self.time_major else (shape[0],) + shape[2:]# Check whether the input shape contains any nested shapes. It could be# (tensor_shape(1, 2), tensor_shape(3, 4)) or (1, 2, 3) which is from numpy# inputs.try:input_shape = tensor_shape.as_shape(input_shape)except (ValueError, TypeError):# A nested tensor inputpassif not nest.is_sequence(input_shape):# This indicates the there is only one input.if self.input_spec is not None:self.input_spec[0] = get_input_spec(input_shape)else:self.input_spec = [get_input_spec(input_shape)]step_input_shape = get_step_input_shape(input_shape)else:if self.input_spec is not None:self.input_spec[0] = nest.map_structure(get_input_spec, input_shape)else:self.input_spec = generic_utils.to_list(nest.map_structure(get_input_spec, input_shape))step_input_shape = nest.map_structure(get_step_input_shape, input_shape)# allow cell (if layer) to build before we set or validate state_spec.if isinstance(self.cell, Layer) and not self.cell.built:with K.name_scope(self.cell.name):self.cell.build(step_input_shape)self.cell.built = True# set or validate state_specif _is_multiple_state(self.cell.state_size):state_size = list(self.cell.state_size)else:state_size = [self.cell.state_size]if self.state_spec is not None:# initial_state was passed in call, check compatibilityself._validate_state_spec(state_size, self.state_spec)else:self.state_spec = [InputSpec(shape=[None] + tensor_shape.as_shape(dim).as_list())for dim in state_size]if self.stateful:self.reset_states()self.built = True@staticmethoddef _validate_state_spec(cell_state_sizes, init_state_specs):"""Validate the state spec between the initial_state and the state_size.Args:cell_state_sizes: list, the `state_size` attribute from the cell.init_state_specs: list, the `state_spec` from the initial_state that ispassed in `call()`.Raises:ValueError: When initial state spec is not compatible with the state size."""validation_error = ValueError('An `initial_state` was passed that is not compatible with ''`cell.state_size`. Received `state_spec`={}; ''however `cell.state_size` is ''{}'.format(init_state_specs, cell_state_sizes))flat_cell_state_sizes = nest.flatten(cell_state_sizes)flat_state_specs = nest.flatten(init_state_specs)if len(flat_cell_state_sizes) != len(flat_state_specs):raise validation_errorfor cell_state_spec, cell_state_size in zip(flat_state_specs,flat_cell_state_sizes):if not tensor_shape.TensorShape(# Ignore the first axis for init_state which is for batchcell_state_spec.shape[1:]).is_compatible_with(tensor_shape.TensorShape(cell_state_size)):raise validation_error@doc_controls.do_not_doc_inheritabledef get_initial_state(self, inputs):get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)if nest.is_sequence(inputs):# The input are nested sequences. Use the first element in the seq to get# batch size and dtype.inputs = nest.flatten(inputs)[0]input_shape = array_ops.shape(inputs)batch_size = input_shape[1] if self.time_major else input_shape[0]dtype = inputs.dtypeif get_initial_state_fn:init_state = get_initial_state_fn(inputs=None, batch_size=batch_size, dtype=dtype)else:init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,dtype)# Keras RNN expect the states in a list, even if it's a single state tensor.if not nest.is_sequence(init_state):init_state = [init_state]# Force the state to be a list in case it is a namedtuple eg LSTMStateTuple.return list(init_state)def __call__(self, inputs, initial_state=None, constants=None, **kwargs):inputs, initial_state, constants = _standardize_args(inputs,initial_state,constants,self._num_constants)if initial_state is None and constants is None:return super(RNN, self).__call__(inputs, **kwargs)# If any of `initial_state` or `constants` are specified and are Keras# tensors, then add them to the inputs and temporarily modify the# input_spec to include them.additional_inputs = []additional_specs = []if initial_state is not None:additional_inputs += initial_stateself.state_spec = nest.map_structure(lambda s: InputSpec(shape=K.int_shape(s)), initial_state)additional_specs += self.state_specif constants is not None:additional_inputs += constantsself.constants_spec = [InputSpec(shape=K.int_shape(constant)) for constant in constants]self._num_constants = len(constants)additional_specs += self.constants_spec# additional_inputs can be empty if initial_state or constants are provided# but empty (e.g. the cell is stateless).flat_additional_inputs = nest.flatten(additional_inputs)is_keras_tensor = K.is_keras_tensor(flat_additional_inputs[0]) if flat_additional_inputs else Truefor tensor in flat_additional_inputs:if K.is_keras_tensor(tensor) != is_keras_tensor:raise ValueError('The initial state or constants of an RNN'' layer cannot be specified with a mix of'' Keras tensors and non-Keras tensors'' (a "Keras tensor" is a tensor that was'' returned by a Keras layer, or by `Input`)')if is_keras_tensor:# Compute the full input spec, including state and constantsfull_input = [inputs] + additional_inputsif self.built:# Keep the input_spec since it has been populated in build() method.full_input_spec = self.input_spec + additional_specselse:# The original input_spec is None since there could be a nested tensor# input. Update the input_spec to match the inputs.full_input_spec = generic_utils.to_list(nest.map_structure(lambda _: None, inputs)) + additional_specs# Perform the call with temporarily replaced input_specself.input_spec = full_input_specoutput = super(RNN, self).__call__(full_input, **kwargs)# Remove the additional_specs from input spec and keep the rest. It is# important to keep since the input spec was populated by build(), and# will be reused in the stateful=True.self.input_spec = self.input_spec[:-len(additional_specs)]return outputelse:if initial_state is not None:kwargs['initial_state'] = initial_stateif constants is not None:kwargs['constants'] = constantsreturn super(RNN, self).__call__(inputs, **kwargs)def call(self,inputs,mask=None,training=None,initial_state=None,constants=None):# The input should be dense, padded with zeros. If a ragged input is fed# into the layer, it is padded and the row lengths are used for masking.inputs, row_lengths = K.convert_inputs_if_ragged(inputs)is_ragged_input = (row_lengths is not None)self._validate_args_if_ragged(is_ragged_input, mask)inputs, initial_state, constants = self._process_inputs(inputs, initial_state, constants)self._maybe_reset_cell_dropout_mask(self.cell)if isinstance(self.cell, StackedRNNCells):for cell in self.cell.cells:self._maybe_reset_cell_dropout_mask(cell)if mask is not None:# Time step masks must be the same for each input.# TODO(scottzhu): Should we accept multiple different masks?mask = nest.flatten(mask)[0]if nest.is_sequence(inputs):# In the case of nested input, use the first element for shape check.input_shape = K.int_shape(nest.flatten(inputs)[0])else:input_shape = K.int_shape(inputs)timesteps = input_shape[0] if self.time_major else input_shape[1]if self.unroll and timesteps is None:raise ValueError('Cannot unroll a RNN if the ''time dimension is undefined. \n''- If using a Sequential model, ''specify the time dimension by passing ''an `input_shape` or `batch_input_shape` ''argument to your first layer. If your ''first layer is an Embedding, you can ''also use the `input_length` argument.\n''- If using the functional API, specify ''the time dimension by passing a `shape` ''or `batch_shape` argument to your Input layer.')kwargs = {}if generic_utils.has_arg(self.cell.call, 'training'):kwargs['training'] = training# TF RNN cells expect single tensor as state instead of list wrapped tensor.is_tf_rnn_cell = getattr(self.cell, '_is_tf_rnn_cell', None) is not None# Use the __call__ function for callable objects, eg layers, so that it# will have the proper name scopes for the ops, etc.cell_call_fn = self.cell.__call__ if callable(self.cell) else self.cell.callif constants:if not generic_utils.has_arg(self.cell.call, 'constants'):raise ValueError('RNN cell does not support constants')def step(inputs, states):constants = states[-self._num_constants:]  # pylint: disable=invalid-unary-operand-typestates = states[:-self._num_constants]  # pylint: disable=invalid-unary-operand-typestates = states[0] if len(states) == 1 and is_tf_rnn_cell else statesoutput, new_states = cell_call_fn(inputs, states, constants=constants, **kwargs)if not nest.is_sequence(new_states):new_states = [new_states]return output, new_stateselse:def step(inputs, states):states = states[0] if len(states) == 1 and is_tf_rnn_cell else statesoutput, new_states = cell_call_fn(inputs, states, **kwargs)if not nest.is_sequence(new_states):new_states = [new_states]return output, new_stateslast_output, outputs, states = K.rnn(step,inputs,initial_state,constants=constants,go_backwards=self.go_backwards,mask=mask,unroll=self.unroll,input_length=row_lengths if row_lengths is not None else timesteps,time_major=self.time_major,zero_output_for_mask=self.zero_output_for_mask)if self.stateful:updates = [state_ops.assign(self_state, state) for self_state, state in zip(nest.flatten(self.states), nest.flatten(states))]self.add_update(updates)if self.return_sequences:output = K.maybe_convert_to_ragged(is_ragged_input, outputs, row_lengths)else:output = last_outputif self.return_state:if not isinstance(states, (list, tuple)):states = [states]else:states = list(states)return generic_utils.to_list(output) + stateselse:return outputdef _process_inputs(self, inputs, initial_state, constants):# input shape: `(samples, time (padded with zeros), input_dim)`# note that the .build() method of subclasses MUST define# self.input_spec and self.state_spec with complete input shapes.if (isinstance(inputs, collections_abc.Sequence)and not isinstance(inputs, tuple)):# get initial_state from full input spec# as they could be copied to multiple GPU.if not self._num_constants:initial_state = inputs[1:]else:initial_state = inputs[1:-self._num_constants]constants = inputs[-self._num_constants:]if len(initial_state) == 0:initial_state = Noneinputs = inputs[0]if self.stateful:if initial_state is not None:# When layer is stateful and initial_state is provided, check if the# recorded state is same as the default value (zeros). Use the recorded# state if it is not same as the default.non_zero_count = math_ops.add_n([math_ops.count_nonzero_v2(s)for s in nest.flatten(self.states)])# Set strict = True to keep the original structure of the state.initial_state = control_flow_ops.cond(non_zero_count > 0,true_fn=lambda: self.states,false_fn=lambda: initial_state,strict=True)else:initial_state = self.stateselif initial_state is None:initial_state = self.get_initial_state(inputs)if len(initial_state) != len(self.states):raise ValueError('Layer has ' + str(len(self.states)) +' states but was passed ' + str(len(initial_state)) +' initial states.')return inputs, initial_state, constantsdef _validate_args_if_ragged(self, is_ragged_input, mask):if not is_ragged_input:returnif mask is not None:raise ValueError('The mask that was passed in was ' + str(mask) +' and cannot be applied to RaggedTensor inputs. Please ''make sure that there is no mask passed in by upstream ''layers.')if self.unroll:raise ValueError('The input received contains RaggedTensors and does ''not support unrolling. Disable unrolling by passing ''`unroll=False` in the RNN Layer constructor.')def reset_states(self, states=None):"""Reset the recorded states for the stateful RNN layer.Can only be used when RNN layer is constructed with `stateful` = `True`.Args:states: Numpy arrays that contains the value for the initial state, whichwill be feed to cell at the first time step. When the value is None,zero filled numpy array will be created based on the cell state size.Raises:AttributeError: When the RNN layer is not stateful.ValueError: When the batch size of the RNN layer is unknown.ValueError: When the input numpy array is not compatible with the RNNlayer state, either size wise or dtype wise."""if not self.stateful:raise AttributeError('Layer must be stateful.')spec_shape = Noneif self.input_spec is not None:spec_shape = nest.flatten(self.input_spec[0])[0].shapeif spec_shape is None:# It is possible to have spec shape to be None, eg when construct a RNN# with a custom cell, or standard RNN layers (LSTM/GRU) which we only know# it has 3 dim input, but not its full shape spec before build().batch_size = Noneelse:batch_size = spec_shape[1] if self.time_major else spec_shape[0]if not batch_size:raise ValueError('If a RNN is stateful, it needs to know ''its batch size. Specify the batch size ''of your input tensors: \n''- If using a Sequential model, ''specify the batch size by passing ''a `batch_input_shape` ''argument to your first layer.\n''- If using the functional API, specify ''the batch size by passing a ''`batch_shape` argument to your Input layer.')# initialize state if Noneif nest.flatten(self.states)[0] is None:def create_state_variable(state):return K.zeros([batch_size] + tensor_shape.as_shape(state).as_list())self.states = nest.map_structure(create_state_variable, self.cell.state_size)if not nest.is_sequence(self.states):self.states = [self.states]elif states is None:for state, size in zip(nest.flatten(self.states),nest.flatten(self.cell.state_size)):K.set_value(state, np.zeros([batch_size] +tensor_shape.as_shape(size).as_list()))else:flat_states = nest.flatten(self.states)flat_input_states = nest.flatten(states)if len(flat_input_states) != len(flat_states):raise ValueError('Layer ' + self.name + ' expects ' +str(len(flat_states)) + ' states, ''but it received ' + str(len(flat_input_states)) +' state values. Input received: ' + str(states))set_value_tuples = []for i, (value, state) in enumerate(zip(flat_input_states,flat_states)):if value.shape != state.shape:raise ValueError('State ' + str(i) + ' is incompatible with layer ' +self.name + ': expected shape=' + str((batch_size, state)) + ', found shape=' + str(value.shape))set_value_tuples.append((state, value))K.batch_set_value(set_value_tuples)
流程
build

input_shape
step_input_shape
state_size

tf keras SimpleRNN源码解析相关推荐

  1. tf keras Dense源码解析

    环境 package version tensorflow 2.3.0 keras 2.4.3 源码 class Dense(Layer):def __init__(self,units,activa ...

  2. [源码解析] TensorFlow 分布式之 ClusterCoordinator

    [源码解析] TensorFlow 分布式之 ClusterCoordinator 文章目录 [源码解析] TensorFlow 分布式之 ClusterCoordinator 1. 思路 1.1 使 ...

  3. 谷歌BERT预训练源码解析(二):模型构建

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_39470744/arti ...

  4. 谷歌BERT预训练源码解析(三):训练过程

    目录 前言 源码解析 主函数 自定义模型 遮蔽词预测 下一句预测 规范化数据集 前言 本部分介绍BERT训练过程,BERT模型训练过程是在自己的TPU上进行的,这部分我没做过研究所以不做深入探讨.BE ...

  5. 谷歌BERT预训练源码解析(一):训练数据生成

    目录 预训练源码结构简介 输入输出 源码解析 参数 主函数 创建训练实例 下一句预测&实例生成 随机遮蔽 输出 结果一览 预训练源码结构简介 关于BERT,简单来说,它是一个基于Transfo ...

  6. dataset__getitem___PyTorch源码解析与实践(1):数据加载Dataset,Sampler与DataLoader

    献给学习PyTorch在路上或者计划较深入理解PyTorch的同行者们 写在前面 笔者一直使用tf,大势所趋决定转PyTorch,这个系列就作为我学习PyTorch的笔记与心得. 网络上PyTorch ...

  7. ElasticSearch源码解析(五):排序(评分公式)

    ElasticSearch源码解析(五):排序(评分公式) 转载自:http://blog.csdn.net/molong1208/article/details/50623948   一.目的 一个 ...

  8. The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorf

    The Wide and Deep Learning Model(译文+Tensorlfow源码解析) 原创 2017年11月03日 22:14:47 标签: 深度学习 / 谷歌 / tensorfl ...

  9. 判定两个tensor维度相同_Tensorflow源码解析5 -- 图的边 - Tensor

    1 概述 前文两篇文章分别讲解了TensorFlow核心对象Graph,和Graph的节点Operation.Graph另外一大成员,即为其边Tensor.边用来表示计算的数据,它经过上游节点计算后得 ...

最新文章

  1. Java获取照片的Exif信息,并解析GPS
  2. 使用webpack2.0 搭建react.js项目
  3. PostgreSQL 与 MySQL 相比,优势何在?[转]
  4. python使用界面-python 可视化界面
  5. 如何用pyecharts绘制柱状图,条形图,折线图,饼图,环形图,散点图
  6. clickhouse原理解析与应用实践_Hybrid App (混合应用) 技术全解析 方案原理篇
  7. 扫地机器人隔板_【扫地机器人使用】_摘要频道_什么值得买
  8. String Table MFC
  9. mac 启动php70 fpm,mac 启动php-fpm
  10. 【Python实例第18讲】affinity propagation聚类算法
  11. 小学生十大计算机专业书排行,小学教辅十大排行榜2018 小学教辅书那些比较好...
  12. 干货资源共享之阿里云大学的学习路线和免费课程
  13. 当你对未来迷茫的时候,请打开这个锦囊
  14. 电脑桌面便签软件怎么新建内容?
  15. Trend趋势反垃圾邮件黑名单申诉方法
  16. 以CRM系统为案例讲解数据分析(重要性介绍及分析方法)
  17. 多视几何 003 二次曲线
  18. strstr函数实现(C语言)
  19. Linux htop命令
  20. Android OpenCV实现文字识别

热门文章

  1. C++设计模式--单例模式(Singleton)及单例通用模板
  2. python 描述器 详解_Python描述器descriptor详解
  3. Clion生成动态链接库.dll
  4. 类加载、类加载器、反射
  5. 手把手教Linux驱动4-进程、文件描述符、file、inode关系详解
  6. linux / 终端常用快捷键
  7. 启明云端分享| ESP8266\ESP32-C3\ESP32-C2三款芯片从核心系统、WIFI射频和基带、外围设备等都有哪些区别
  8. 文件上传打满服务器带宽,文件上传云服务器 带宽选择
  9. Kubernetes入门——深入浅出讲Docker
  10. wifi分析仪怎么看哪个信道好_游戏工作室用什么路由器好?合理选择组建手机工作室网络...