一、准备数据

1.seq_example代表问题,seq_answer代表答案,数据内容如下所示:

seq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]

2.将数据进行jieba分词并加入索引index,其中SOS代表单词开头,EOS代表单词结尾,PAD补全,数据如下:

{'你': 3, '认识': 4, '我': 5, '吗': 6, '住': 7, '在': 8, '哪里': 9, '知道': 10, '的': 11, '名字': 12, '是': 13, '谁': 14, '会': 15, '唱歌': 16, '有': 17, '父母': 18, '当然': 19, '成都': 20, '不': 21, '机器人': 22, '不会': 23, '没有': 24, 'PAD': 0, 'SOS': 1, 'EOS': 2}

3. 最后将seq_example与seq_answer分词后使用索引表示

二、模型构建

1.encoder

采用双向LSTM处理输入向量,代码如下:

class lstm_encoder(nn.Module):def __init__(self):super(lstm_encoder, self).__init__()
#         双向LSTMself.encoder = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)def forward(self, embedding_input):encoder_output, (encoder_h_n, encoder_c_n) = self.encoder(embedding_input)# 拼接前向和后向最后一个隐层encoder_h_n = torch.cat([encoder_h_n[0], encoder_h_n[1]], dim=1)encoder_c_n = torch.cat([encoder_c_n[0], encoder_c_n[1]], dim=1)return encoder_output, encoder_h_n.unsqueeze(0), encoder_c_n.unsqueeze(0)

2.decoder + Attention

decoder采用单向LSTM并加入Attention机制,即将decoder输出与encoder输出通过Atention拼接后进入全连接层做预测,Attention机制采用的General方式,具体过程如下所示:

代码如下:

class lstm_decoder(nn.Module):def __init__(self):super(lstm_decoder, self).__init__()# 单向LSTMself.decoder = nn.LSTM(embedding_size, n_hidden * 2, 1)# attention参数self.att_weight = nn.Linear(n_hidden * 2, n_hidden * 2)# attention_joint参数self.att_joint = nn.Linear(n_hidden * 4, n_hidden * 2)# 定义全连接层self.fc = nn.Linear(n_hidden * 2, num_classes)def forward(self, input_x, encoder_output, hn, cn):decoder_output, (decoder_h_n, decoder_c_n) = self.decoder(input_x, (hn, cn))decoder_output = decoder_output.permute(1, 0, 2)encoder_output = encoder_output.permute(1, 0, 2)decoder_output_att = self.att_weight(encoder_output)decoder_output_att = decoder_output_att.permute(0, 2, 1)# 计算分数scoredecoder_output_score = decoder_output.bmm(decoder_output_att)# 计算权重atat = nn.functional.softmax(decoder_output_score, dim=2)# 计算新的context向量ctct = at.bmm(encoder_output)# 拼接ct和decoder_htht_joint = torch.cat((ct, decoder_output), dim=2)fc_joint = torch.tanh(self.att_joint(ht_joint))fc_out = self.fc(fc_joint)return fc_out, decoder_h_n, decoder_c_n

三、具体代码

