Keras实现英文到中文机器翻译 seq2seq+LSTM
该模型实现的是英文到中文的翻译,下图为了更好展示模型架构借用大佬的图(这里没有用到Embeddings):
本文完整代码:Github
目录
一、处理文本数据
1.获得翻译前后的句子
2.创建关于 字符-index 和 index -字符的字典
3.对中文和英文句子One-Hot编码
二、建立模型
三、decoder预测每个字符
四、训练模型
五、展示
整体由encoder和decoder两大部分组成,每部分都有一个LSTM网络,其中encoder输入原始的句子,输出状态向量;decoder输入的是含有开始符号的翻译后的句子,输出目标句子。
具体步骤为:
1.encoder将输入序列进行编码成状态向量
2.decoder从第一个字符开始预测
3.向decoder喂入状态向量(state_h,state_c)和累计包含之前预测字符的独热编码(第一次的状态向量来自于encoder,后来预测每个目标序列的每个字符时,状态向量来源于decoder,predict出来的状态向量)
4.使用argmax预测对下一个字符的位置,再根据字典查找到对应的字符
5.将上一步骤中的字符添加到 target sequence中
6.直到预测到我们指定结束字符时结束循环
一、处理文本数据
这一步骤包含对原数据进行分割获得翻译前、后的句子,生成字符的字典,最后对翻译前后的句子进行One-Hot编码,便于处理数据。
1.获得翻译前后的句子
先看一下原数据的样式:
首先导入需要的库:
代码1.1.1
import pandas as pd
import numpy as np
from keras.layers import Input, LSTM, Dense, merge,concatenate
from keras.optimizers import Adam, SGD
from keras.models import Model,load_model
from keras.utils import plot_model
from keras.models import Sequential#定义神经网络的参数
NUM_SAMPLES=3000 #训练样本的大小
batch_size = 64 #一次训练所选取的样本数
epochs = 100 #训练轮数
latent_dim = 256 #LSTM 的单元个数
用pandas读取文件,然后我们只要前两列内容
代码1.1.2
data_path='data/cmn.txt'
df=pd.read_table(data_path,header=None).iloc[:NUM_SAMPLES,0:2]
#添加标题栏
df.columns=['inputs','targets']
#每句中文举手加上‘\t’作为起始标志,句末加上‘\n’终止标志
df['targets']=df['targets'].apply(lambda x:'\t'+x+'\n')
最后是这样的形式:
然后分别把英文和中文数据转换为list形式
代码1.1.3
#获取英文、中文各自的列表
input_texts=df.inputs.values.tolist()
target_texts=df.targets.values.tolist()#确定中英文各自包含的字符。df.unique()直接取sum可将unique数组中的各个句子拼接成一个长句子
input_characters = sorted(list(set(df.inputs.unique().sum())))
target_characters = sorted(list(set(df.targets.unique().sum())))#英文字符中不同字符的数量
num_encoder_tokens = len(input_characters)
#中文字符中不同字符的数量
num_decoder_tokens = len(target_characters)
#最大输入长度
INUPT_LENGTH = max([ len(txt) for txt in input_texts])
#最大输出长度
OUTPUT_LENGTH = max([ len(txt) for txt in target_texts])
2.创建关于 字符-index 和 index -字符的字典
代码1.2.1
input_token_index = dict( [(char, i)for i, char in enumerate(input_characters)] )
target_token_index = dict( [(char, i) for i, char in enumerate(target_characters)] )reverse_input_char_index = dict([(i, char) for i, char in enumerate(input_characters)])
reverse_target_char_index = dict([(i, char) for i, char in enumerate(target_characters)])
3.对中文和英文句子One-Hot编码
代码1.3.1
#需要把每条语料转换成LSTM需要的三维数据输入[n_samples, timestamp, one-hot feature]到模型中
encoder_input_data =np.zeros((NUM_SAMPLES,INUPT_LENGTH,num_encoder_tokens))
decoder_input_data =np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,num_decoder_tokens))
decoder_target_data = np.zeros((NUM_SAMPLES,OUTPUT_LENGTH,num_decoder_tokens))for i,(input_text,target_text) in enumerate(zip(input_texts,target_texts)):for t,char in enumerate(input_text):encoder_input_data[i,t,input_token_index[char]]=1.0for t, char in enumerate(target_text):decoder_input_data[i,t,target_token_index[char]]=1.0if t > 0:# decoder_target_data 不包含开始字符,并且比decoder_input_data提前一步decoder_target_data[i, t-1, target_token_index[char]] = 1.0
二、建立模型
代码2.1
#定义编码器的输入encoder_inputs=Input(shape=(None,num_encoder_tokens))#定义LSTM层,latent_dim为LSTM单元中每个门的神经元的个数,return_state设为True时才会返回最后时刻的状态h,cencoder=LSTM(latent_dim,return_state=True)# 调用编码器,得到编码器的输出(输入其实不需要),以及状态信息 state_h 和 state_cencoder_outputs,state_h,state_c=encoder(encoder_inputs)# 丢弃encoder_outputs, 我们只需要编码器的状态encoder_state=[state_h,state_c]#定义解码器的输入decoder_inputs=Input(shape=(None,num_decoder_tokens))decoder_lstm=LSTM(latent_dim,return_state=True,return_sequences=True)# 将编码器输出的状态作为初始解码器的初始状态decoder_outputs,_,_=decoder_lstm(decoder_inputs,initial_state=encoder_state)#添加全连接层decoder_dense=Dense(num_decoder_tokens,activation='softmax')decoder_outputs=decoder_dense(decoder_outputs)#定义整个模型model=Model([encoder_inputs,decoder_inputs],decoder_outputs)
model的模型图:

