Chapter1.代码详解

完整代码github链接,Untitled.ipynb文件内。
【里面的测试是还没训练完的时候测试的,今晚会更新训练完成后的测试结果】
修复了网上一些代码的bug,解决了由于tensorflow版本不同引起的一些问题。

数据集链接 ,下载数据集后,解压提取dgk_shooter_min.conv文件,最好进行转码操作。建议用记事本打开后将其另存为,选择编码为utf-8后进行保存。

代码详解

(1)数据预处理

#coding=utf-8
#(1)数据预处理
import os
import random
from io import open
conv_path = 'dgk_shooter_min.conv.txt'
#判断数据集是否存在?
if not os.path.exists(conv_path):print('数据集不存在')exit()# 数据集格式
"""
E
M 畹/华/吾/侄/
M 你/接/到/这/封/信/的/时/候/
M 不/知/道/大/伯/还/在/不/在/人/世/了/
E
M 咱/们/梅/家/从/你/爷/爷/起/
M 就/一/直/小/心/翼/翼/地/唱/戏/
M 侍/奉/宫/廷/侍/奉/百/姓/
M 从/来/不/曾/遭/此/大/祸/
M 太/后/的/万/寿/节/谁/敢/不/穿/红/
M 就/你/胆/儿/大/
M 唉/这/我/舅/母/出/殡/
M 我/不/敢/穿/红/啊/
M 唉/呦/唉/呦/爷/
M 您/打/得/好/我/该/打/
M 就/因/为/没/穿/红/让/人/赏/咱/一/纸/枷/锁/
M 爷/您/别/给/我/戴/这/纸/枷/锁/呀/
E
M 您/多/打/我/几/下/不/就/得/了/吗/
M 走/
M 这/是/哪/一/出/啊/…/ / /这/是/
M 撕/破/一/点/就/弄/死/你/
M 唉/
M 记/着/唱/戏/的/再/红/
M 还/是/让/人/瞧/不/起/
M 大/伯/不/想/让/你/挨/了/打/
M 还/得/跟/人/家/说/打/得/好/
M 大/伯/不/想/让/你/再/戴/上/那/纸/枷/锁/
M 畹/华/开/开/门/哪/
E
...
"""# 我首先使用文本编辑器sublime把dgk_shooter_min.conv文件编码转为UTF-8,一下子省了不少麻烦
convs = []  # 对话集合
with open(conv_path, encoding="utf8") as f:one_conv = []  # 一次完整对话for line in f:line = line.strip('\n').replace('/', '')#将分隔符去掉if line == '':continueif line[0] == 'E':if one_conv:convs.append(one_conv)one_conv = []elif line[0] == 'M':one_conv.append(line.split(' ')[1])
#将对话转成utf-8格式,并将其保存在dgk_shooter_min.conv文件中print(convs[:3])  # 个人感觉对白数据集有点不给力啊
#[ ['畹华吾侄', '你接到这封信的时候', '不知道大伯还在不在人世了'],
#  ['咱们梅家从你爷爷起', '就一直小心翼翼地唱戏', '侍奉宫廷侍奉百姓', '从来不曾遭此大祸', '太后的万寿节谁敢不穿红', '就你胆儿大', '唉这我舅母出殡', '我不敢穿红啊', '唉呦唉呦爷', '您打得好我该打', '就因为没穿红让人赏咱一纸枷锁', '爷您别给我戴这纸枷锁呀'],
#  ['您多打我几下不就得了吗', '走', '这是哪一出啊 ', '撕破一点就弄死你', '唉', '记着唱戏的再红', '还是让人瞧不起', '大伯不想让你挨了打', '还得跟人家说打得好', '大伯不想让你再戴上那纸枷锁', '畹华开开门哪'], ....]# 把对话分成问与答
ask = []        # 问
response = []   # 答
for conv in convs:if len(conv) == 1:continueif len(conv) % 2 != 0:  # 奇数对话数, 转为偶数对话conv = conv[:-1]for i in range(len(conv)):if i % 2 == 0:ask.append(conv[i])#偶数对,填写问题else:response.append(conv[i])#回答print(len(ask), len(response))
print(ask[:3])
print(response[:3])
#['畹华吾侄', '咱们梅家从你爷爷起', '侍奉宫廷侍奉百姓']
#['你接到这封信的时候', '就一直小心翼翼地唱戏', '从来不曾遭此大祸']def convert_seq2seq_files(questions, answers, TESTSET_SIZE=8000):# 创建文件train_enc = open('train.enc', 'w',encoding='utf-8')  # 问train_dec = open('train.dec', 'w',encoding='utf-8')  # 答test_enc = open('test.enc', 'w',encoding='utf-8')  # 问test_dec = open('test.dec', 'w',encoding='utf-8')  # 答# 选择8000数据作为测试数据test_index = random.sample([i for i in range(len(questions))], TESTSET_SIZE)for i in range(len(questions)):if i in test_index:#创建测试文件test_enc.write(questions[i] + '\n')test_dec.write(answers[i] + '\n')else:#创建训练文件train_enc.write(questions[i] + '\n')train_dec.write(answers[i] + '\n')if i % 1000 == 0:#表示处理了多少个iprint(len(range(len(questions))), '处理进度:', i)train_enc.close()train_dec.close()test_enc.close()test_dec.close()convert_seq2seq_files(ask, response)
# 生成的*.enc文件保存了问题
# 生成的*.dec文件保存了回答

