pytorch实现 chatbot聊天机器人
涉及的论文
Neural Conversational Model https://arxiv.org/abs/1506.05869
Luong attention mechanism(s) https://arxiv.org/abs/1508.04025
Sutskever et al. https://arxiv.org/abs/1409.3215
GRU Cho et al. https://arxiv.org/pdf/1406.1078v3.pdf
Bahdanau et al. https://arxiv.org/abs/1409.0473
使用的数据集
Corpus web https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
Corpus link http://www.cs.cornell.edu/~cristian/data/cornell_movie_dialogs_corpus.zip
代码列表
chatbot_test.py
chatbot_train.py
corpus_dataset.py
vocabulary.py
graph.py
model.py
etc.py
main.py
chatbot_test.py
# -*- coding: utf-8 -*-from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literalsimport corpus_dataset
import graph
import etcdef run_test():config = etc.configvoc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)g = graph.CorpusGraph(config)train_model = g.create_train_model(voc, "test")g.evaluate_input(voc, train_model)
chatbot_train.py
# -*- coding: utf-8 -*-from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literalsimport corpus_dataset
import graph
import etcdef run_train():config = etc.configvoc, pairs = corpus_dataset.load_vocabulary_and_pairs(config)g = graph.CorpusGraph(config)print("Create model")train_model = g.create_train_model(voc)print("Starting Training!")g.trainIters(voc, pairs, train_model)
# print("Starting evaluate!")
# g.evaluate_input(voc, train_model)
corpus_dataset.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : corpus_dataset.py
# Create date : 2019-01-16 11:16
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionimport os
import re
import csv
import codecs
import unicodedata
import vocabularydef _check_is_have_file(file_name):return os.path.exists(file_name)def _filter_pair(p, max_length):return len(p[0].split(' ')) < max_length and len(p[1].split(' ')) < max_lengthdef _filter_pairs(pairs, max_length):return [pair for pair in pairs if _filter_pair(pair, max_length)]def _read_vocabulary(datafile, corpus_name):print("Reading lines...")lines = open(datafile, encoding='utf-8'). read().strip().split('\n')pairs = [[normalize_string(s) for s in l.split('\t')] for l in lines]voc = vocabulary.Voc(corpus_name)return voc, pairsdef _unicode_to_ascii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn')def _get_delimiter(config):delimiter = config["delimiter"]delimiter = str(codecs.decode(delimiter, "unicode_escape"))return delimiterdef _get_object(line, fields):values = line.split(" +++$+++ ")obj = {}for i, field in enumerate(fields):obj[field] = values[i]return objdef _load_lines(config):lines_file_name = config["lines_file_name"]corpus_path = config["corpus_path"]lines_file_full_path = "%s/%s" % (corpus_path, lines_file_name)fields = config["movie_lines_fields"]lines = {}f = open(lines_file_full_path, 'r', encoding='iso-8859-1')for line in f:line_obj = _get_object(line, fields)lines[line_obj['lineID']] = line_objf.close()return linesdef _cellect_lines(conv_obj, lines):# Convert string to list (conv_obj["utteranceIDs"] == "['L598485', 'L598486', ...]")line_ids = eval(conv_obj["utteranceIDs"])# Reassemble linesconv_obj["lines"] = []for line_id in line_ids:conv_obj["lines"].append(lines[line_id])return conv_objdef _load_conversations(lines, config):conversations = []corpus_path = config["corpus_path"]conversation_file_name = config["conversation_file_name"]conversation_file_full_path = "%s/%s" % (corpus_path, conversation_file_name)fields = config["movie_conversations_fields"]f = open(conversation_file_full_path, 'r', encoding='iso-8859-1')for line in f:conv_obj = _get_object(line, fields)conv_obj = _cellect_lines(conv_obj, lines)conversations.append(conv_obj)f.close()return conversationsdef _get_conversations(config):lines = {}conversations = []lines = _load_lines(config)print("lines count:", len(lines))conversations = _load_conversations(lines, config)print("conversations count:", len(conversations))return conversationsdef _extract_sentence_pairs(conversations):pairs = []for conversation in conversations:for i in range(len(conversation["lines"]) - 1): # We ignore the last line (no answer for it)inputLine = conversation["lines"][i]["text"].strip()targetLine = conversation["lines"][i+1]["text"].strip()# Filter wrong samples (if one of the lists is empty)if inputLine and targetLine:pairs.append([inputLine, targetLine])return pairsdef _load_formatted_data(config):max_length = config["max_length"]corpus_name = config["corpus_name"]formatted_file_full_path = get_formatted_file_full_path(config)print("Start preparing training data ...")voc, pairs = _read_vocabulary(formatted_file_full_path, corpus_name)print("Read {!s} sentence pairs".format(len(pairs)))pairs = _filter_pairs(pairs, max_length)print("Trimmed to {!s} sentence pairs".format(len(pairs)))for pair in pairs:voc.addSentence(pair[0])voc.addSentence(pair[1])print("Counted words:", voc.num_words)return voc, pairsdef _trim_rare_words(voc, pairs, min_count):voc.trim(min_count)keep_pairs = []for pair in pairs:input_sentence = pair[0]output_sentence = pair[1]keep_input = Truekeep_output = Truefor word in input_sentence.split(' '):if word not in voc.word2index:keep_input = Falsebreakfor word in output_sentence.split(' '):if word not in voc.word2index:keep_output = Falsebreakif keep_input and keep_output:keep_pairs.append(pair)print("Trimmed from {} pairs to {}, {:.4f} of total".format(len(pairs), len(keep_pairs), len(keep_pairs) / len(pairs)))return keep_pairsdef _write_newly_formatted_file(config):formatted_file_full_path = get_formatted_file_full_path(config)if not _check_is_have_file(formatted_file_full_path):delimiter = _get_delimiter(config)conversations = _get_conversations(config)outputfile = open(formatted_file_full_path, 'w', encoding='utf-8')pairs = _extract_sentence_pairs(conversations)print("pairs count:", len(pairs))writer = csv.writer(outputfile, delimiter=delimiter, lineterminator='\n')print("\nWriting newly formatted file...")for pair in pairs:writer.writerow(pair)else:print("%s already has the formatted file,so we do not write" % formatted_file_full_path)def load_vocabulary_and_pairs(config):_write_newly_formatted_file(config)voc, pairs = _load_formatted_data(config)pairs = _trim_rare_words(voc, pairs, config["min_count"])return voc, pairsdef get_formatted_file_full_path(config):formatted_file_name = config["formatted_file_name"]corpus_path = config["corpus_path"]formatted_file_full_path = "%s/%s" % (corpus_path, formatted_file_name)return formatted_file_full_pathdef normalize_string(s):s = _unicode_to_ascii(s.lower().strip())s = re.sub(r"([.!?])", r" \1", s)s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)s = re.sub(r"\s+", r" ", s).strip()return s
vocabulary.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : vocabulary.py
# Create date : 2019-01-16 11:21
# Modified date : 2019-02-02 13:37
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionPAD_token = 0 # Used for padding short sentences
SOS_token = 1 # Start-of-sentence token
EOS_token = 2 # End-of-sentence tokenclass Voc:def __init__(self, name):self.name = nameself.trimmed = Falseself.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3 # Count SOS, EOS, PADdef addSentence(self, sentence):for word in sentence.split(' '):self.addWord(word)def addWord(self, word):if word not in self.word2index:self.word2index[word] = self.num_wordsself.word2count[word] = 1self.index2word[self.num_words] = wordself.num_words += 1else:self.word2count[word] += 1def trim(self, min_count):if self.trimmed:returnself.trimmed = Truekeep_words = []for k, v in self.word2count.items():if v >= min_count:keep_words.append(k)print('keep_words {} / {} = {:.4f}'.format(len(keep_words), len(self.word2index), len(keep_words) / len(self.word2index)))self.word2index = {}self.word2count = {}self.index2word = {PAD_token: "PAD", SOS_token: "SOS", EOS_token: "EOS"}self.num_words = 3 # Count default tokensfor word in keep_words:self.addWord(word)
graph.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : graph.py
# Create date : 2019-01-16 11:44
# Modified date : 2019-02-02 14:55
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionimport os
import itertools
import random
import torch
import torch.nn as nn
from torch import optimimport vocabulary
import model
import corpus_datasetdef _get_training_batches(voc, pairs, batch_size, n_iteration):training_batches = []for i in range(n_iteration):lt = [random.choice(pairs) for _ in range(batch_size)]batch = _batch2TrainData(voc, lt)training_batches.append(batch)return training_batchesdef _zero_padding(l, fillvalue=vocabulary.PAD_token):return list(itertools.zip_longest(*l, fillvalue=fillvalue))def _binary_matrix(lt):m = []for i, seq in enumerate(lt):m.append([])for token in seq:if token == vocabulary.PAD_token:m[i].append(0)else:m[i].append(1)return mdef _get_indexes_batch(lt, voc):indexes_batch = [_indexes_from_sentence(voc, sentence) for sentence in lt]return indexes_batchdef _input_var(batch, voc):indexes_batch = _get_indexes_batch(batch, voc)padList = _zero_padding(indexes_batch)variable = torch.LongTensor(padList)lengths = torch.tensor([len(indexes) for indexes in indexes_batch])return variable, lengthsdef _output_var(batch, voc):indexes_batch = _get_indexes_batch(batch, voc)padList = _zero_padding(indexes_batch)variable = torch.LongTensor(padList)max_target_len = max([len(indexes) for indexes in indexes_batch])mask = _binary_matrix(padList)mask = torch.ByteTensor(mask)return variable, mask, max_target_lendef _indexes_from_sentence(voc, sentence):#return [voc.word2index[word] for word in sentence.split(' ')] + [vocabulary.EOS_token]index_lt = []for word in sentence.split(' '):i = voc.word2index[word]index_lt.append(i)index_lt.append(vocabulary.EOS_token)return index_ltdef _batch2TrainData(voc, pair_batch):pair_batch.sort(key=lambda x: len(x[0].split(" ")), reverse=True)input_batch, output_batch = [], []for pair in pair_batch:input_batch.append(pair[0])output_batch.append(pair[1])input_variable, lengths = _input_var(input_batch, voc)target_variable, mask, max_target_len = _output_var(output_batch, voc)return input_variable, lengths, target_variable, mask, max_target_lendef _maskNLLLoss(inp, target, mask, device):nTotal = mask.sum()crossEntropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1)).squeeze(1))loss = crossEntropy.masked_select(mask).mean()loss = loss.to(device)return loss, nTotal.item()class CorpusGraph(nn.Module):def __init__(self, config):super(CorpusGraph, self).__init__()self.model_name = config["model_name"]self.save_dir = config["save_dir"]self.corpus_name = config["corpus_name"]self.encoder_n_layers = config["encoder_n_layers"]self.decoder_n_layers = config["decoder_n_layers"]self.hidden_size = config["hidden_size"]self.checkpoint_iter = config["checkpoint_iter"]self.learning_rate = config["learning_rate"]self.decoder_learning_ratio = config["decoder_learning_ratio"]self.dropout = config["dropout"]self.attn_model = config["attn_model"]self.device = config["device"]self.print_every = config["print_every"]self.save_every = config["save_every"]self.n_iteration = config["n_iteration"]self.batch_size = config["batch_size"]self.clip = config["clip"]self.max_length = config["max_length"]self.teacher_forcing_ratio = config["teacher_forcing_ratio"]self.train_load_checkpoint_file = config["train_load_checkpoint_file"]def _evaluate(self, voc, sentence, train_model):encoder = train_model["encoder"]decoder = train_model["decoder"]# Set dropout layers to eval modeencoder.eval()decoder.eval()searcher = model.GreedySearchDecoder(encoder, decoder)indexes_batch = [_indexes_from_sentence(voc, sentence)]lengths = torch.tensor([len(indexes) for indexes in indexes_batch])input_batch = torch.LongTensor(indexes_batch).transpose(0, 1)input_batch = input_batch.to(self.device)lengths = lengths.to(self.device)tokens, scores = searcher(input_batch, lengths, self.max_length, self.device)decoded_words = [voc.index2word[token.item()] for token in tokens]return decoded_wordsdef _choose_use_teacher_forcing(self):return True if random.random() < self.teacher_forcing_ratio else Falsedef _train_step(self, decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len):loss = 0print_losses = []n_totals = 0for t in range(max_target_len):decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)if self._choose_use_teacher_forcing():decoder_input = target_variable[t].view(1, -1)else:_, topi = decoder_output.topk(1)decoder_input = torch.LongTensor([[topi[i][0] for i in range(self.batch_size)]])decoder_input = decoder_input.to(self.device)mask_loss, nTotal = _maskNLLLoss(decoder_output, target_variable[t], mask[t], self.device)loss += mask_lossprint_losses.append(mask_loss.item() * nTotal)n_totals += nTotalreturn loss, print_losses, n_totalsdef _train_init(self, input_variable, lengths, target_variable, mask, train_model):encoder = train_model["encoder"]decoder = train_model["decoder"]encoder_optimizer = train_model["encoder_optimizer"]decoder_optimizer = train_model["decoder_optimizer"]encoder_optimizer.zero_grad()decoder_optimizer.zero_grad()input_variable = input_variable.to(self.device)lengths = lengths.to(self.device)target_variable = target_variable.to(self.device)mask = mask.to(self.device)encoder_outputs, encoder_hidden = encoder(input_variable, lengths)decoder_input = torch.LongTensor([[vocabulary.SOS_token for _ in range(self.batch_size)]])decoder_input = decoder_input.to(self.device)decoder_hidden = encoder_hidden[:decoder.n_layers]return decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, maskdef _train_backward(self, loss, train_model):encoder = train_model["encoder"]decoder = train_model["decoder"]encoder_optimizer = train_model["encoder_optimizer"]decoder_optimizer = train_model["decoder_optimizer"]loss.backward()_ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), self.clip)_ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), self.clip)encoder_optimizer.step()decoder_optimizer.step()def _train(self, input_variable, lengths, target_variable, mask, max_target_len, train_model):decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask = self._train_init(input_variable, lengths, target_variable, mask, train_model)loss, print_losses, n_totals = self._train_step(decoder, decoder_input, decoder_hidden, encoder_outputs, target_variable, mask, max_target_len)self._train_backward(loss, train_model)return sum(print_losses) / n_totalsdef _save_model_dict(self, train_model, iteration, voc, loss):model_dict = self._get_model_dict(train_model, iteration, voc, loss)checkpoint_file_full_path = self._get_checkpoint_file_full_name()torch.save(model_dict, checkpoint_file_full_path)def _show_batches(self, batches):input_variable, lengths, target_variable, mask, max_target_len = batchesprint("input_variable:", input_variable)print("lengths:", lengths)print("target_variable:", target_variable)print("mask:", mask)print("max_target_len:", max_target_len)def _show_train_state(self, print_loss, iteration):print_loss_avg = print_loss / self.print_everyprint("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / self.n_iteration * 100, print_loss_avg))print_loss = 0return print_lossdef _get_model_dict(self, train_model, iteration, voc, loss):model_dict = {}model_dict["en"] = train_model["encoder"].state_dict()model_dict["de"] = train_model["decoder"].state_dict()model_dict["en_opt"] = train_model["encoder_optimizer"].state_dict()model_dict["de_opt"] = train_model["decoder_optimizer"].state_dict()model_dict["embedding"] = train_model["embedding"].state_dict()model_dict["iteration"] = iterationmodel_dict["loss"] = lossmodel_dict["voc_dict"] = voc.__dict__return model_dictdef _load_checkpoint(self, train_model, voc, checkpoint):train_model["encoder"].load_state_dict(checkpoint['en'])train_model["decoder"].load_state_dict(checkpoint['de'])train_model["encoder_optimizer"].load_state_dict(checkpoint['en_opt'])train_model["decoder_optimizer"].load_state_dict(checkpoint['de_opt'])train_model["embedding"].load_state_dict(checkpoint['embedding'])voc.__dict__ = checkpoint['voc_dict']train_model["iteration"] = checkpoint["iteration"]return train_modeldef _train_load_checkpoint(self, train_model, voc):loadFilename = self._get_checkpoint_file_full_name()if os.path.exists(loadFilename) and self.train_load_checkpoint_file:checkpoint = torch.load(loadFilename)train_model = self._load_checkpoint(train_model, voc, checkpoint)return train_modeldef _test_load_checkpoint(self, train_model, voc):loadFilename = self._get_checkpoint_file_full_name()if os.path.exists(loadFilename) and self.train_load_checkpoint_file:checkpoint = torch.load(loadFilename)# If loading a model trained on GPU to CPUcheckpoint = torch.load(loadFilename, map_location=torch.device('cpu'))train_model = self._load_checkpoint(train_model, voc, checkpoint)return train_modeldef _get_save_directory(self):directory = os.path.join(self.save_dir,self.model_name,self.corpus_name,'{}-{}_{}'.format(self.encoder_n_layers,self.decoder_n_layers,self.hidden_size))if not os.path.exists(directory):os.makedirs(directory)return directorydef _get_checkpoint_file_full_name(self):directory = self._get_save_directory()checkpoint_file_name = "checkpoint.tar"checkpoint_file_full_name = "%s/%s" % (directory, checkpoint_file_name)return checkpoint_file_full_namedef create_train_model(self, voc, status="train"):embedding = nn.Embedding(voc.num_words, self.hidden_size)encoder = model.EncoderRNN(self.hidden_size, embedding, self.encoder_n_layers, self.dropout)encoder = encoder.to(self.device)decoder = model.LuongAttnDecoderRNN(self.attn_model, embedding, self.hidden_size, voc.num_words, self.decoder_n_layers, self.dropout)decoder = decoder.to(self.device)#Ensure dropout layers are in train modeencoder.train()decoder.train()encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)decoder_optimizer = optim.Adam(decoder.parameters(), self.learning_rate*self.decoder_learning_ratio)train_model = {}train_model["encoder"] = encodertrain_model["decoder"] = decodertrain_model["encoder_optimizer"] = encoder_optimizertrain_model["decoder_optimizer"] = decoder_optimizertrain_model["embedding"] = embeddingtrain_model["iteration"] = 0if status == "train":train_model = self._train_load_checkpoint(train_model, voc)else:train_model = self._test_load_checkpoint(train_model, voc)return train_modeldef trainIters(self, voc, pairs, train_model):training_batches = _get_training_batches(voc, pairs, self.batch_size, self.n_iteration)print_loss = 0base_iteration = train_model['iteration'] + 1start_iteration = 1for iteration in range(start_iteration, self.n_iteration + 1):training_batch = training_batches[iteration - 1]#self._show_batches(training_batch)input_variable, lengths, target_variable, mask, max_target_len = training_batchloss = self._train(input_variable, lengths, target_variable, mask, max_target_len, train_model)print_loss += losscur_iteration = base_iteration + iterationif iteration % self.print_every == 0:print_loss = self._show_train_state(print_loss, cur_iteration)if iteration % self.save_every == 0:self._save_model_dict(train_model, cur_iteration, voc, loss)def evaluate_input(self, voc, train_model):input_sentence = ''while(1):try:input_sentence = input('> ')if input_sentence == 'q' or input_sentence == 'quit': breakinput_sentence = corpus_dataset.normalize_string(input_sentence)output_words = self._evaluate(voc, input_sentence, train_model)output_words[:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD')]print('Bot:', ' '.join(output_words))except KeyError:print("Error: Encountered unknown word.")
model.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : model.py
# Create date : 2019-01-16 11:38
# Modified date : 2019-02-02 14:50
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionimport torch
import torch.nn as nn
import torch.nn.functional as F
import vocabularyclass EncoderRNN(nn.Module):def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):super(EncoderRNN, self).__init__()self.n_layers = n_layersself.hidden_size = hidden_sizeself.embedding = embeddingself.gru = nn.GRU(hidden_size,hidden_size,n_layers,dropout=(0 if n_layers == 1 else dropout),bidirectional=True)def forward(self, input_seq, input_lengths, hidden=None):embedded = self.embedding(input_seq)packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)outputs, hidden = self.gru(packed, hidden)outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(outputs)outputs = outputs[ :, :, :self.hidden_size] + outputs[ :, :, self.hidden_size:]return outputs, hiddenclass Attn(torch.nn.Module):def __init__(self, method, hidden_size):super(Attn, self).__init__()self.method = methodif self.method not in ['dot', 'general', 'concat']:raise ValueError(self.method, "is not an appropriate attention method.")self.hidden_size = hidden_sizeif self.method == 'general':self.attn = torch.nn.Linear(self.hidden_size, hidden_size)elif self.method == 'concat':self.attn = torch.nn.Linear(self.hidden_size * 2, hidden_size)self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size))def dot_score(self, hidden, encoder_output):return torch.sum(hidden * encoder_output, dim=2)def general_score(self, hidden, encoder_output):energy = self.attn(encoder_output)return torch.sum(hidden * energy, dim=2)def concat_score(self, hidden, encoder_output):energy = self.attn(torch.cat((hidden.expand(encoder_output.size(0), -1, -1), encoder_output), 2)).tanh()return torch.sum(self.v * energy, dim=2)def forward(self, hidden, encoder_outputs):if self.method == 'general':attn_energies = self.general_score(hidden, encoder_outputs)elif self.method == 'concat':attn_energies = self.concat_score(hidden, encoder_outputs)elif self.method == 'dot':attn_energies = self.dot_score(hidden, encoder_outputs)attn_energies = attn_energies.t()return F.softmax(attn_energies, dim=1).unsqueeze(1)class LuongAttnDecoderRNN(nn.Module):def __init__(self, attn_model, embedding, hidden_size, output_size, n_layers=1, dropout=0.1):super(LuongAttnDecoderRNN, self).__init__()self.attn_model = attn_modelself.hidden_size = hidden_sizeself.output_size = output_sizeself.n_layers = n_layersself.dropout = dropoutself.embedding = embeddingself.embedding_dropout = nn.Dropout(dropout)self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=(0 if n_layers == 1 else dropout))self.concat = nn.Linear(hidden_size * 2, hidden_size)self.out = nn.Linear(hidden_size, output_size)self.attn = Attn(attn_model, hidden_size)def forward(self, input_step, last_hidden, encoder_outputs):embedded = self.embedding(input_step)embedded = self.embedding_dropout(embedded)rnn_output, hidden = self.gru(embedded, last_hidden)attn_weights = self.attn(rnn_output, encoder_outputs)context = attn_weights.bmm(encoder_outputs.transpose(0, 1))rnn_output = rnn_output.squeeze(0)context = context.squeeze(1)concat_input = torch.cat((rnn_output, context), 1)concat_output = torch.tanh(self.concat(concat_input))output = self.out(concat_output)output = F.softmax(output, dim=1)return output, hiddenclass GreedySearchDecoder(nn.Module):def __init__(self, encoder, decoder):super(GreedySearchDecoder, self).__init__()self.encoder = encoderself.decoder = decoderdef forward(self, input_seq, input_length, max_length,device):encoder_outputs, encoder_hidden = self.encoder(input_seq, input_length)decoder_hidden = encoder_hidden[:self.decoder.n_layers]decoder_input = torch.ones(1, 1, device=device, dtype=torch.long) * vocabulary.SOS_tokenall_tokens = torch.zeros([0], device=device, dtype=torch.long)all_scores = torch.zeros([0], device=device)for _ in range(max_length):decoder_output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, encoder_outputs)decoder_scores, decoder_input = torch.max(decoder_output, dim=1)all_tokens = torch.cat((all_tokens, decoder_input), dim=0)all_scores = torch.cat((all_scores, decoder_scores), dim=0)decoder_input = torch.unsqueeze(decoder_input, 0)return all_tokens, all_scores
etc.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : etc.py
# Create date : 2019-01-17 22:50
# Modified date : 2019-02-02 14:10
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionimport torchconfig = {}
config["corpus_name"] = "cornell movie-dialogs corpus"
config["corpus_path"] = "./data/%s" % config["corpus_name"]
config["delimiter"] = '\t'config["formatted_file_name"] = "formatted_movie_lines.txt"
config["conversation_file_name"] = "movie_conversations.txt"
config["lines_file_name"] = "movie_lines.txt"config["movie_lines_fields"] = ["lineID", "characterID", "movieID", "character", "text"]
config["movie_conversations_fields"] = ["character1ID", "character2ID", "movieID", "utteranceIDs"]config["model_name"] = 'cb_model'
config["attn_model"] = 'dot'
#attn_model = 'general'
#attn_model = 'concat'
config["hidden_size"] = 500
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["dropout"] = 0.1
config["print_every"] = 20
config["save_every"] = 500
config["n_iteration"] = 1000
config["encoder_n_layers"] = 2
config["decoder_n_layers"] = 2
config["clip"] = 50.0
config["learning_rate"] = 0.0001
config["decoder_learning_ratio"] = 5.0
config["batch_size"] = 64
config["save_dir"] = "./data/save"
config["checkpoint_iter"] = 4000
config["min_count"] = 3 # Minimum word count threshold for trimming
config["max_length"] = 10
config["teacher_forcing_ratio"] = 1.0
config["train_load_checkpoint_file"] = TrueUSE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
config["device"] = device
main.py
#!/usr/bin/python
# -*- coding: utf-8 -*-
#####################################
# File name : main.py
# Create date : 2019-02-02 13:44
# Modified date : 2019-02-02 13:45
# Author : DARREN
# Describe : not set
# Email : lzygzh@126.com
#####################################
from __future__ import division
from __future__ import print_functionfrom chatbot_train import run_train
from chatbot_test import run_testdef run():run_train()run_test()run()
github:
https://github.com/darr/chatbot
pytorch实现 chatbot聊天机器人相关推荐
- Chatbot 聊天机器人页面交互设计
目录 一.背景 二.设计要点 三.相关交互细节 四.总结 一.背景 最近在做源码智投app的机器人Neo的原型设计,是一个chatbot聊天机器人.整理了一下关于聊天机器人设计的一些心得. 这是Neo ...
- Pytorch项目实战聊天机器人(02.项目的准备阶段)
Pytorch项目实战聊天机器人(02.项目的准备阶段) 02.项目的准备阶段 二.2-2 NLP涉及知识 三.2-3 NLTK库 四 ,2-4 语料和词性标注 五 ,2-5 分词 六 , 2-6 T ...
- chatbot聊天机器人环境搭建以及项目运行指南
项目地址 网传有一位程序员因忙于工作,无暇陪伴女友,便做了个聊天机器人的软件来陪女友聊天,然后自己就安心工作去了.等到程序员下班时一看,机器人已经和女友聊到了二胎的娃叫啥名了.博主不明觉厉,便去拷贝一 ...
- chatbot聊天机器人技术路线汇总
版权声明:博主原创文章,转载请注明来源,谢谢合作!! https://mp.csdn.net/mdeditor/84481818 聊天机器人实现的技术途径 大约可分为4种:1. 第一种是属于" ...
- chatbot聊天机器人技术路线
聊天机器人实现的技术途径大约可分为以下4种: (其中第一种是属于"调用第三方API",也就是说核心代码和数据库不掌握在自己手里)(第二.三.四种属于开源框架,也就是说我们可以下载其 ...
- 转载:chatbot聊天机器人技术路线
转载&备份: https://blog.csdn.net/tian_panda/article/details/80664578 聊天机器人实现的技术途径大约可分为以下4种: (其中第一种是属 ...
- 【PyTorch】11 聊天机器人实战——Cornell Movie-Dialogs Corpus电影剧本数据集处理、利用Global attention实现Seq2Seq模型
聊天机器人教程 1. 下载数据文件 2. 加载和预处理数据 2.1 创建格式化数据文件 2.2 加载和清洗数据 3.为模型准备数据 4.定义模型 4.1 Seq2Seq模型 4.2 编码器 4.3 解 ...
- pytorch教程 聊天机器人(详细注释attentionrnn输入输出shape等知识点...
最近可能要用到seq2seq模型去解决一些轨迹预测的问题,拿pytorch教程的聊天机器人练了练手. 原文中教程已经写的比较详尽了,在此对原文教程进行一些补充说明,可能更加方便向我这样的小白入门学习. ...
- 【人机对话】对话机器人技术简介:问答系统、对话系统与聊天机器人
点击上方,选择星标或置顶,每天给你送干货! 阅读大概需要16分钟 跟随小博主,每天进步一丢丢 来自:AI算法之心 作者:段清华 个人主页:http://qhduan.com Github链接: htt ...
最新文章
- 机器学习——XGBoost大杀器,XGBoost模型原理,XGBoost参数含义
- Tomcat 安装与使用
- 十五道java开发常遇到的计算机网络协议高频面试题
- 通过FILETIME得到时间
- 清华大学「天机」芯片登上Nature封面:类脑加传统计算融合实现通用人工智能...
- 绑定注意事项——数据源的属性
- vue 打包路由报错_Vue下路由History模式打包后页面空白的解决方法
- CVPR AAAI 2020 |人脸活体检测最新进展
- c语言编程题报文解析,C语言解析pcap文件得到HTTP信息实例
- 东北三省计算机专业好的学校,东北地区哪个大学比较好 各自的王牌专业是什么...
- 一个发散动画的菜单控件(主要记录控件x,y坐标的运动状况)
- VolTE注册流程0001 融合HLR HSS
- 魔兽世界不显示服务器后缀,魔兽世界看不到世界频道?给你看到的方法
- 关于拉勾网的scrapy crawlspider爬虫出现的302问题的解决方式
- 前端javascript总结笔记(一)--js的三座大山
- PIC16 F887 单片机 直流电机PWM调速 PID调速 PID算法
- Java集合 HashSet 和 HashMap
- H5游戏开发-面向对象编程
- Anacoda的用途
- 在html页面插入flv播放器。ie火狐均可用