目标:使用 github 唐诗宋词数据库的 json 数据,训练一共唐诗生成器
数据源:https://github.com/chinese-poetry/chinese-poetry

查看原始数据格式

原始数据包含 作者、正文、标题、id四个部分,这里仅仅使用诗词正文进行训练

制作唐诗数据集

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
path_tang = './tang/poet.tang.' # 读取 json 文件
def get_json(path):with open(path, 'r', encoding='utf-8') as f:data = json.load(f)return data# 对 ./tang/ 文件夹下的所有 json 文件进行遍历
import os# 获取文件夹下的所有文件名
def get_file_name(path):file_name = []for root, dirs, files in os.walk(path):for file in files:file_name.append(file)return file_name
file_name_ls = get_file_name('./tang/')
ret_ls = []for file_name in file_name_ls:ls = get_json('./tang/' + file_name)n_ls = len(ls)for i in range(n_ls):para = ls[i]['paragraphs']para = ''.join(para)ret_ls.append(para)
len(ret_ls) # 一共 57607 首诗歌
n_poet = len(ret_ls)
n_poet

查看前10首诗

for i in range(10):print(ret_ls[i])

秦川雄帝宅,函谷壯皇居。綺殿千尋起,離宮百雉餘。連甍遙接漢,飛觀迥凌虛。雲日隱層闕,風煙出綺疎。
巖廊罷機務,崇文聊駐輦。玉匣啓龍圖,金繩披鳳篆。韋編斷仍續,縹帙舒還卷。對此乃淹留,欹案觀墳典。

以茲遊觀極,悠然獨長想。披卷覽前蹤,撫躬尋既往。望古茅茨約,瞻今蘭殿廣。人道惡高危,虛心戒盈蕩。奉天竭誠敬,臨民思惠養。納善察忠諫,明科慎刑賞。六五誠難繼,四三非易仰。廣待淳化敷,方嗣云亭響。

进行 tokenizer 统计

# 全部合并成一个 string
str_all =  ''.join(ret_ls)
str_all[:1000]

‘秦川雄帝宅,函谷壯皇居。綺殿千尋起,離宮百雉餘。連甍遙接漢,飛觀迥凌虛。雲日隱層闕,風煙出綺疎。巖廊罷機務,崇文聊駐輦。玉匣啓龍圖,金繩披鳳篆。韋編斷仍續,縹帙舒還卷。對此乃淹留,欹案觀墳典。移步出

百蠻奉遐賮,萬國朝未央。雖無舜禹迹,幸欣天地康。車軌同八表,書文混四方。赫奕儼冠蓋,紛綸盛服章。羽旄飛馳道,鐘鼓震巖廊。組練輝霞色,霜戟耀朝光。晨宵懷至理,終媿撫遐荒。壽丘惟舊跡,酆邑乃前基。粵予承累’

# 用 keras 的 Tokenizer 进行统计
from keras.preprocessing.text import Tokenizer# 设置最大词汇量为 10000 个词
tokenizer = Tokenizer(num_words=10000,char_level=True)   # 注意:这里是 char_level=True,因为是字符级别的统计
# 在 str_all 上进行训练
tokenizer.fit_on_texts([str_all])
## 使用word_index属性查看每个词对应的编码
## 使用word_counts属性查看每个词对应的频数
for ii,iterm in enumerate(tokenizer.word_index.items()):if ii < 10:print(iterm)else:break
print("===================")
for ii,iterm in enumerate(tokenizer.word_counts.items()):if ii < 10:print(iterm)else:break

(‘,’, 1)
(‘。’, 2)
(‘不’, 3)
(‘人’, 4)
(‘一’, 5)
(‘無’, 6)
(‘山’, 7)
(‘風’, 8)
(‘日’, 9)
(‘有’, 10)
===================
(‘秦’, 1983)
(‘川’, 2503)
(‘雄’, 1060)
(‘帝’, 2099)
(‘宅’, 771)
(‘,’, 268814)
(‘函’, 268)
(‘谷’, 1301)
(‘壯’, 748)
(‘皇’, 1900)

制作数据集

# 制作数据集的方法:从 ret_ls 中,采样一首诗,然后从这首诗随机采样一个长度为 20 的子串,作为输入,然后预测下一个字符
# 数据集大小为 1w 样本对,采样方法是随机采样
import random
from tqdm import tqdm
x_seq_ls = []
y_seq_ls = []for i in tqdm(range(n_poet)):if len(ret_ls[i])-21<=0:        # 如果这首诗歌的长度小于等于 21,就跳过continue# 随机选一个子串start = random.randint(0, len(ret_ls[i])-21)end = start + 20# 保存到 x_seq_ls 和 y_seq_ls 中x_seq_ls.append(ret_ls[i][start:end])y_seq_ls.append(ret_ls[i][end])

