机器学习入门0005 tensorflow_NMT模型

1.简介

nmt(Neural Machine Translation)是一个序列到序列的模型。可以用来做【聊天机器人】,【翻译】,【关键词提取】,【文章摘要】,【图像描述】等功能。用法简单,只需要安装Tensorflow1.4+ 版本即可运行。这个地址是Tensorflow 官方github https://github.com/tensorflow/nmt,里面内容很全面。

2.运行官网github的例子

下面内容是从斯坦福大学下载英语到越南语的平行语料库,然后通过nmt模型使用语料库训练一个 英语-越南语 或者 越南语-英语 的翻译模型。

2.1  下载平行语料库

在控制台输入这个命令nmt/scripts/download_iwslt15.sh /tmp/nmt_data,就可以下载小的平行语料库了。这个命令需要在nmt目录外边执行,会把数据下载到/tmp/nmt_data/下,其实是8个文件。在国内,由于网络原因,这个命令下载总是会中止,很蛋疼。可以借助其他下载工具来下载(比如某雷,某度云盘,某旋风等),这些文件的地址是:

https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/train.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2012.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/tst2013.vi
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.en
https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi/vocab.vi

下载完成后放在/tmp/nmt_data/下,也可以下载到自己喜欢的位置,但是需要修改之后的命令,会造成不必要的麻烦。

2.2 训练模型(英语-越南语模型)

在控制台中输入下面命令:

mkdir /tmp/nmt_model
python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train 
--dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 
--steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu

命令共有两行,第一行创建一个文件夹,用于存放训练的模型(若干个矩阵)

第二行很长只有一行,使用来进行训练的。其中的参数用来指明数据的位置,模型存放的位置,训练时的参数:总共训练12000步 每次训练100组 rnn共2层128个单元等。

注意:这个训练时间和机器性能有关可能达到1周时间 python 3.X的用户执行第二条命令,要这样写python3 -m nmt.nmt --src=en --tgt=v...

2.3 使用训练好的模型

在/tmp/下创建一个文件,名字是my_infer_file.en 里面写上几行英语:

i am a studenthow are you....i want to be a super programer

然后执行下面命令

python -m nmt.nmt --out_dir=/tmp/nmt_model --inference_input_file=/tmp/my_infer_file.en
--inference_output_file=/tmp/nmt_model/output_infer

很快可以执行完毕,然后到/tmp/nmt_model/目录下看这个文件output_infer,里面是对应的越南语翻译。

3.怎么用nmt

3.1 数据格式

仿照着之前下载的8个文件,做好数据对应,其中三个文件是英文的一句一行单词之间通过空格分隔,还有三个是越南语,格式和英语一样。vocab.vi 是越南语的词汇表取了常用的5000个词,vocab.en 是英语词汇表取了最常用的前5000个词语,但是它们前三个词语是<unk> 代表不认识的词语 <s>开始 </s>结束,这三个词必须在词汇表中否则nmt模型不能工作,具体原因官方github上有解释。

3.2 模型参数

python -m nmt.nmt --src=en --tgt=vi --vocab_prefix=/tmp/nmt_data/vocab  --train_prefix=/tmp/nmt_data/train --dev_prefix=/tmp/nmt_data/tst2012  --test_prefix=/tmp/nmt_data/tst2013 --out_dir=/tmp/nmt_model --num_train_steps=12000 --steps_per_stats=100 --num_layers=2 --num_units=128 --dropout=0.2 --metrics=bleu

这条命令中只是使用了个别的参数,还有一些其他有用的参数,如下:

forget_bias=1.0 这个是lstm的记忆力参数,取值范围在[0.0,1.0]越大代表记性越好

batch_size=128 这个代表每次训练128条数据,取值是正整数,如果太大,需要的内存会增大

learning_rate=1 学习率,正常情况下设置成小于等于1的值,默认值 1

num_gpus=1 机器中gpu个数,默认值是1

eos='</s>' 结束符配置成</s>,参考3.1 数据格式

sos='<s>' 同上,这两个参数没有配置的必要

src_max_len=50 源输入最大长度,针对我们训练的英语-越南语模型中,意思是每行最长接受50个英语单词,其余忽略

