个人博客:http://www.chenjianqu.com/

原文链接:http://www.chenjianqu.com/show-40.html

前一篇文章中使用简单的seq2seq搭建了单词级聊天机器人《聊天机器人-基于QQ聊天记录训练》,里面也简单介绍了seq2seq的原理。这里尝试用seq2seq做一下字符级的翻译:英语->粤语。

seq2seq的训练过程是'teacher forcing'的,英法翻译举例如下:

推断过程不实用'teacher forcing',如下:

因此seq2seq的训练和预测过程的模型是不一样的,它们使用相同的层。

训练模型的架构:

用tensorboard画出来的结构如下:

推断模型:

编码器

解码器

加载数据集

数据集来源:http://www.manythings.org/anki/.

import os
num_samples=3200
input_texts = []
target_texts = []
input_characters = set()
target_characters = set()
data_path=r'D:\NLP\dataset\机器翻译\yue-eng\yue.txt'with open(data_path, 'r', encoding='utf-8') as f:lines = f.read().split('\n')for line in lines[: min(num_samples, len(lines) - 1)]:input_text, target_text = line.split('\t')target_text = '\t' + target_text + '\n'input_texts.append(input_text)#输入target_texts.append(target_text)#输出for char in input_text:if char not in input_characters:input_characters.add(char)#输入字符集for char in target_text:if char not in target_characters:target_characters.add(char)#输出字符集#排序
input_characters = sorted(list(input_characters))
target_characters = sorted(list(target_characters))
#字符数
num_encoder_tokens = len(input_characters)
num_decoder_tokens = len(target_characters)
#输出输出的最大句子长度
max_encoder_seq_length = max([len(txt) for txt in input_texts])
max_decoder_seq_length = max([len(txt) for txt in target_texts])

数据预处理

将文本映射为向量。

import numpy as np#字典:把字符映射为数字
input_token_index = dict([(char, i) for i, char in enumerate(input_characters)])
target_token_index = dict([(char, i) for i, char in enumerate(target_characters)])
#字典:把数字映射回字符
rinput_dict = dict( (i, char) for char, i in input_token_index.items())
rtarget_dict = dict((i, char) for char, i in target_token_index.items())#将数据集映射为向量
encoder_input_data = np.zeros((len(input_texts), max_encoder_seq_length, num_encoder_tokens),dtype='float32')
decoder_input_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens),dtype='float32')
decoder_target_data = np.zeros((len(input_texts), max_decoder_seq_length, num_decoder_tokens),dtype='float32')
for i, (input_text, target_text) in enumerate(zip(input_texts, target_texts)):for t, char in enumerate(input_text):encoder_input_data[i, t, input_token_index[char]] = 1.for t, char in enumerate(target_text):# decoder_target_data is ahead of decoder_input_data by one timestepdecoder_input_data[i, t, target_token_index[char]] = 1.if t > 0:# decoder_target_data will be ahead by one timestep# and will not include the start character.decoder_target_data[i, t - 1, target_token_index[char]] = 1.
print(encoder_input_data.shape)
print(decoder_input_data.shape)
print(decoder_target_data.shape)

定义模型

lstm的cell数量为256维。