100%|██████████| 57607/57607 [00:00<00:00, 654635.20it/s]

len(x_seq_ls),len(y_seq_ls)

(56294, 56294)

# 查看一下 x_seq_ls 和 y_seq_ls
for i in range(20):print(x_seq_ls[i],y_seq_ls[i])

綺殿千尋起,離宮百雉餘。連甍遙接漢,飛觀 迥
匣啓龍圖,金繩披鳳篆。韋編斷仍續,縹帙舒 還
電。驚雁落虛弦,啼猿悲急箭。閱賞誠多美, 於

輦駐新豐。紐落藤披架,花殘菊破叢。葉鋪荒 草
濟世豈邀名。星旂紛電舉,日羽肅天行。徧野 屯
參差影,寒猿斷續聲。冠蓋往來合,風塵朝夕 驚

# 把 x_seq_ls 和 y_seq_ls 用 tokenizer 进行编码x_token = tokenizer.texts_to_sequences(x_seq_ls)
y_token = tokenizer.texts_to_sequences(y_seq_ls)
# 查看一下 x_seq_ls 和 y_seq_ls
for i in range(20):print(x_token[i],y_token[i])

[854, 486, 68, 268, 183, 1, 187, 203, 182, 1953, 252, 2, 238, 3516, 283, 610, 210, 1, 87, 578] [1219]
[1599, 1378, 152, 629, 1, 56, 1814, 833, 326, 2625, 2, 2001, 1248, 258, 692, 1131, 1, 2359, 2958, 1306] [103]
[1585, 2, 265, 606, 74, 299, 613, 1, 513, 1062, 304, 636, 1289, 2, 2161, 723, 855, 58, 603, 1] [362]

[1479, 1061, 96, 1270, 2, 4298, 74, 1361, 833, 1862, 1, 18, 278, 891, 582, 739, 2, 181, 1527, 485] [95]
[1021, 189, 260, 1533, 134, 2, 286, 3148, 624, 1585, 663, 1, 9, 483, 1127, 12, 38, 2, 1396, 214] [1978]
[695, 872, 243, 1, 71, 1062, 258, 1131, 69, 2, 707, 789, 402, 13, 309, 1, 8, 153, 79, 325] [265]

# 转化为 numpy
x_mat = np.array(x_token)
y_mat = np.array(y_token)
# 查看一下 x_seq_ls 和 y_seq_ls
for i in range(20):print(x_mat[i],y_mat[i])

[ 854 486 68 268 183 1 187 203 182 1953 252 2 238 3516
283 610 210 1 87 578] [1219]
[1599 1378 152 629 1 56 1814 833 326 2625 2 2001 1248 258
692 1131 1 2359 2958 1306] [103]

[1021 189 260 1533 134 2 286 3148 624 1585 663 1 9 483
1127 12 38 2 1396 214] [1978]
[ 695 872 243 1 71 1062 258 1131 69 2 707 789 402 13
309 1 8 153 79 325] [265]

查看数据集形状

x_mat,x_mat.shape

(array([[ 854, 486, 68, …, 1, 87, 578],
[1599, 1378, 152, …, 2359, 2958, 1306],
[1585, 2, 265, …, 58, 603, 1],
…,
[ 781, 1, 644, …, 1, 103, 155],
[ 108, 552, 819, …, 472, 231, 231],
[ 19, 165, 2, …, 12, 14, 794]]),
(56294, 20))

y_mat,y_mat.shape

(array([[1219],
[ 103],
[ 362],
…,
[ 4],
[ 145],
[ 1]]),
(56294, 1))

划分训练集、测试集

# 用 sklearn 对数据集进行划分
from sklearn.model_selection import train_test_split
# 划分为训练集和测试集(7:3)
x_train, x_test, y_train, y_test = train_test_split(x_mat, y_mat, test_size=0.3, random_state=42, shuffle=True)
x_train.shape, x_test.shape, y_train.shape, y_test.shape

((39405, 20), (16889, 20), (39405, 1), (16889, 1))

训练模型

搭建网络