其中decoder在每个timestep有三个输入分别是来自encoder的两个状态向量state_h,state_c和经过One-Hot编码的中文序列
代码2.2
#定义encoder模型,得到输出encoder_statesencoder_model=Model(encoder_inputs,encoder_state)decoder_state_input_h=Input(shape=(latent_dim,))decoder_state_input_c=Input(shape=(latent_dim,))decoder_state_inputs=[decoder_state_input_h,decoder_state_input_c]# 得到解码器的输出以及中间状态decoder_outputs,state_h,state_c=decoder_lstm(decoder_inputs,initial_state=decoder_state_inputs)decoder_states=[state_h,state_c]decoder_outputs=decoder_dense(decoder_outputs)decoder_model=Model([decoder_inputs]+decoder_state_inputs,[decoder_outputs]+decoder_states)plot_model(model=model,show_shapes=True)plot_model(model=encoder_model,show_shapes=True)plot_model(model=decoder_model,show_shapes=True)return model,encoder_model,decoder_model
encoder的模型图:

decoder的模型图:
三、decoder预测每个字符
首先encoder根据输入序列生成状态向量states_value 并结合由包含开始字符"\t"的编码一并传入到decoder的输入层,预测出下个字符的位置sampled_token_index ,将新预测到的字符添加到target_seq中再进行One-Hot编码,用预测上个字符生成的状态向量作为新的状态向量。
以上过程在while中不断循环,直到预测到结束字符"\n",结束循环,返回翻译后的句子。从下图可直观的看出对于decoder部分是一个一个生成翻译后的序列,注意蓝线的指向是target_squence,它是不断被填充的。
代码3.1
def decode_sequence(input_seq,encoder_model,decoder_model):# 将输入序列进行编码生成状态向量states_value = encoder_model.predict(input_seq)# 生成一个size=1的空序列target_seq = np.zeros((1, 1, num_decoder_tokens))# 将这个空序列的内容设置为开始字符target_seq[0, 0, target_token_index['\t']] = 1.# 进行字符恢复# 简单起见,假设batch_size = 1stop_condition = Falsedecoded_sentence = ''while not stop_condition:output_tokens, h, c = decoder_model.predict([target_seq] + states_value)
# print(output_tokens)这里输出的是下个字符出现的位置的概率# 对下个字符采样 sampled_token_index是要预测下个字符最大概率出现在字典中的位置sampled_token_index = np.argmax(output_tokens[0, -1, :])sampled_char = reverse_target_char_index[sampled_token_index]decoded_sentence += sampled_char# 退出条件:生成 \n 或者 超过最大序列长度if sampled_char == '\n' or len(decoded_sentence) >INUPT_LENGTH :stop_condition = True# 更新target_seqtarget_seq = np.zeros((1, 1, num_decoder_tokens))target_seq[0, 0, sampled_token_index] = 1.# 更新中间状态states_value = [h, c]return decoded_sentence
四、训练模型
model,encoder_model,decoder_model=create_model()
#编译模型
model.compile(optimizer='rmsprop',loss='categorical_crossentropy')
#训练模型
model.fit([encoder_input_data,decoder_input_data],decoder_target_data,batch_size=batch_size,epochs=epochs,validation_split=0.2)
#训练不错的模型为了以后测试可是保存
model.save('s2s.h5')
encoder_model.save('encoder_model.h5')
decoder_model.save('decoder_model.h5')
五、展示
if __name__ == '__main__':intro=input("select train model or test model:")if intro=="train":print("训练模式...........")train()else:print("测试模式.........")while(1):test()
训练数据用了3000组 ,大部分是比较短的词组或者单词。效果不能算是太好,但是比起英语渣渣还算可以吧。
Reference:
A ten-minute introduction to sequence-to-sequence learning in Keras
https://towardsdatascience.com/neural-machine-translation-using-seq2seq-with-keras-c23540453c74
Keras实现英文到中文机器翻译 seq2seq+LSTM相关推荐
- 人工智能框架实战精讲:Keras项目-英文语料的DNN、Word2Vec、CNN、LSTM文本分类实战与调参优化
Keras项目-英文语料的文本分类实战 一.机器学习模型 1.1 数据简介 1.2 数据读取与预处理 1.3 数据切分与逻辑回归模型构建 二.全连接神经网络模型 2.1 模型训练 2.2 模型结果展示 ...
- 【转】SQL函数:字符串中提取数字,英文,中文,过滤重复字符
SQL函数:字符串中提取数字,英文,中文,过滤重复字符 --提取数字 IF OBJECT_ID('DBO.GET_NUMBER') IS NOT NULL DROP FUNCTION DBO.GET_ ...
- keras实现简单lstm_深度学习(LSTM)在交通建模中的应用
上方点击蓝字关注? 在简单了解了LSTM原理之后,本期我将以航班延误预测为例为大家介绍一下如何利用Python编程来构建LSTM模型. 这里我们要用到一个高级的深度学习链接库--Keras,它以Ten ...
- IDEA(Pycharm)一家子常用快捷键Keymap对应的英文、中文与具体位置
本博客旨在把idea全家桶中快捷键的英文与中文罗列出来,方便大家自定义 中文 英文 位置 代码提示 Basic code - completion - Basic 同时多行输入 Add or Remo ...
- EasyUI的DataGrid 分页栏英文改中文解决方案
EasyUI的DataGrid 分页栏英文改中文解决方案 参考文章: (1)EasyUI的DataGrid 分页栏英文改中文解决方案 (2)https://www.cnblogs.com/tahn30 ...
- CSS 中的字体兼容写法:用CSS为英文和中文字体分别设置不同的字体
font-family的调用方法: font-family:Arial,'Times New Roman','Microsoft YaHei',SimHei; font:bold 12px/0.75e ...
- 用PS将图片或表格中的英文变成中文
用PS将图片或表格中的英文变成中文 相信现在很多童鞋都面临着毕业压力,开始焦急的撰写毕业论文了.现在教大家用PS将中文变成成英文,马上学起来.如果你们还没有下载photoshop,也可以用在线PS,方 ...
- C++排雷:19.过滤英文和中文标点符号,string与wstring之间的转换
想要过滤一个文本中的标点符号. 对英文标点符号可以使用cctype中的ispunct方法来识别 而对于中文标点符号,则需要一定的转换: C++用string来处理字符串. string是窄字符串ASC ...
- vs智能提示英文转为中文
使用 VS2015 时,在 4.0 下智能提示显示中文,在 4.5 下显示英文,英文转为中文的方法: 找到目录: C:\Program Files (x86)\Reference Assemblies ...
最新文章
- maven 多环境打包
- Vue 系列之 组件
- linux下mysql源码安装
- hdu1816 + POJ 2723开锁(二分+2sat)
- CTFshow php特性 web113
- 【faster-rcnn】训练自己的数据集时的坑
- 南明区将引进和培养大数据高端人才逾千名
- 模拟 http connecttimeout_燃烧室数学模型模拟软件NPSS
- Apache Kylin从入门到精通
- 移动短信回执怎么开通_移动短信回执业务内容及资费介绍
- Live预告 | 地平线李星宇:智能汽车电子构架如何变革迎接数字化重塑?...
- 計算機二級-java08
- 自如总部摘牌?官方回应:更换logo 业务一切正常
- es String 内部实现逻辑标准
- python redis模块_大数据入门4 | Redis安装及python中的redis模块加载
- mongodb聚合内存不足解决方案
- Entity Framework 常用类
- html flash带播放视频源码,HTML嵌套Flash播放视频
- java short 写法_Java Short类shortValue()方法及示例
- 这一刻,听见华为FTTR的星光四重奏
热门文章
- 用一句话概括计算机信息管理,一句话概括ERP
- 授人玫瑰 手留余香 --纪念python3.2.3官方文档翻译结束
- css 新增div滚动条、修改滚动条样式
- el-checkbox中的checked勾选状态问题 2021-08-02
- python加载图片并显示_python读取目录下所有的jpg文件,并显示第一张图片的示例...
- matlab对图像进行线性点运算,图像线性点运算---MATLAB
- 实时图像增强,基于“间距自适应查找表”的方法(CVPR 2022)
- USACO Wormholes
- 像高启强一样学PMP项目管理,我的备考效率一路『狂飙』!
- MySQL实现树的遍历