import torch
import torch.nn as nn
import torch.optim as optim
import jieba
import osseq_example = ["你认识我吗", "你住在哪里", "你知道我的名字吗", "你是谁", "你会唱歌吗", "你有父母吗"]
seq_answer = ["当然认识", "我住在成都", "我不知道", "我是机器人", "我不会", "我没有父母"]
# 所有词
example_cut = []
answer_cut = []
word_all = []
# 分词
for i in seq_example:example_cut.append(list(jieba.cut(i)))
for i in seq_answer:answer_cut.append(list(jieba.cut(i)))
#   所有词
for i in example_cut + answer_cut:for word in i:if word not in word_all:word_all.append(word)
# 词语索引表
word2index = {w: i+3 for i, w in enumerate(word_all)}
# 补全
word2index['PAD'] = 0
# 句子开始
word2index['SOS'] = 1
# 句子结束
word2index['EOS'] = 2
index2word = {value: key for key, value in word2index.items()}
# 一些参数
vocab_size = len(word2index)
seq_length = max([len(i) for i in example_cut + answer_cut]) + 1
embedding_size = 5
num_classes = vocab_size
n_hidden = 10# 将句子用索引表示
def make_data(seq_list):result = []for word in seq_list:seq_index = [word2index[i] for i in word]if len(seq_index) < seq_length:seq_index += [0] * (seq_length - len(seq_index))result.append(seq_index)return result
encoder_input = make_data(example_cut)
decoder_input = make_data([['SOS'] + i for i in answer_cut])
decoder_target = make_data([i + ['EOS'] for i in answer_cut])
# 训练数据
encoder_input, decoder_input, decoder_target = torch.LongTensor(encoder_input), torch.LongTensor(decoder_input), torch.LongTensor(decoder_target)# 建立encoder模型
class lstm_encoder(nn.Module):def __init__(self):super(lstm_encoder, self).__init__()
#         双向LSTMself.encoder = nn.LSTM(embedding_size, n_hidden, 1, bidirectional=True)def forward(self, embedding_input):encoder_output, (encoder_h_n, encoder_c_n) = self.encoder(embedding_input)# 拼接前向和后向最后一个隐层encoder_h_n = torch.cat([encoder_h_n[0], encoder_h_n[1]], dim=1)encoder_c_n = torch.cat([encoder_c_n[0], encoder_c_n[1]], dim=1)return encoder_output, encoder_h_n.unsqueeze(0), encoder_c_n.unsqueeze(0)# 建立attention_decoder模型
class lstm_decoder(nn.Module):def __init__(self):super(lstm_decoder, self).__init__()# 单向LSTMself.decoder = nn.LSTM(embedding_size, n_hidden * 2, 1)# attention参数self.att_weight = nn.Linear(n_hidden * 2, n_hidden * 2)# attention_joint参数self.att_joint = nn.Linear(n_hidden * 4, n_hidden * 2)# 定义全连接层self.fc = nn.Linear(n_hidden * 2, num_classes)def forward(self, input_x, encoder_output, hn, cn):decoder_output, (decoder_h_n, decoder_c_n) = self.decoder(input_x, (hn, cn))decoder_output = decoder_output.permute(1, 0, 2)encoder_output = encoder_output.permute(1, 0, 2)decoder_output_att = self.att_weight(encoder_output)decoder_output_att = decoder_output_att.permute(0, 2, 1)# 计算分数scoredecoder_output_score = decoder_output.bmm(decoder_output_att)# 计算权重atat = nn.functional.softmax(decoder_output_score, dim=2)# 计算新的context向量ctct = at.bmm(encoder_output)# 拼接ct和decoder_htht_joint = torch.cat((ct, decoder_output), dim=2)fc_joint = torch.tanh(self.att_joint(ht_joint))fc_out = self.fc(fc_joint)return fc_out, decoder_h_n, decoder_c_nclass seq2seq(nn.Module):def __init__(self):super(seq2seq, self).__init__()self.word_vec = nn.Embedding(vocab_size, embedding_size)#     encoderself.seq2seq_encoder = lstm_encoder()#     decoderself.seq2seq_decoder = lstm_decoder()def forward(self, encoder_input, decoder_input, inference_threshold=0):embedding_encoder_input = self.word_vec(encoder_input)embedding_decoder_input = self.word_vec(decoder_input)# 调换第一维和第二维度embedding_encoder_input = embedding_encoder_input.permute(1, 0, 2)embedding_decoder_input = embedding_decoder_input.permute(1, 0, 2)# 编码器encoder_output, h_n, c_n = self.seq2seq_encoder(embedding_encoder_input)# 判断为训练还是预测if inference_threshold:# 解码器decoder_output, h_n, c_n = self.seq2seq_decoder(embedding_decoder_input, encoder_output, h_n, c_n)return decoder_outputelse:# 创建outputs张量存储Decoder的输出outputs = []for i in range(seq_length):decoder_output, h_n, c_n = self.seq2seq_decoder(embedding_decoder_input, encoder_output, h_n, c_n)decoder_x = torch.max(decoder_output.reshape(-1, 25), dim=1)[1].item()if decoder_x in [0, 2]:return outputsoutputs.append(decoder_x)embedding_decoder_input = self.word_vec(torch.LongTensor([[decoder_x]]))embedding_decoder_input = embedding_decoder_input.permute(1, 0, 2)return outputsmodel = seq2seq()
print(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.05)# 判断是否有模型文件
if os.path.exists("./seq2seqModel.pkl"):model.load_state_dict(torch.load('./seq2seqModel.pkl'))
else:# 训练model.train()for epoch in range(10000):pred = model(encoder_input, decoder_input, 1)loss = criterion(pred.reshape(-1, 25), decoder_target.view(-1))optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 1000 == 0:print("Epoch: %d,  loss: %.5f " % (epoch + 1, loss))# 保存模型torch.save(model.state_dict(), './seq2seqModel.pkl')
# 测试
model.eval()
question_text = '你住在哪里'
question_cut = list(jieba.cut(question_text))
encoder_x = make_data([question_cut])
decoder_x = [[word2index['SOS']]]
encoder_x,  decoder_x = torch.LongTensor(encoder_x), torch.LongTensor(decoder_x)
out = model(encoder_x, decoder_x)
answer = ''
for i in out:answer += index2word[i]
print('问题:', question_text)
print('回答:', answer)