# 开始搭建网络
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM, Dropout
from keras.optimizers import Adammodel = Sequential()
model.add(Embedding(10000, 128, input_length=20))
model.add(LSTM(128, return_sequences=True))
model.add(Dropout(0.2))
model.add(LSTM(128))
model.add(Dropout(0.2))
model.add(Dense(10000, activation='softmax'))# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=Adam(learning_rate=0.001), metrics=['accuracy'])
model.summary()

# 测试样本能否正常输入网络
pred = model.predict(x_train[:1])
pred.shape

(1, 10000)

# y 标签转化为 one-hot 编码 (因为使用的损失函数是 categorical_crossentropy,而不是 sparse_categorical_crossentropy)
# 如果使用 sparse_categorical_crossentropy,就不需要转化为 one-hot 编码
# 【重要】from keras.utils import to_categorical
y_train_onehot = to_categorical(y_train, num_classes=10000)
y_test_onehot = to_categorical(y_test, num_classes=10000)
# 开始训练
history = model.fit(x_train, y_train_onehot, batch_size=32, epochs=40, validation_data=(x_test, y_test_onehot))

保存模型 和 tokenizer

# 保存模型
model.save('tang_poet_LSTM_v1.h5')
# 保存 tokenizer
import pickle
# saving
with open('tokenizer_tang_poet_LSTM_v1.pkl', 'wb') as f: pickle.dump(tokenizer, f) 

测试模型

test_string = '白日依山盡,黃河入海流,欲窮千里目,更上一'for i in range(300):# 循环100步,每步生成一个字符test_string_token = tokenizer.texts_to_sequences([test_string[-20:]])test_string_mat = np.array(test_string_token)pred = model.predict(test_string_mat)pred_argmax = np.argmax(pred, axis=1)[0]# 把 pred_argmax 对应的词语找出来tokenizer.index_word[pred_argmax]test_string = test_string + tokenizer.index_word[pred_argmax]
test_string

繁体输出

‘白日依山盡,黃河入海流,欲窮千里目,更上一拙羣。胡杯越泛起。自恐承雙苔,未由不相聽。況思無里息,無里白雲疎。早石侵飯客,天風亦相然。更知何處在,病路得成情。羽縫當貴者,不知無子娛,何年何處在?晨棄信音名。但見千年在,病酒洞花香。況思不未盡,歸處亦無然。韜知如揀咽郡,霜區隔柴陽。況月未相至,風寡獨相然。何日明雲久,歸郭得杉空。彈穴螢浦色,秋月入樓花。獨道分霖水,歸郭亦離春。臥寢分傳日,美人見此塵。獨思千情客,不知不相過。況思無限客,不知亦不然。更知如此在,目我是公情。獨憶分霖色,搖連隔鏡稀。歛暖沙野續,風雲不不經。何日空空急,風山一北花。況思分氣散,花聲蜀翠翎。風水白山月,歸郭雪頭風。獨照塵峰外,秋痕入抽湯。短苒崦亭裏,韶悵滿羅氳。自得’

简体输出

‘白日依山尽,黄河入海流,欲穷千里目,更上一拙羣。胡杯越泛起。自恐承双苔,未由不相听。况思无里息,无里白云疎。早石侵饭客,天风亦相然。更知何处在,病路得成情。羽缝当贵者,不知无子娱,何年何处在?晨弃信音名。但见千年在,病酒洞花香。况思不未尽,归处亦无然。韬知如拣咽郡,霜区隔柴阳。况月未相至,风寡独相然。何日明云久,归郭得杉空。弹穴萤浦色,秋月入楼花。独道分霖水,归郭亦离春。卧寝分传日,美人见此尘。独思千情客,不知不相过。况思无限客,不知亦不然。更知如此在,目我是公情。独忆分霖色,摇连隔镜稀。歛暖沙野续,风云不不经。何日空空急,风山一北花。况思分气散,花声蜀翠翎。风水白山月,归郭雪头风。独照尘峰外,秋痕入抽汤。短苒崦亭里,韶怅满罗氲。自得’

