tensorflow nmt的数据预处理过程  

在tensorflow/nmt项目中,训练数据和推断数据的输入使用了新的Dataset API,应该是tensorflow 1.2之后引入的API,方便数据的操作。如果你还在使用老的Queue和Coordinator的方式,建议升级高版本的tensorflow并且使用Dataset API。

本教程将从训练数据推断数据两个方面,详解解析数据的具体处理过程,你将看到文本数据如何转化为模型所需要的实数,以及中间的张量的维度是怎么样的,batch_size和其他超参数又是如何作用的。

训练数据的处理

先来看看训练数据的处理。训练数据的处理比推断数据的处理稍微复杂一些,弄懂了训练数据的处理过程,就可以很轻松地理解推断数据的处理。
训练数据的处理代码位于nmt/utils/iterator_utils.py文件内的get_iterator函数。我们先来看看这个函数所需要的参数是什么意思:

参数 解释
src_dataset 源数据集
tgt_dataset 目标数据集
src_vocab_table 源数据单词查找表,就是个单词和int类型数据的对应表
tgt_vocab_table 目标数据单词查找表,就是个单词和int类型数据的对应表
batch_size 批大小
sos 句子开始标记
eos 句子结尾标记
random_seed 随机种子,用来打乱数据集的
num_buckets 桶数量
src_max_len 源数据最大长度
tgt_max_len 目标数据最大长度
num_parallel_calls 并发处理数据的并发数
output_buffer_size 输出缓冲区大小
skip_count 跳过数据行数
num_shards 将数据集分片的数量,分布式训练中有用
shard_index 数据集分片后的id
reshuffle_each_iteration 是否每次迭代都重新打乱顺序

上面的解释,如果有不清楚的,可以查看我之前一片介绍超参数的文章:
tensorflow_nmt的超参数详解

该函数处理训练数据的主要代码如下:

if not output_buffer_size:output_buffer_size = batch_size * 1000src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))src_tgt_dataset = src_tgt_dataset.shard(num_shards, shard_index)if skip_count is not None:src_tgt_dataset = src_tgt_dataset.skip(skip_count)src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size, random_seed, reshuffle_each_iteration)src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (tf.string_split([src]).values, tf.string_split([tgt]).values),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Filter zero length input sequences.src_tgt_dataset = src_tgt_dataset.filter(lambda src, tgt: tf.logical_and(tf.size(src) > 0, tf.size(tgt) > 0))if src_max_len:src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src[:src_max_len], tgt),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)if tgt_max_len:src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src, tgt[:tgt_max_len]),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Convert the word strings to ids.  Word strings that are not in the# vocab get the lookup table's default_value integer.src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Create a tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,tf.concat(([tgt_sos_id], tgt), 0),tf.concat((tgt, [tgt_eos_id]), 0)),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Add in sequence lengths.src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)

我们逐步来分析,这个过程到底做了什么,数据张量又是如何变化的。

如何对齐数据

num_buckets到底起什么作用

num_buckets起作用的代码如下:  

 if num_buckets > 1:def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):# Calculate bucket_width by maximum source sequence length.# Pairs with length [0, bucket_width) go to bucket 0, length# [bucket_width, 2 * bucket_width) go to bucket 1, etc.  Pairs with length# over ((num_bucket-1) * bucket_width) words all go into the last bucket.if src_max_len:bucket_width = (src_max_len + num_buckets - 1) // num_bucketselse:bucket_width = 10# Bucket sentence pairs by the length of their source sentence and target# sentence.bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)return tf.to_int64(tf.minimum(num_buckets, bucket_id))def reduce_func(unused_key, windowed_data):return batching_func(windowed_data)batched_dataset = src_tgt_dataset.apply(tf.contrib.data.group_by_window(key_func=key_func, reduce_func=reduce_func, window_size=batch_size))