将数据集进行处理后分成问与答的形式进行保存,选择其中的8000数据作为测试数据。处理完毕后生成的.enc文件保存了问题,.dec文件保存了回答。
问题文件*.enc预览:

爷爷您戏改得真好
您怎么不进去呀
王老板
见过
地球再也无法承受人类的数量
我现在是和摩兰达说话吗?
我们不是告诉他们应该想什么

回答文件*.dec预览:

这回跟您可真是一棵菜了
我等人拿钥匙呢
唉
什么事
我们发现了一个新的太阳系
不是
我们仅仅是想告诉他们应该怎么想

(2)创建词汇表

#coding=utf-8
#(2)创建词汇表
# 前一步生成的问答文件路径
train_encode_file = 'train.enc'
train_decode_file = 'train.dec'
test_encode_file = 'test.enc'
test_decode_file = 'test.dec'print('开始创建词汇表...')
# 特殊标记,用来填充标记对话
PAD = "__PAD__"
GO = "__GO__"
EOS = "__EOS__"  # 对话结束
UNK = "__UNK__"  # 标记未出现在词汇表中的字符
START_VOCABULART = [PAD, GO, EOS, UNK]
PAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
# 参看tensorflow.models.rnn.translate.data_utilsvocabulary_size = 5000# 生成词汇表文件
def gen_vocabulary_file(input_file, output_file):vocabulary = {}with open(input_file, encoding="utf8") as f:counter = 0for line in f:counter += 1tokens = [word for word in line.strip()]for word in tokens:if word in vocabulary:vocabulary[word] += 1else:vocabulary[word] = 1vocabulary_list = START_VOCABULART + sorted(vocabulary, key=vocabulary.get, reverse=True)# 取前5000个常用汉字, 应该差不多够用了(额, 好多无用字符, 最好整理一下. 我就不整理了)if len(vocabulary_list) > 5000:vocabulary_list = vocabulary_list[:5000]print(input_file + " 词汇表大小:", len(vocabulary_list))with open(output_file, "w", encoding="utf8") as ff:for word in vocabulary_list:ff.write(word + "\n")gen_vocabulary_file(train_encode_file, "train_encode_vocabulary")
gen_vocabulary_file(train_decode_file, "train_decode_vocabulary")train_encode_vocabulary_file = 'train_encode_vocabulary'
train_decode_vocabulary_file = 'train_decode_vocabulary'print("对话转向量...")# 把对话字符串转为向量形式
def convert_to_vector(input_file, vocabulary_file, output_file):tmp_vocab = []with open(vocabulary_file, "r", encoding="utf8") as f:tmp_vocab.extend(f.readlines())tmp_vocab = [line.strip() for line in tmp_vocab]vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])# {'硕': 3142, 'v': 577, 'I': 4789, '\ue796': 4515, '拖': 1333, '疤': 2201 ...}output_f = open(output_file, 'w')with open(input_file, 'r', encoding="utf8") as f:for line in f:line_vec = []for words in line.strip():line_vec.append(vocab.get(words, UNK_ID))output_f.write(" ".join([str(num) for num in line_vec]) + "\n")output_f.close()convert_to_vector(train_encode_file, train_encode_vocabulary_file, 'train_encode.vec')
convert_to_vector(train_decode_file, train_decode_vocabulary_file, 'train_decode.vec')convert_to_vector(test_encode_file, train_encode_vocabulary_file, 'test_encode.vec')
convert_to_vector(test_decode_file, train_decode_vocabulary_file, 'test_decode.vec')

提取前5000个常用的汉字创建词汇表
词汇表文件*_vocabulary预览:

