在用ConvLSTM2D处理图像数据时,通常先将图像经过Conv2D的操作,然后把卷积结果输入到LSTM中。这个过程需要经历的过程是:(1)把shape=[n,h,w,c]的数据输入到Conv2D的卷积层中,得到卷积结果。(2)把卷积结果通过split变形成shape = [n,d,h,w,c]形式。(3)把shape=[n,d,h,w,c]形式的数据输入到LSTM中。的每个step中,这个过程很繁琐。在keras中,用TimeDistributed简化了这一过程。示例如下:

import tensorflow.compat.v1 as tf
from keras.layers import ConvLSTM2D,TimeDistributed,Conv2D
import numpy as npinputs_np = tf.convert_to_tensor(np.random.random((4,6,256,256,3)).astype(np.float32))  # shape = [5,6,10,10,3]# inputs = tf.concat([inputs_np[:,i,...] for i in range(6)],axis=0)
# conv1 = Conv2D(filters=10,kernel_size=(3,3),strides=(1,1))(inputs)
# conv1 =tf.concat([tf.expand_dims(v,axis=1) for v in tf.split(conv1,num_or_size_splits=6)],axis=1)
conv1 = TimeDistributed(Conv2D(filters=10,kernel_size=(3,3),strides=(1,1)),input_shape=(6,256,256,3))(inputs_np)lstm_out= ConvLSTM2D(filters=1,kernel_size=(3,3),strides=(1,1),padding='valid',activation='tanh',return_sequences=True)(conv1)with tf.Session() as sess:sess.run(tf.global_variables_initializer())lstm_out_ = sess.run(lstm_out)print(lstm_out_.shape)

程序中,用TimeDistributed计算conv1的过程等价于注释行的计算过程。

注意:在TimeDistributed的参数中,input_shape=[time_steps,h,w,c]:

(1)不包含batch维度,

(2)第一个维度是clip_len, 或者LSTM的长度

(3)TimeDistributed 的操作对象只能是keras的一个layer对象,不能是自定义的函数

(4)TimeDistributed也是keras的一个层(layer)对象。

(5)用TimeDistributed封装了一个layer之后,模型的默认变量名发生了变化,但是没有新增变量。

如下例,当用TimeDistributed封装了一个Conv2d之后,模型的默认变量名由 “conv2d/kernel:0”变成了“time_distributed/kernel:0”。

(6)针对默认变量名的变化,可以使用“name=”的layer参数进行纠正。如下例中的“NEW1”中的操作。

import tensorflow.compat.v1 as tf
from tensorflow.keras.layers import ConvLSTM2D, TimeDistributed, Conv2D
import numpy as npinputs_np = tf.convert_to_tensor(np.random.random((4, 6, 256, 256, 3)).astype(np.float32))  # shape = [5,6,10,10,3]
with tf.variable_scope("RAW"):inputs = tf.concat([inputs_np[:,i,...] for i in range(6)],axis=0)conv1 = Conv2D(filters=10,kernel_size=(3,3),strides=(1,1))(inputs)conv1 =tf.concat([tf.expand_dims(v,axis=1) for v in tf.split(conv1,num_or_size_splits=6)],axis=1)
with tf.variable_scope("NEW"):conv2 = TimeDistributed(Conv2D(filters=10, kernel_size=(3, 3), strides=(1, 1)), input_shape=(6, 256, 256, 3))(inputs_np)
with tf.variable_scope("NEW1"):conv3 = TimeDistributed(Conv2D(filters=10, kernel_size=(3, 3), strides=(1, 1)), input_shape=(6, 256, 256, 3),name="conv2d")(inputs_np)lstm_out = ConvLSTM2D(filters=1, kernel_size=(3, 3), strides=(1, 1), padding='valid', activation='tanh',return_sequences=True)(conv1)all_vars = tf.global_variables()
for var in all_vars:print(var.name)
"""
打印结果:
RAW/conv2d/kernel:0
RAW/conv2d/bias:0
NEW/time_distributed/kernel:0
NEW/time_distributed/bias:0
NEW1/conv2d/kernel:0
NEW1/conv2d/bias:0
conv_lst_m2d/kernel:0
conv_lst_m2d/recurrent_kernel:0
conv_lst_m2d/bias:0
"""

