padded_batch API如下

padded_batch(batch_size, padded_shapes=None, padding_values=None, drop_remainder=False
)

注意参数drop_remainder用来约束最后一个batch是不是要丢掉,当这个batch样本数少于batch_size时,比如batch_size = 3,最后一个batch只有2个样本。默认是不丢掉

padded_batch是非常见的一个操作,比如对一个变长序列,通过padding操作将每个序列补成一样的长度。

特点:

1)padded_shapes使用默认值或者设置为-1,那么每个batch padding后每个维度就是跟这个者个batch的样本各个维度最大值保持一致

2)当shape固定为特定的size时,那么每个batch的shape就是一样的。如果


A = tf.data.Dataset.range(1, 6, output_type=tf.int32).map(lambda x: tf.fill([x], x))
for item in A.as_numpy_iterator():print(item)

结果如下:

[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]

padded_batch操作:

  • padded_shapes不设置或者设置为-1

padded_shapes设置为-1跟不设置该参数的效果一样,就是按每个batch里的最大的size去进行padding

B = A.padded_batch(2, padded_shapes = [-1])
for item in B.as_numpy_iterator():print("*" * 20)print(item)

打印结果如下:

可以看出事每个batch的里的shape保持一致,长度不够的补0

********************
[[1 0][2 2]]
********************
[[3 3 3 0][4 4 4 4]]
********************
[[5 5 5 5 5]]
  • padded_shapes设置为固定值

B = A.padded_batch(2, padded_shapes = [6])
for item in B.as_numpy_iterator():print("*" * 20)print(item)

打印结果:

可见每个batch的每个序列长度都是6,不足就补0

********************
[[1 0 0 0 0 0][2 2 0 0 0 0]]
********************
[[3 3 3 0 0 0][4 4 4 4 0 0]]
********************
[[5 5 5 5 5 0]]

TensorFlow的Dataset的padded_batch使用相关推荐

  1. Tensorflow关于Dataset的一般操作

    Dataset封装了很好的关于数据集的一些基本操作,在这里做一下总结.该对象的路径是:tensorflow.data.Dataset(这是1.4版本之后的)很大程度上参考了这篇博客 同时再推荐一个特别 ...

  2. tensorflow中Dataset.shuffle函数的buffer size的含义解读

    Reference tensorflow - Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle ...

  3. Tensorflow(02)——dataset与sequential

    目录 0.学习地址推荐 1. tensorflow.dataset数据集操作 1.1 自定义生成数据集 1.2 从CSV文件生成数据集 2. keras中的sequential模型 2.1 Seque ...

  4. tensorflow的Dataset对象报错ValueError: GraphDef cannot be larger than 2GB

    Dataset对象报错ValueError: GraphDef cannot be larger than 2GB或Cannot create a tensor proto whose content ...

  5. tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator

    tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练. 也有tensorflow中的 tf.data.DataSet的使用.并且由于是tensorflow框架 ...

  6. 使用PaddleFluid和TensorFlow训练序列标注模型

    专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...

  7. 聊天机器人chatbot搭建及思考(TensorFlow)(附代码)

    端到端的对话系统 环境 Python 3.7 TensorFlow 1.14 模型结构 使用seq2seq + attention 模型 NLP应用 词向量层 单词->实数向量 降低输入维度(o ...

  8. 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)

    个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...

  9. TensorFlow学习之LSTM ---机器翻译(Seq2Seq + Attention模型)

    一.理论知识 Seq2Seq模型的基本思想:使用一个循环神经网络读取输入句子,将这个句子的信息压缩到一个固定维度的编码中:再使用另一个循环神经网络读取这个编码,将其"解压"为目标语 ...

  10. tf.data.Dataset.from_tensor_slices 的用法

    将python列表和numpy数组转换成tensorflow的dataset 只有dataset才能被model.fit函数训练 import tensorflow as tf import nump ...

最新文章

  1. Kubernetes删除一直处于Terminating状态的namespace
  2. win 7 DHCP获取不到ip的解决办法
  3. [HDF]hdf-4.2.6类库的使用
  4. 黑马程序员的课程不如兄弟连兄弟会好
  5. Make Them Equal 埃氏筛法(1200)
  6. oracle数sqlplus,sqlplus查询oracle数据库数据容量
  7. datetimepicker获取年月日_bootstrap-datetimepicker 获取时间
  8. blob_buf,blobmsg内存图详解
  9. php页面重定向到html,javascript-页面重定向(PHP,HTML)
  10. 牛客网 牛客小白月赛2 H.武-最短路(Dijkstra)
  11. [数据结构] 非旋Treap
  12. keil如何下载STM32芯片F1/F4固件库
  13. Liunx Mint下载方案Aria2、web面板AriaNG搭建
  14. Hark的数据结构与算法练习之若领图排序ProxymapSort
  15. nginx worker shutting down状态超时退出配置
  16. 2022浙江理工校赛 问题 I: Tournament Seeding
  17. __getattr__和__setattr__
  18. React 详细教程
  19. 计算机操作系统-操作系统概述
  20. 30分钟初步掌握LaTeX--转自新浪博客

热门文章

  1. Qt开发 — WindowType详解
  2. 身份证二要素,帮你轻松搞定实名认证
  3. 开源的去马赛克神器 修复受损漫画无压力
  4. 基于Python制作的一个打砖块小游戏
  5. 产品读书《社群经济:移动互联网时代未来商业驱动力》
  6. python自动排版_你熟悉Python的代码规范吗?如何一键实现代码排版
  7. Linux命令详解之 cp
  8. java报错establishing_JDBC连接SQLServer时出现错误Error establishing socket.的解决。
  9. labview—电子表格文件读写
  10. ROS下里程计辅助2D激光雷达去运动畸变