__PAD__
__GO__
__EOS__
__UNK__
我
的
你
是
,
不
了
们

对话转向量,把对话字符串转为向量形式
向量文件*.vec预览:

6 269 31 13 1022 157 5 60 190
28 14 226 92 113 2047 2047 98 909 724
137 22 9 644 1331 278 63 1685
28 6 1363 118 63
4 9 652 514 824 88
433 131 51 24 4 127 131
1093 433 94 81 4 884 13 840 3435 1010 366

生成的train_encode.vec和train_decode.vec用于训练,对应的词汇表train_encode_vocabulary和train_decode_vocabulary。

(3)训练

这里选取部分代码进行讲解,完整代码链接。
导入向量文件进行训练,定义一个read_data的函数对训练集与测试集的问题向量文件encode.vec,回答向量文件decode.vec,进行读取。

读取的时候将问题向量文件encode.vec中的每一行默认以空格为分隔符,构成一个源序列。将回答向量文件decode.vec中的每一行默认以空格为分隔符,构成一个目标序列。然后将两个序列添加到data_set中。对文件中的每一行都进行处理与添加后,将得到的data_set返回。

# 读取*encode.vec和*decode.vec数据(数据还不算太多, 一次读入到内存)
def read_data(source_path, target_path, max_size=None):data_set = [[] for _ in buckets]#生成了[[],[],[],[]],即当值与参数不一样with tf.gfile.GFile(source_path, mode="r") as source_file:#以读格式打开源文件(source_file)with tf.gfile.GFile(target_path, mode="r") as target_file:#以读格式打开目标文件source, target = source_file.readline(), target_file.readline()#只读取一行counter = 0#计数器为0while source and target and ( not max_size or counter < max_size):#当读入的还存在时counter += 1source_ids = [int(x) for x in source.split()]#source的目标序列号,默认分隔符为空格,组成了一个源序列target_ids = [int(x) for x in target.split()]#target组成一个目标序列,为目标序列target_ids.append(EOS_ID)#加上结束标记的序列号for bucket_id, (source_size, target_size) in enumerate(buckets):#enumerate()遍历序列中的元素和其下标if len(source_ids) < source_size and len(target_ids) < target_size:#判断是否超越了最大长度data_set[bucket_id].append([source_ids, target_ids])#读取到数据集文件中区break#一次即可,跳出当前循环source, target = source_file.readline(), target_file.readline()#读取了下一行return data_set

构建模型

model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm=5.0,batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.97,forward_only=False)

开始训练

with tf.Session(config=config) as sess:# 恢复前一次训练ckpt = tf.train.get_checkpoint_state('.')if ckpt != None:print(ckpt.model_checkpoint_path)model.saver.restore(sess, ckpt.model_checkpoint_path)else:sess.run(tf.global_variables_initializer())train_set = read_data(train_encode_vec, train_decode_vec)test_set = read_data(test_encode_vec, test_decode_vec)train_bucket_sizes = [len(train_set[b]) for b in range(len(buckets))]#分别计算出训练集中的长度【1,2,3,4】train_total_size = float(sum(train_bucket_sizes))#训练实例总数train_buckets_scale = [sum(train_bucket_sizes[:i + 1]) / train_total_size for i in range(len(train_bucket_sizes))]#计算了之前所有的数的首战百分比loss = 0.0#损失置位0total_step = 0previous_losses = []# 一直训练,每过一段时间保存一次模型while True:random_number_01 = np.random.random_sample()#每一次循环结果不一样#选出最小的大于随机采样的值的索引号bucket_id = min([i for i in range(len(train_buckets_scale)) if train_buckets_scale[i] > random_number_01])encoder_inputs, decoder_inputs, target_weights = model.get_batch(train_set, bucket_id)#get_batch()函数首先获取bucket的encoder_size与decoder_size_, step_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, False)#损失loss += step_loss / 500total_step += 1print(total_step)if total_step % 500 == 0:print(model.global_step.eval(), model.learning_rate.eval(), loss)# 如果模型没有得到提升,减小learning rateif len(previous_losses) > 2 and loss > max(previous_losses[-3:]):#即损失比以前的大则降低学习率sess.run(model.learning_rate_decay_op)previous_losses.append(loss)# 保存模型checkpoint_path = "./chatbot_seq2seq.ckpt"model.saver.save(sess, checkpoint_path, global_step=model.global_step)#返回路径checkpoint_file = "%s-%s" % (save_path, "{:08d}".format(global_step))loss = 0.0#置当前损失为0# 使用测试数据评估模型for bucket_id in range(len(buckets)):if len(test_set[bucket_id]) == 0:continue#获取当前bucket的encoder_inputs, decoder_inputs, target_weightsencoder_inputs, decoder_inputs, target_weights = model.get_batch(test_set, bucket_id)#计算bucket_id的损失权重_, eval_loss, _ = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)eval_ppx = math.exp(eval_loss) if eval_loss < 300 else float('inf')print(bucket_id, eval_ppx)#输出的是bucket_id与eval_ppx