keras TimeDistributed 描述相关推荐

  1. tensorflow.keras lstm 部分参数整理units input_size

    参考博客: (11条消息) 关于LSTM的units参数_LeoRainy的博客-CSDN博客_lstm units怎么设置 Keras LSTM的参数input_shape, units等的理解_y ...

  2. 最通俗易懂的YOLOv3原理及代码解析

    YOLO是一种端到端的目标检测模型.YOLO算法的基本思想是:首先通过特征提取网络提取输入特征,得到特定大小的特征图输出.输入图像分成13×13的网格单元,接着如果真实框中某个对象的中心坐标落在某个网 ...

  3. yolo系列之yolo v3【深度解析】——讲的挺好,原作者厉害的

    版权申明:转载和引用图片,都必须经过书面同意.获得留言同意即可 本文使用图片多为本人所画,需要高清图片可以留言联系我,先点赞后取图 这篇博文比较推荐的yolo v3代码是qwe的keras版本,复现比 ...

  4. YOLOv3原理及代码解析

    博主完整翻译了YOLOV1和YOLOV3的论文:请移步查看: YOLOV1:https://blog.csdn.net/taifengzikai/article/details/81988891 YO ...

  5. keras训练完以后怎么预测_使用Keras建立Wide Deep神经网络,通过描述预测葡萄酒价格...

    你能通过"优雅的单宁香"."成熟的黑醋栗香气"或"浓郁的酒香"这样的描述,预测葡萄酒的价格吗?事实证明,机器学习模型可以. 在这篇文章中,我 ...

  6. keras 中的keras.preprocessing、Embedding、GlobalMaxPooling1D()、 TimeDistributed

    本文以短问答为背景,串联几个keras下常用的函数 0:keras.preprocessing 该模块是对数据的预处理模块 https://blog.csdn.net/winter_python/ar ...

  7. TensorFlow tf.keras.layers.TimeDistributed

    对时间序列每个timestamp的向量空间做一个层 # as the first layer in a model model = Sequential() model.add(TimeDistrib ...

  8. Keras:基于Theano和TensorFlow的深度学习库

    原文链接:https://www.cnblogs.com/littlehann/p/6442161.html catalogue 引言 一些基本概念 Sequential模型 泛型模型 常用层 卷积层 ...

  9. 深度学习(莫烦 神经网络 lecture 3) Keras

    神经网络 & Keras 目录 神经网络 & Keras 目录 1.Keras简介 1.1 科普: 人工神经网络 VS 生物神经网络 1.2 什么是神经网络 (Neural Netwo ...

最新文章

  1. boot数据加解密 spring_SpringBoot 集成 Jasypt 对数据库加密以及踩坑
  2. maven 主工程 java_Maven创建Java Application工程(既jar包)
  3. jvm类加载、初始化
  4. [CQOI2016]手机号码 数位DP
  5. 物理化学 化学 动力学(上)
  6. promehteus 监控超时_05 . Prometheus监控Nginx
  7. Vue计算属性、方法、侦听器
  8. linux的mysql本地yum安装_linux下使用yum安装mysql
  9. Css 3d轮播样式
  10. Spring Ioc创建对象的方式
  11. mysql中交集,并集,差集,左连接,右连接
  12. GO、Rust 这些新一代高并发编程语言为何都极其讨厌共享内存?
  13. HDU 4558 剑侠情缘
  14. python矩阵转置_Python 矩阵转置的几种方法小结
  15. Windows Server 2016 RTM AVMA Keys
  16. 新百家姓前20位(附前300名)
  17. 男士不得不看的21种经典拍照姿势
  18. 厦理Java期末训练题【附带每题答案,非标准但可通过PTA】
  19. 离职和就职的原因(一)
  20. 使用GORM操作数据库

热门文章

  1. navacate连接不上mysql_解决navicat连接不上mysql服务器
  2. c++获取串口设备名称_RTThread PIN设备学习笔记
  3. 'nmake' 不是内部或外部命令,也不是可运行的程序 或批处理文件。
  4. 计算机考试只读,计算机基础考试试题-20210710011550.docx-原创力文档
  5. android toast_Android Toast
  6. angularjs双向绑定_AngularJS隔离范围双向绑定示例
  7. struts2登录注册示例_Struts2资源包和本地化示例
  8. asp.net 通过IHttpHandler开发接口
  9. 工作2年跳槽阿里,面试官会问哪些?(免费领取Java面试题)
  10. C语言基础教程之可变的参数