tgt_max_len=50 目标输出最大长度,默认值50.这个和上面的参数有时很有用,假设我们要做文章摘要,参数可以这样写--src_max_len=800 --tgt_max_len=150,这两个参数都会影响训练和预测速度,他们越大,模型跑的越慢。

share_vocab=False 这个意思是是否公用词汇表,假设做文章摘要,把这个设置成True。因为不是做翻译,输入和输出是同一种语言。

还有一些其他参数,不再列举,可以去源代码中nmt.py文件中查看。

3.3 训练一个聊天机器人(汉语)

3.3.1 准备好训练数据,开发数据,测试数据,汉语常用汉字表(前5000个)即可

仿照3.1中的数据,来准备训练数据。这次不是翻译数据,而是对话数据。比如:

train.src

你 好 !很 高 兴 认 识 你 。当 然 很 激 动 了。....

train.tgt

你 好 呀 !我 也 是 呢 , 你 有 没 有 很 激 动 。激 动 你 妹 啊 。....

vocab.src

<unk>
<s>
</s>
,
的
。
<sp>
一
0
是
1
、
在
有
不
了
2
人
中
大
国
年...

3.3.2 接下来进行训练

python -m nmt.nmt --src=src --tgt=tgt --vocab_prefix=/tmp/chat_data/vocab  --train_prefix=/tmp/chat_data/train --dev_prefix=/tmp/chat_data/dev  --test_prefix=/tmp/chat_data/test --out_dir=/tmp/nmt_model --num_train_steps=192000 --steps_per_stats=100 --num_layers=2 --num_units=256 --dropout=0.2 --metrics=bleu --src_max_len=80 --tgt_max_len=80 --share_vocab=True

经过漫长的训练,聊天模型训练完毕

3.3.3 集成到项目

有三种方案将训练的模型集成到项目中:

(1)对nmt进行部分修改,在项目代码中调用预测,使结果以文件形式展示,然后去文件中提取结果。优点:改动少,可以快速集成。 缺点:运行速度很慢

(2)对nmt进行部分修改,在项目代码中调用预测,只是要给nmt的源代码添加参数和返回值,返回值就是结果。 优点:改动少,可以快速集成。缺点:运行速度慢

(3)把nmt重构,写成一个对象,不要释放session,这样调用的速度会快一些。优点:运行速度快。 缺点:需要对nmt进行深入了解,开发周期长

前两种速度慢的原因是,每次运行都要加载大量的参数,加载词汇。第一种方案还多进行了两次io操作。

4.其他

4.1 数据问题

数据是比较难得到的,可以用自己qq聊天的数据,把聊天数据导出,然后做成nmt需要的数据格式。至于数据量,10万条以上吧,这个还没有详细的研究过。数据质量一定要好。很多公司是自己手动标注数据,这会耗费大量的时间,数据很难得。假设要做关键词抽取,可以通过爬虫爬取某浪新闻的带标签的文章。train_prefix dev_prefix这两个参数所指定的文件的数据量100-500条即可,不要太多了

4.2 内存问题

[src_max_len]  [gt_max_len]  [num_units]  [num_layers]  [batch_size] 这几个参数越大训练速度越慢,消耗内存越多。

词汇表越大,消耗的内存也越大,训练速度也会越慢。

4.3 训练问题

如果训练好久,控制台还没有一点反应,说明参数调的不好,机器性能跟不上,可以适当的降低参数。训练时间可能会很长,要有耐心。

nmt模型可以中断训练,下次输入和上次相同的参数,会接着继续执行。