(4)模型测试

#coding=utf-8
#(4)使用训练好的模型
import tensorflow as tf  # 0.12
# from tensorflow.models.rnn.translate import seq2seq_model
from tensorflow.models.tutorials.rnn.chatbot import seq2seq_model#注意 seq2seq_model这个需要自己去下载,根据自己的路径进行导入
# 本人将seq2seq_model模块下载后 复制到tensorflow/models/tutorials/rnn/chatbot/路径下,所以才这样进行导入
import os
import numpy as npPAD_ID = 0
GO_ID = 1
EOS_ID = 2
UNK_ID = 3
tf.reset_default_graph()
#词汇表路径path
train_encode_vocabulary = 'train_encode_vocabulary'
train_decode_vocabulary = 'train_decode_vocabulary'#读取词汇表
def read_vocabulary(input_file):tmp_vocab = []with open(input_file, "r",encoding='utf-8') as f:tmp_vocab.extend(f.readlines())#打开的文件全部读入input_file中tmp_vocab = [line.strip() for line in tmp_vocab]#转换成列表vocab = dict([(x, y) for (y, x) in enumerate(tmp_vocab)])return vocab, tmp_vocab#返回字典,列表vocab_en, _, = read_vocabulary(train_encode_vocabulary)#得到词汇字典
_, vocab_de, = read_vocabulary(train_decode_vocabulary)#得到词汇列表# 词汇表大小5000
vocabulary_encode_size = 5000
vocabulary_decode_size = 5000buckets = [(5, 10), (10, 15), (20, 25), (40, 50)]
layer_size = 256  # 每层大小
num_layers = 3  # 层数
batch_size = 1model = seq2seq_model.Seq2SeqModel(source_vocab_size=vocabulary_encode_size, target_vocab_size=vocabulary_decode_size,buckets=buckets, size=layer_size, num_layers=num_layers, max_gradient_norm=5.0,batch_size=batch_size, learning_rate=0.5, learning_rate_decay_factor=0.99,forward_only=True)
#模型说明:源,目标词汇尺寸=vocabulary_encode(decode)_size;batch_size:训练期间使用的批次的大小;#forward_only:仅前向不传递误差model.batch_size = 1#batch_size=1with tf.Session() as sess:#打开作为一次会话# 恢复前一次训练ckpt = tf.train.get_checkpoint_state('.')#从检查点文件中返回一个状态(ckpt)#如果ckpt存在,输出模型路径if ckpt != None:print(ckpt.model_checkpoint_path)model.saver.restore(sess, ckpt.model_checkpoint_path)#储存模型参数else:print("没找到模型")#测试该模型的能力while True:input_string = input('me > ')# 退出if input_string == 'quit':exit()input_string_vec = []#输入字符串向量化for words in input_string.strip():input_string_vec.append(vocab_en.get(words, UNK_ID))#get()函数:如果words在词表中,返回索引号;否则,返回UNK_IDbucket_id = min([b for b in range(len(buckets)) if buckets[b][0] > len(input_string_vec)])#保留最小的大于输入的bucket的idencoder_inputs, decoder_inputs, target_weights = model.get_batch({bucket_id: [(input_string_vec, [])]}, bucket_id)#get_batch(A,B):两个参数,A为大小为len(buckets)的元组,返回了指定bucket_id的encoder_inputs,decoder_inputs,target_weights_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs, target_weights, bucket_id, True)#得到其输出outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]#求得最大的预测范围列表if EOS_ID in outputs:#如果EOS_ID在输出内部,则输出列表为[,,,,:End]outputs = outputs[:outputs.index(EOS_ID)]response = "".join([tf.compat.as_str(vocab_de[output]) for output in outputs])#转为解码词汇分别添加到回复中print('AI--PigPig > ' + response)#输出回复

测试结果:

以下为训练5500步后的测试结果:
【最终结果有待更新】

傲娇属性get

训练10000步后开始变得可爱了 ^_^