【NLP】LSTM 唐诗生成器相关推荐

  1. 【NLP】keras Transformer 唐诗生成器

    本文使用了 keras-transformer 第三方库,库的模型结构和使用方法如下图,需要构造 encoder_input (x1) decoder_input (x2) decoder_outpu ...

  2. NLP深度学习:近期趋势概述

    NLP&深度学习:近期趋势概述 https://www.cnblogs.com/DicksonJYL/p/9686204.html 摘要:当NLP遇上深度学习,到底发生了什么样的变化呢? 在最 ...

  3. NLP深度学习:近期趋势概述 1

    摘要:当NLP遇上深度学习,到底发生了什么样的变化呢? 在最近发表的论文中,Young及其同事汇总了基于深度学习的自然语言处理(NLP)系统和应用程序的一些最新趋势.本文的重点介绍是对各种NLP任务( ...

  4. NLP深度学习:近期趋势概述(二)

    递归神经网络(RNN) RNN是专门用于处理顺序信息的神经网络的方法.RNN将计算应用于以先前计算结果为条件的输入序列.这些序列通常由固定大小的标记向量表示,他们被顺序送至循环单元.下图说明了一个简单 ...

  5. Keras LSTM教程

    **本文更完整的内容请参考极客教程的深度学习专栏:https://geek-docs.com/deep-learning/rnn/keras-lstm-tutorial.html,欢迎提出您的宝贵意见 ...

  6. 【NLP】基于GAN的文本生成综述

    论文一.<Generative Adversarial Nets>NIPS 2014 1.模型简述 这篇论文是最早提出 GAN 的文章,作者 Ian J. Goodfellow 提出了一种 ...

  7. 第七课 循环神经网络与自然语言处理

    文章目录 1 循环神经网络 1.1 场景与多种应用 1.2 RNN网络结构 1.2.1为什么需要RNN 1.2.2 RNN 结构 1.3 多种RNN 1.4 BPTT算法 1.5 生成模型与图像描述 ...

  8. 深度学习04-RNN

    文章目录 1 为什么需要RNN 1.1RNN的应用场景 1.2 DNN和CNN不能解决的问题 2 RNN的网络结构 2.1 RNN基础结构 2.2 不同类型的RNN 3 RNN的优化算法BPTT 4 ...

  9. 假新闻无处不在:我开源了一个深度学习标记假新闻项目

    来源 | DeepHub IMBA(ID:deephub-imba) 虚假新闻的兴起迫使拥有社交媒体帐户的每个人都成为一名侦探,负责在发布前确定帖子是否真实.但是,虚假新闻仍然会越过我们的防线,在网络 ...

  10. 【论文笔记】Adversarial Multi-task Learning for Text Classification

    一.概要   该文章发于ACL 2017,针对于已有的大多数神经网络多任务学习模型进行文本分类,存在的共享特征(shared features)可能再次包含特定任务特征(task-specific f ...

最新文章

  1. 1.2.1 分层结构 协议 接口 服务
  2. 汇编-理解call,ret
  3. linux判断网卡是否挂载,检查网卡是否加载 - Linux服务器网卡驱动安装及故障排除_服务器应用_Linux公社-Linux系统门户网站...
  4. [转]url传递中文的解决方案总结
  5. Javascript中字符串转数字
  6. linux mysql 确认命令_LINUX启动/重启/停上MYSQL的命令(详解)
  7. POJ 1716 Integer Intervals
  8. 手摸手带你入门ionic3(六):界面跳转
  9. NSGA2算法MATLAB
  10. Nginx 反向代理解决浏览器跨域问题
  11. Unity中设置对象匀速移动
  12. win7 IE11浏览器怎么改成中文
  13. 并网逆变器PI控制(并网模式)
  14. 专家学者热议智慧交通:大数据云计算,出行有“千里眼”
  15. IntelliJ Idea入门教程:如何新建项目
  16. 配置JAVA 环境变量
  17. 程序员一直在寻找的日程安排工具
  18. 计算机专业英语2013版第二章答案,2013计算机专业英语第2章.ppt
  19. 芯洲科技SCT2462Q车规DCDC,完全替代LM61460。
  20. 别再手动比对文件啦,Python 让你轻松实现文件内容以及目录对比!!!

热门文章

  1. 集合例题3.:现在有一个map集合如下:Map<Integer,String> map = new HashMap<Integer, String>();map.put(1, “张三丰“);map.
  2. dna计算机 论文,科学网—阿德勒曼DNA计算机(科普) - 沈海军的博文
  3. P5_扩展卡尔曼滤波器_udacity无人驾驶
  4. 炫酷渐变色背景粒子线条折线连接canvas动画
  5. Android开发方向,我们如何选择?
  6. 软件外包公司到底干啥的?要不要去外包公司?
  7. 钉钉小程序摸索二:钉钉小程序开发过程中错误解决过程
  8. python二级考试大纲小学生_Python语言二级考试大纲
  9. JavaScript下载后端返回的文件流
  10. 如何用艺术字制作水印?教你一招快速制作专属水印