机器学习入门0005 tensorflow_NMT模型相关推荐

  1. 机器学习入门实践——线性回归模型(波士顿房价预测)

    机器学习入门实践--线性回归模型(波士顿房价预测) 一.背景介绍 给定一个大小为 n n n的数据集 { y i , x i 1 , . . . , x i d } i = 1 n {\{y_{i}, ...

  2. 机器学习入门-一元线性回归模型的骚操作

  3. 机器学习入门笔记(一):模型性能评价与选择

    文章目录 一.训练误差与测试误差 1.1 基本概念 1.2 训练误差 1.3 泛化误差(测试误差) 1.4 过拟合 二.模型评估方法 2.1 留出法(hold-out) 2.2 正则化 2.3 交叉验 ...

  4. 机器学习入门(九):非监督学习:5种聚类算法+2种评估模型

    机器学习入门专栏其他章节: 机器学习入门(一)线性回归 机器学习入门(二)KNN 机器学习入门(三)朴素贝叶斯 机器学习入门(四)决策树 机器学习入门(五)集成学习 机器学习入门(六)支持向量机 机器 ...

  5. python 非线性回归_机器学习入门之菜鸟之路——机器学习之非线性回归个人理解及python实现...

    本文主要向大家介绍了机器学习入门之菜鸟之路--机器学习之非线性回归个人理解及python实现,通过具体的内容向大家展现,希望对大家学习机器学习入门有所帮助. 梯度下降:就是让数据顺着梯度最大的方向,也 ...

  6. 【机器学习入门】(13) 实战:心脏病预测,补充: ROC曲线、精确率--召回率曲线,附python完整代码和数据集

    各位同学好,经过前几章python机器学习的探索,想必大家对各种预测方法也有了一定的认识.今天我们来进行一次实战,心脏病病例预测,本文对一些基础方法就不进行详细解释,有疑问的同学可以看我前几篇机器学习 ...

  7. 【机器学习入门】(10) 特征工程:特征抽取,字典特征抽取、文本特征抽取,附完整python代码

    各位同学好,今天和大家介绍一下python机器学习中的特征工程.在将数据放入模型之前,需要对数据的一些特征进行特征抽取,方法有: (1) 字典特征抽取 DictVectorizer(),(2) 文本特 ...

  8. 十年公务员转行IT,自学AI三年,他淬炼出746页机器学习入门笔记

    整理 | Jane 编辑 | Just 出品 | AI科技大本营(ID:rgznai100) 近期,梁劲传来该笔记重大更新的消息.<机器学习--从入门到放弃>这本笔记的更新没有停止,在基于 ...

  9. 我的机器学习入门清单及路线!

    Datawhale干货 作者:桔了个仔,南洋理工大学,Datawhale成员 知乎:https://www.zhihu.com/people/huangzhe 这是我个人的机器学习入门清单及路线,所以 ...

最新文章

  1. 更好用的3D打印“活体”墨水来了,合成生物的新工具包!
  2. at指令 fpga_FPGA毕设系列 | 无线通信
  3. JavaScript实现多项式函数在某个点的评估算法(附完整源码)
  4. 基于机器视觉的马达孔直径中心距、齿数线序测量
  5. ctab法提取dna流程图_CTAB法提取植物基因组DNA过程图示
  6. 循环队列的创建Java_Java版-数据结构-队列(循环队列)
  7. 190321每日一句
  8. crypto.js 前端加解密
  9. 人工智能行业数据安全解决方案
  10. Python_Dataframe_去除重复数据
  11. 【华为机考题库学习】--算法篇(更新中……)
  12. springboott整合mybatis-plus和sharding-jdbc实现分库分表和读写分离(含完整项目代码)
  13. 如果获取token?
  14. Chapter 2、不使用代理
  15. KuaiRec 快手首个稠密为99.6%的数据集 相关介绍、下载、处理、使用方法
  16. 全自动过滤器:全自动管道过滤器工作原理
  17. 【DDD落地实践系列】DDD 领域驱动设计落地实践:六步拆解 DDD
  18. 7-4 王小二分饼 (15分) __C++
  19. 计算机维修和应用有什么区别,计算器与计算机的区别有哪些,计算器常见故障以及维修方法...
  20. 计算机网卡号里面有以太网,以太网卡

热门文章

  1. CloseHandle()、TerminateThread()、ExitThread()的区别
  2. lotus notes 闪退_黑鲨研习win7系统Lotus Notes邮箱闪退的技巧
  3. 扩散模型(Diffusion Models)
  4. 微信小程序(游戏)----五子棋(AI篇)
  5. 从“点卡改月卡”谈电子游戏产业中的道德困境
  6. 梯度下降算法与Normal equation
  7. 中日韩大字符集文字编码的比较研究
  8. 高防服务器ddos压力测试的工具推荐
  9. 手机浏览器加载不出来css,如何解决浏览器不加载css文件的问题
  10. emu8086 第一个程序