代码详解|tensorflow实现 聊天AI--PigPig养成记(1)相关推荐

  1. GTK+实现linux聊天室代码详解-clientr端

    查看原代码请点击此超链接 注意!!此聊天室对红帽无兼容.需在其他linux系统上运行,如"深度". 加油学习! GTK+实现linux聊天室代码详解-server端:GTK+实现l ...

  2. 目标检测Tensorflow:Yolo v3代码详解 (2)

    目标检测Tensorflow:Yolo v3代码详解 (2) 三.解析Dataset()数据预处理部分 四. 模型训练 yolo_train.py 五. 模型冻结 model_freeze.py 六. ...

  3. Tensorflow官网CIFAR-10数据分类教程代码详解

    标题 概述 对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,本教程代码通过解决CIFAR-10数据分类任务,介绍了Tensorflow的一些高阶用法,演示了构建大型复杂模型的一些重 ...

  4. 基于U-Net的的图像分割代码详解及应用实现

    摘要 U-Net是基于卷积神经网络(CNN)体系结构设计而成的,由Olaf Ronneberger,Phillip Fischer和Thomas Brox于2015年首次提出应用于计算机视觉领域完成语 ...

  5. batchnorm原理及代码详解

    转载自:http://www.ishenping.com/ArtInfo/156473.html batchnorm原理及代码详解 原博文 原微信推文 见到原作者的这篇微信小文整理得很详尽.故在csd ...

  6. NLP【05】pytorch实现glove词向量(附代码详解)

    上一篇:NLP[04]tensorflow 实现Wordvec(附代码详解) 下一篇:NLP[06]RCNN原理及文本分类实战(附代码详解) 完整代码下载:https://github.com/ttj ...

  7. 基于神经网络的依存句法分析总结及代码详解

    上一篇文章CS224n之句法分析总结,介绍了句法分析以及具体的依存分析中的arc-standard算法.arc-standard系统是transition systems中最流行的一个系统之一.而本文 ...

  8. 石头剪刀布人工智能代码详解

    石头剪刀布人工智能代码详解 #include <iostream> #include <fstream> #include <stdlib.h> #include ...

  9. 深度篇——目标检测史(七) 细说 YOLO-V3目标检测 之 代码详解

    返回主目录 返回 目标检测史 目录 上一章:深度篇--目标检测史(六) 细说 YOLO-V3目标检测 下一章:深度篇--目标检测史(八) 细说 CornerNet-Lite 目标检测 论文地址:< ...

最新文章

  1. 2021年大数据Hadoop(二十二):MapReduce的自定义分组
  2. XGBOOST带试验源码
  3. vb计算机考试试题及答案,计算机二级考试《VB》操作试题及答案2016
  4. 【Linux】一步一步学Linux——dircolors命令(239)
  5. python内建时间模块 time和datetime
  6. 服务器修改用户组权限设置,如何:修改用户的权限
  7. BM模式匹配算法原理(图解)
  8. 西部数码服务器自动备份,西部数码-云服务器
  9. c语言转化音乐格式转换器安卓版,MP3格式转换器APP
  10. html点击图片局部放大,鼠标单击实现放大镜图片局部放大效果
  11. 分享5个宝藏文字转语音配音软件,错过太可惜
  12. linux 查看gnu,查看GNU/Linux信息
  13. c语言设计一个语音识别程序,用 C# 开发自己的语音识别程序
  14. python爬虫爬取深交所数据
  15. matlab gui设计入门与实战,matlab gui编程教程
  16. win10系统设置选择电源键按钮功能设置步骤
  17. 【PID优化】基于正余弦算法 (SCA)优化PID实现微型机器人系统位置控制附simulink模型和matlab代码
  18. 计算机c盘program,电脑c盘program files(x86)文件夹可以删除吗
  19. js实现圆柱形轮播图
  20. 大数据学习——相关资源

热门文章

  1. linux系统学习第八天-工程师技术
  2. 《MongoDB管理与开发精要》——1.4节本章小结
  3. 《深入理解C++11:C++ 11新特性解析与应用》——2.4 宏__cplusplus
  4. CSipSimple通话记录分组
  5. linux定时关机命令_电脑设置定时关机你会吗?Windows自带的这行命令真好用
  6. 使用OPENROWSET爆破SQL Server密码
  7. Xamarin iOS开发中的编辑、连接、运行
  8. 苹果手机换了屏显示无服务器,苹果将​​为存在显示问题的iPhone 11提供免费更换...
  9. python控制苹果手机触摸屏失灵怎么办_iphone触摸屏失灵怎么办 iphone触摸屏失灵解决办法【详解】...
  10. 脑电实验注意事项及实验过程中伪迹识别