from keras.models import Model
from keras.layers import Input, LSTM, Dense
latent_dim = 256
encoder_inputs = Input(shape=(None, num_encoder_tokens))
encoder = LSTM(latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder(encoder_inputs)
encoder_states = [state_h, state_c]decoder_inputs = Input(shape=(None, num_decoder_tokens))
decoder_lstm = LSTM(latent_dim, return_sequences=True, return_state=True)
decoder_outputs, _, _ = decoder_lstm(decoder_inputs,initial_state=encoder_states)
decoder_dense_1 = Dense(int(num_decoder_tokens/4), activation='relu')
decoder_outputs_1 = decoder_dense_1(decoder_outputs)
decoder_dense = Dense(num_decoder_tokens, activation='softmax')
decoder_outputs = decoder_dense(decoder_outputs_1)model = Model([encoder_inputs, decoder_inputs], decoder_outputs)
model.summary()
plot_model(model,to_file='eng2yue_model.png',show_shapes=True)

训练并保存模型

设置回调函数,在训练过程中保存最优模型。

import kerasbatch_size = 16
epochs = 100 callbacks_list=[keras.callbacks.EarlyStopping(monitor='acc',patience=10,),keras.callbacks.ModelCheckpoint(filepath='eng2yue_model_checkpoint.h5',monitor='val_loss',#如果val_loss不改善,则不需要覆盖模型文件save_best_only=True),keras.callbacks.TensorBoard(log_dir='my_log_dir',histogram_freq=1#每一轮之后记录直方图)
]model.compile(optimizer='rmsprop', loss='categorical_crossentropy',metrics=['acc'])model.fit([encoder_input_data, decoder_input_data], decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.1,callbacks=callbacks_list)model.save('eng2yue_model.h5')

训练结果:

精度很低,这是训练集太小的缘故。

定义预测模型

编码器和解码器分开。

#编码器
encoder_model = Model(encoder_inputs, encoder_states)
#解码器
decoder_state_input_h = Input(shape=(latent_dim,))
decoder_state_input_c = Input(shape=(latent_dim,))
decoder_states_inputs = [decoder_state_input_h, decoder_state_input_c]
decoder_outputs, state_h, state_c = decoder_lstm(decoder_inputs, initial_state=decoder_states_inputs)
decoder_states = [state_h, state_c]
decoder_outputs = decoder_dense(decoder_outputs)
decoder_model = Model([decoder_inputs] + decoder_states_inputs,[decoder_outputs] + decoder_states
)

预测

测试结果不是很理想,从训练集的验证精度上也看得出来。

#输入文本
text='I get it.'text_seq = np.zeros((1, max_encoder_seq_length, num_encoder_tokens),dtype='float32')
for t, char in enumerate(text):text_seq[0, t, input_token_index[char]] = 1.result = ''
states_value = encoder_model.predict(text_seq)#编码输入得到状态变量target_seq = np.zeros((1, 1, num_decoder_tokens))
target_seq[0, 0, target_token_index['\t']] = 1.#初始化解码器输入向量
stop_condition = Falsewhile not stop_condition:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)sampled_token_index = np.argmax(output_tokens[0, -1, :])sampled_char = rtarget_dict[sampled_token_index]result += sampled_char#退出循环if (sampled_char == '\n' or len(result) > max_decoder_seq_length):stop_condition = True#更新decoder输入target_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# 更新状态states_value = [h, c]
print(result)

参考文献

[1]Francois Chollet.A ten-minute introduction to sequence-to-sequence learning in Keras.https://blog.keras.io/a-ten-minute-introduction-to-sequence-to-sequence-learning-in-keras.html. 2017-09-29

字符级Seq2Seq-英语粤语翻译的简单实现相关推荐

  1. 当莎士比亚遇见Google Flax:教你用​字符级语言模型和归递神经网络写“莎士比亚”式句子...

    作者 | Fabian Deuser 译者 | 天道酬勤 责编 | Carol 出品 | AI科技大本营(ID:rgznai100) 有些人生来伟大,有些人成就伟大,而另一些人则拥有伟大. -- 威廉 ...

  2. pytorch dropout_手把手带你使用字符级RNN生成名字 | PyTorch

    作者 | News 编辑 | 奇予纪 出品 | 磐创AI团队出品 [磐创AI 导读]:本篇文章讲解了PyTorch专栏的第五章中的使用字符级RNN生成名字.查看专栏历史文章,请点击下方蓝色字体进入相应 ...

  3. CS224n自然语言处理(三)——问答系统、字符级模型和自然语言生成

    文章目录 一.问答系统 1.Stanford Question Answering Dataset (SQuAD) 2.Stanford Attentive Reader Stanford Atten ...

  4. 自然语言处理(十四):从零开始构建使用注意力机制的Seq2Seq网络实现翻译任务

    自然语言处理笔记总目录 本案例取自PyTorch官网的NLP FROM SCRATCH: TRANSLATION WITH A SEQUENCE TO SEQUENCE NETWORK AND ATT ...

  5. 我主修计算机科学专业英语翻译,计算机专业英语教程翻译.docx

    计算机专业英语教程翻译.docx 1.1细看处理器与主存储器我们已经了解到所有计算机有类似的能力且能执行相同的功能,尽管一些可能比其他的快.我们知道电脑系统有输入.输出.仓储.加工的元件,还知道处理器 ...

  6. 【Mo 人工智能技术博客】使用 Seq2Seq 实现中英文翻译

    1. 介绍 1.1 Deep NLP 自然语言处理(Natural Language Processing,NLP)是计算机科学.人工智能和语言学领域交叉的分支学科,主要让计算机处理或理解自然语言,如 ...

  7. 计算机病毒的依赖性,计算机辅助框架英语依赖性翻译研究-计算机病毒论文-计算机论文.docx...

    计算机辅助框架英语依赖性翻译研究-计算机病毒论文-计算机论文 --文章均为WORD文档,下载后可直接编辑使用亦可打印-- 摘要:英语翻译的核心在于对上下文依赖性的分析,本文介绍了最小依赖翻译(Mini ...

  8. 她没有你会使用计算机英语,2017中考英语句子翻译题解题方法加真题演练附答案...

    下面的内容,都是官方语言,我很少在课上这样讲解翻译题. 这里只是想作为参考让同学们系统的了解下中考英语翻译问题. 我对翻译的理解是:语法+词汇,短语,句型.然后不要自己去编造句子. 而是大量的输入固定 ...

  9. 日常英语---五、英语句子翻译和读的选择什么工具好

    日常英语---五.英语句子翻译和读的选择什么工具好 一.总结 一句话总结:英语翻译用有道+google啊,找相似句啊,读的话要google 翻译:有道+google 找相似句 读:google 1.后 ...

最新文章

  1. 使用metasploit中Evasion模块
  2. 让机器听懂世界,触及人类梦想还有多远?
  3. Spring从菜鸟到高手(二)AOP的真正实现
  4. python从入门到精通书-Python从入门到精通,跟着《这本书》学就够了?
  5. python集合属性方法运算_Python基础__字典、集合、运算符
  6. 最长回文子串 hihocode 1032 hdu 3068
  7. 【算法基础】漫画:什么是 “跳表” ?
  8. MySQL的内连和外连
  9. php基础教程 第八步循环补充
  10. 前端和后端哪个工资高_新媒体运营和网络运维哪个好,哪个工资待遇高,门槛低?...
  11. day16 java的访问控制权限
  12. 关于Python ord()和chr()返回ASCII码和Unicode码的看法
  13. 新建cordova应用,插件开发教程系列(总目录)
  14. 热量的传递 —— 热辐射
  15. ACL'21 | 对比学习论文一句话总结
  16. ssis oracle配置,从SSIS包SQL Server连接Oracle数据库
  17. 【运筹学】线性规划数学模型 ( 求解基矩阵示例 | 矩阵的可逆性 | 线性规划表示为 基矩阵 基向量 非基矩阵 非基向量 形式 )
  18. Mac 下修改文件的 md5 值
  19. 【报错】unknown error: DevToolsActivePort file doesn‘t exis
  20. php empty 和空字符串区别

热门文章

  1. maven添加子工程_Maven建立父子项目和跨项目调用内容的步骤—佳佳小白
  2. P3275 [SCOI2011]糖果
  3. react native ios 上架
  4. 终端设备文件与进程之间的关系
  5. android studio shell 命令行自动打包(mac 平台)
  6. JavaScript四大家族之scroll家族
  7. java 数据库查询Date类型字段 没有了时分秒 全为 00 的解决办法
  8. MVVM(Knockout.js)的新尝试:多个Page,一个ViewModel
  9. Video-Touch:手势识别实现多用户远程控制机器人
  10. 每年扫码千亿次!微信官方开源了自家优化的二维码引擎!3行代码让你拥有微信扫码能力...