Pytorch简单实现seq2seq+Attention机器人问答相关推荐

  1. 智能客服系列3 seq2seq+attention【Python十分钟写出聊天机器人】基于Keras实现seq2seq模型

    开篇导读 首先复习下前面聊天机器人系列: <一>聊天机器人/翻译系统系列一梳理了聊天机器人网络设计模型原理 (理论篇-图文解锁seq2seq+attention模型原理) <二> ...

  2. 基于PyTorch实现Seq2Seq + Attention的英汉Neural Machine Translation

    NMT(Neural Machine Translation)基于神经网络的机器翻译模型效果越来越好,还记得大学时代Google翻译效果还是差强人意,近些年来使用NMT后已基本能满足非特殊需求了.目前 ...

  3. NLP-生成模型-2017-PGNet:Seq2Seq+Attention+Coverage+Copy【Coverage解决解码端重复解码问题;Copy机制解决解码端OOV问题】【抽取式+生成式】

    PGNet模型训练注意事项: Coverage机制要在训练的最后阶段再加入(约占总训练时间的1%),如果从刚开始训练时就加入则反而影响训练效果: Copy机制在源文本的各个单词上的概率分布直接使用At ...

  4. Seq2Seq+Attention生成式文本摘要

    任务描述: 自动摘要是指给出一段文本,我们从中提取出要点,然后再形成一个短的概括性的文本.自动的文本摘要是非常具有挑战性的,因为当我们作为人类总结一篇文章时,我们通常会完整地阅读它以发展我们的理解,然 ...

  5. 文本生成任务之营销文本生成(Seq2seq+attention、Pointer Generator Network、Converage、Beam Search、优化技巧、文本增强)

    文章目录 引言 项目任务简介 0. 数据预处理 0.1 将json文件转化成txt文件 0.2 词典处理 0.3 自定义数据集SampleDataset(Dataset类) 0.4 生成Dataloa ...

  6. Tensorflow 自动文摘: 基于Seq2Seq+Attention模型的Textsum模型

    Github下载完整代码 https://github.com/rockingdingo/deepnlp/tree/master/deepnlp/textsum 简介 这篇文章中我们将基于Tensor ...

  7. 基于pytorch的sque2suqe with attention实现与介绍

    基于pytorch的sque2suqe with attention实现与介绍 上一篇文章<基于pytorch的ConvGRU神经网络的实现与介绍>https://blog.csdn.ne ...

  8. seq2seq + attention

    1.思考几个问题:① 为什么解码器 一般来说 需要与 编码器的 hidden_size 相同呢? 2.seq2seq + attention 注意的几个问题:① 如果编码器 的 RNNCell 是LS ...

  9. seq2seq + attention 详解

    seq2seq + attention 详解 作者:xy_free \qquad 时间:2018.05.21 1. seq2seq模型 seq2seq模型最早可追溯到2014年的两篇paper [1, ...

最新文章

  1. py3提取json指定内容_python3 取页面指定数据(json)
  2. webapi控制器怎么接收json_一个秒杀系统的登录系统到底是怎么工作的
  3. 每日一皮:这个不要轻易尝试,执行有生命危险
  4. 类似百度输入框自动完成
  5. 图的dfs非递归_如何理解恶心的递归
  6. Java+Selenium爬贴吧
  7. QT的QSplitter类的使用
  8. 分享我第一次做项目的感受
  9. 有名信号量sem_open和内存信号量sem_init创建信号量的区别
  10. Python SQLAlchemy --3
  11. 短信计费(信息学奥赛一本通-T1398)
  12. [计算机网络] - TCP三次握手和四次挥手
  13. 基于SSH框架社区智能化管理系统答辩PPT模板
  14. 用Java求s=a+aa+aaa+.....+aaa...a的值
  15. simulink怎么生成vxworks的执行程序_让天下没有难改的Simulink模型
  16. 手机迅雷打不开html,迅雷打不开了怎么办
  17. The Nicest Word(io优化)
  18. 超级实用的分时图指标 有了本分时图你根本不用看K线了
  19. 西门子S7-1200 HslcommunicationDemo大纲拆解
  20. r语言中的或怎么表示什么不同_R语言中$是什么意思

热门文章

  1. 惠普HP Laser MFP 136a 打印机驱动
  2. 石墨文档入选「2021 数字经济产业 TOP100 榜单」
  3. php远程下载到本地,PHP 下载远程文件到本地的简单示例
  4. nginx静态文件缓存
  5. 有哪些好用的微信群管理工具?
  6. 大数据最佳实践-hbase
  7. AttributeError: module cv2.face has no attribute 'createEigenFaceRecognizer'
  8. 什么是 VxLAN?
  9. Vue 生成海报图的方法
  10. XMUOJ·纸片选择