TensorFlow的Dataset的padded_batch使用
padded_batch API如下
padded_batch(batch_size, padded_shapes=None, padding_values=None, drop_remainder=False
)
注意参数drop_remainder用来约束最后一个batch是不是要丢掉,当这个batch样本数少于batch_size时,比如batch_size = 3,最后一个batch只有2个样本。默认是不丢掉
padded_batch是非常见的一个操作,比如对一个变长序列,通过padding操作将每个序列补成一样的长度。
特点:
1)padded_shapes使用默认值或者设置为-1,那么每个batch padding后每个维度就是跟这个者个batch的样本各个维度最大值保持一致
2)当shape固定为特定的size时,那么每个batch的shape就是一样的。如果
A = tf.data.Dataset.range(1, 6, output_type=tf.int32).map(lambda x: tf.fill([x], x))
for item in A.as_numpy_iterator():print(item)
结果如下:
[1]
[2 2]
[3 3 3]
[4 4 4 4]
[5 5 5 5 5]
padded_batch操作:
padded_shapes不设置或者设置为-1
padded_shapes设置为-1跟不设置该参数的效果一样,就是按每个batch里的最大的size去进行padding
B = A.padded_batch(2, padded_shapes = [-1])
for item in B.as_numpy_iterator():print("*" * 20)print(item)
打印结果如下:
可以看出事每个batch的里的shape保持一致,长度不够的补0
********************
[[1 0][2 2]]
********************
[[3 3 3 0][4 4 4 4]]
********************
[[5 5 5 5 5]]
padded_shapes设置为固定值
B = A.padded_batch(2, padded_shapes = [6])
for item in B.as_numpy_iterator():print("*" * 20)print(item)
打印结果:
可见每个batch的每个序列长度都是6,不足就补0
********************
[[1 0 0 0 0 0][2 2 0 0 0 0]]
********************
[[3 3 3 0 0 0][4 4 4 4 0 0]]
********************
[[5 5 5 5 5 0]]
TensorFlow的Dataset的padded_batch使用相关推荐
- Tensorflow关于Dataset的一般操作
Dataset封装了很好的关于数据集的一些基本操作,在这里做一下总结.该对象的路径是:tensorflow.data.Dataset(这是1.4版本之后的)很大程度上参考了这篇博客 同时再推荐一个特别 ...
- tensorflow中Dataset.shuffle函数的buffer size的含义解读
Reference tensorflow - Meaning of buffer_size in Dataset.map , Dataset.prefetch and Dataset.shuffle ...
- Tensorflow(02)——dataset与sequential
目录 0.学习地址推荐 1. tensorflow.dataset数据集操作 1.1 自定义生成数据集 1.2 从CSV文件生成数据集 2. keras中的sequential模型 2.1 Seque ...
- tensorflow的Dataset对象报错ValueError: GraphDef cannot be larger than 2GB
Dataset对象报错ValueError: GraphDef cannot be larger than 2GB或Cannot create a tensor proto whose content ...
- tensorflow的数据读取 tf.data.DataSet、tf.data.Iterator
tensorflow的工程有使用python的多进程读取数据,然后给feed给神经网络进行训练. 也有tensorflow中的 tf.data.DataSet的使用.并且由于是tensorflow框架 ...
- 使用PaddleFluid和TensorFlow训练序列标注模型
专栏介绍:Paddle Fluid 是用来让用户像 PyTorch 和 Tensorflow Eager Execution 一样执行程序.在这些系统中,不再有模型这个概念,应用也不再包含一个用于描述 ...
- 聊天机器人chatbot搭建及思考(TensorFlow)(附代码)
端到端的对话系统 环境 Python 3.7 TensorFlow 1.14 模型结构 使用seq2seq + attention 模型 NLP应用 词向量层 单词->实数向量 降低输入维度(o ...
- 【TensorFlow实战笔记】对于TED(en-zh)数据集进行Seq2Seq模型实战,以及对应的Attention机制(tf保存模型读取模型)
个人公众号 AI蜗牛车 作者是南京985AI硕士,CSDN博客专家,研究方向主要是时空序列预测和时间序列数据挖掘,获国家奖学金,校十佳大学生,省优秀毕业生,阿里天池时空序列比赛rank3.公众号致力于 ...
- TensorFlow学习之LSTM ---机器翻译(Seq2Seq + Attention模型)
一.理论知识 Seq2Seq模型的基本思想:使用一个循环神经网络读取输入句子,将这个句子的信息压缩到一个固定维度的编码中:再使用另一个循环神经网络读取这个编码,将其"解压"为目标语 ...
- tf.data.Dataset.from_tensor_slices 的用法
将python列表和numpy数组转换成tensorflow的dataset 只有dataset才能被model.fit函数训练 import tensorflow as tf import nump ...
最新文章
- Kubernetes删除一直处于Terminating状态的namespace
- win 7 DHCP获取不到ip的解决办法
- [HDF]hdf-4.2.6类库的使用
- 黑马程序员的课程不如兄弟连兄弟会好
- Make Them Equal 埃氏筛法(1200)
- oracle数sqlplus,sqlplus查询oracle数据库数据容量
- datetimepicker获取年月日_bootstrap-datetimepicker 获取时间
- blob_buf,blobmsg内存图详解
- php页面重定向到html,javascript-页面重定向(PHP,HTML)
- 牛客网 牛客小白月赛2 H.武-最短路(Dijkstra)
- [数据结构] 非旋Treap
- keil如何下载STM32芯片F1/F4固件库
- Liunx Mint下载方案Aria2、web面板AriaNG搭建
- Hark的数据结构与算法练习之若领图排序ProxymapSort
- nginx worker shutting down状态超时退出配置
- 2022浙江理工校赛 问题 I: Tournament Seeding
- __getattr__和__setattr__
- React 详细教程
- 计算机操作系统-操作系统概述
- 30分钟初步掌握LaTeX--转自新浪博客