Tensorflow nmt的数据预处理过程相关推荐

  1. 谷歌发布全新TensorFlow库“tf.Transform” 简化机器学习数据预处理过程

    在实际的机器学习开发中,开发者通常需要对数据集进行大量的耗时费力的预处理过程,以适应各种不同标准的机器学习模型(例如神经网络).这些预处理过程根据待解问题的不同和原始数据的组织形式而各不相同,包括不同 ...

  2. 时间序列预测——深度好文,ARIMA是最难用的(数据预处理过程不适合工业应用),线性回归模型简单适用,预测趋势很不错,xgboost的话,不太适合趋势预测,如果数据平稳也可以使用。...

    补充:https://bmcbioinformatics.biomedcentral.com/articles/10.1186/1471-2105-15-276 如果用arima的话,还不如使用随机森 ...

  3. 数据预处理过程中处理方法

    1.初次清洗,DataFrame中存在数值型字段和非数值型字段技巧 1)将训练集和测试集数据进行合并,统一处理 dataset=pd.concat([train_data,test_data],axi ...

  4. 记录 之 tensorflow常见的数据预处理操作

    今天我们简单的介绍几个比较常用的函数: 一.裁剪及pad操作 0.tf.image.random_crop(image, shape)#按shape随机裁剪 #原始图像 #裁剪形状,[a,b,..] ...

  5. dataframe 修改某列_python dataframe操作大全数据预处理过程(dataframe、md5)

    hive表的特征选择,不同表之间的join 训练数据.测试数据的分开保存 使用pandas进行数据处理 显示所有列:pd.set_option('display.max_columns', None) ...

  6. dataframe修改列名_python dataframe操作大全数据预处理过程(dataframe、md5)

    hive表的特征选择,不同表之间的join 训练数据.测试数据的分开保存 使用pandas进行数据处理 显示所有列:pd.set_option('display.max_columns', None) ...

  7. 深度之眼Pytorch打卡(九):Pytorch数据预处理——预处理过程与数据标准化(transforms过程、Normalize原理、常用数据集均值标准差与数据集均值标准差计算)

    前言   前段时间因为一些事情没有时间或者心情学习,现在两个多月过去了,事情结束了,心态也调整好了,所以又来接着学习Pytorch.这篇笔记主要是关于数据预处理过程.数据集标准化与数据集均值标准差计算 ...

  8. 英伟达DALI加速技巧:使数据预处理比原生PyTorch运算速度快4倍

    关注上方"深度学习技术前沿",选择"星标公众号", 资源干货,第一时间送达! 你的数据处理影响整个训练速度,如果加上英伟达 DALI 库,处理速度比原生 PyT ...

  9. 英伟达DALI加速技巧:让数据预处理速度比原生PyTorch快4倍

    点击我爱计算机视觉标星,更快获取CVML新技术 本文转载自机器之心. 选自towardsdatascience 作者:Pieterluitjens 机器之心编译 参与:一鸣.嘉明.思 你的数据处理影响 ...

最新文章

  1. python 多线程爬虫 实例
  2. 【通俗理解线性代数】 -- 内积与相关
  3. java实现md5加密示例
  4. Python中metaclass解释
  5. Java多线程:线程间通信之Lock
  6. 马斯克开始行动:下调Twitter Blue订阅费 禁止广告
  7. Windows 底层驱动级 Anti-Rootkit 工具 ScDetective 源代码
  8. 拓端tecdat|在R语言中用模拟探索回归的P值
  9. matlab做误差棒图,matlab绘制误差棒
  10. 小程序 | 云数据库模糊查询
  11. 桌面计算机图标名字变了,电脑桌面图标突然变成未知图标怎么回事
  12. 度中心度(Degree Centrality)
  13. excel 数组公式
  14. 【课堂笔记精选】为了能够用“Unity”软件做游戏,我要从最基础的开始复习JavaScript...
  15. 【Spring源码三千问】Advice、Advisor、Advised都是什么接口?
  16. IOS界面push跳转后navigationController不显示
  17. NGUI------UIToggle
  18. 双向广搜-HDU1401 Solitaire
  19. 如何调整图片的dpi?如何修改图片分辨率?
  20. 每日新闻:Gartner报告:这五大新兴科技趋势将模糊人机界限;阿里云肖力:阿里云安全三大“核驱动: 可信、智能、合规...

热门文章

  1. 华硕fl8000u是什么型号_华硕fl8000u怎么样 华硕笔记本fl8000u配置是什么
  2. C/C++编程学习 - 第1周 ⑦ 头文件、强制类型转换、递归
  3. 二维码可以用哪款条码软件打印?
  4. 与学生谈“编程”和“考试”
  5. 浅谈软件开发工具CASE在软件项目开发中发挥的作用认识
  6. noip2014:螺旋矩阵_网页设计:2014年值得关注的20个最热门趋势
  7. html页面打印调用jqprint.js
  8. Java制作报表系统流程_finereport报表制作详细过程
  9. 学习操作系统(4)——进程
  10. 再获肯定,云畅科技旗下腾云畅翼入榜2021腾讯云启创新生态企业年度榜单