java rnn生成古诗_Char-RNN生成古诗
尝试用char-RNN生成古诗,本来是想要尝试用来生成广告文案的,测试一波生成古诗的效果。嘛,虽然我对业务兴趣不大,不过这个模型居然把我硬盘跑挂了,也是醉。
其实Char-RNN来生成文本的逻辑非常简单,就是一个字一个字放进去,让RNN开始学,按照前面的字预测下面的字。所以就要想办法把文本揉成我们需要的格式。
比如说,我们现在有一句诗“床前明月光,疑是地上霜”。那么我们的输入就是“床前明月光”,那么我们的预测就是“前明月光,”,其实就是错位一位。
然后我们要考虑的是如何批量的把数据喂进去,这里参考了gluon的教程上面的一个操作,因为诗歌是有上下文联系的,如果我们用随机选取的话,很可能就会丢掉很多有用的信息,所以我们还要想办法将诗歌的这种连续性保留下来。
mxnet教程的方法是先将所有的文本串成一行。所有的换行符替换为空格,所以空格在这里起到了分段的作用,空格也就有了意义。然后我们因为我们要批量训练,所以先按照我们每批打算训练多少行文本,将这一个超长的文本截断成这样,然后按照我们一次想看多少个字的窗口扫描过去。代码实现上如下: 1
2
3
4
5
6
7
8
9
10
11
12def data_iter_consecutive(corpus_indices, batch_size, num_steps):
corpus_indices = torch.tensor(corpus_indices)
data_len = len(corpus_indices)
batch_len = data_len // batch_size
indices = corpus_indices[0: batch_size*batch_len].reshape((
batch_size, batch_len))
epoch_size = (batch_len - 1) // num_steps
for i in range(epoch_size):
i = i * num_steps
X = indices[:, i: i + num_steps]
Y = indices[:, i + 1: i + num_steps + 1]
yield X, Y
这样有一个好处就是可以保持诗句的连续性,效果上大概是:
1
2
3
4
5
6
7# 所有诗句拼成一行
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
# batch_size = 2, num_steps = 3
# batch 1
[[1, 2, 3], [7, 8, 9]]
# batch 2
[[4, 5, 6], [10, 11, 12]]
这样一来,一句诗[1, 2, 3, 4, 5, 6]就能在不同batch里面保持连贯性了。
然后就是很简单设计网络:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33class lyricNet(nn.Module):
def __init__(self, hidden_dim, embed_dim, num_layers, weight,
num_labels, bidirectional, dropout=0.5, **kwargs):
super(lyricNet, self).__init__(**kwargs)
self.hidden_dim = hidden_dim
self.embed_dim = embed_dim
self.num_layers = num_layers
self.num_labels = num_labels
self.bidirectional = bidirectional
if num_layers <= 1:
self.dropout = 0
else:
self.dropout = dropout
self.embedding = nn.Embedding.from_pretrained(weight)
self.embedding.weight.requires_grad = False
# self.embedding = nn.Embedding(num_labels, self.embed_dim)
self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,
num_layers=self.num_layers, bidirectional=self.bidirectional,
dropout=self.dropout)
if self.bidirectional:
self.decoder = nn.Linear(hidden_dim * 2, self.num_labels)
else:
self.decoder = nn.Linear(hidden_dim, self.num_labels)
def forward(self, inputs, hidden=None):
embeddings = self.embedding(inputs)
states, hidden = self.rnn(embeddings.permute([1, 0, 2]), hidden)
outputs = self.decoder(states.reshape((-1, states.shape[-1])))
return(outputs, hidden)
def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):
hidden = torch.zeros(num_layers, batch_size, hidden_dim)
return hidden
这里我用的是很简单的one-hot做词向量,当然数据量大一点可以考虑pretrained的字向量。不过直观感受上用白话文训练的字向量应该效果不会太好吧。
接着就可以开始训练了:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28for epoch in range(num_epoch):
start = time.time()
num, total_loss = 0, 0
data = data_iter_consecutive(corpus_indice, batch_size, 35)
hidden = model.init_hidden(num_layers, batch_size, hidden_dim)
for X, Y in data:
num += 1
hidden.detach_()
if use_gpu:
X = X.to(device)
Y = Y.to(device)
hidden = hidden.to(device)
optimizer.zero_grad()
output, hidden = model(X, hidden)
l = loss_function(output, Y.t().reshape((-1,)))
l.backward()
norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)
optimizer.step()
total_loss += l.item()
end = time.time()
s = end - since
h = math.floor(s / 3600)
m = s - h * 3600
m = math.floor(m / 60)
s -= m * 60
if (epoch % 10 == 0) or (epoch == (num_epoch - 1)):
print('epoch %d/%d, loss %.4f, norm %.4f, time %.3fs, since %dh %dm %ds'
%(epoch+1, num_epoch, total_loss / num, norm, end-start, h, m, s))
这里的训练过程需要注意两个点,一个是hidden的initial,因为我们想要保持句子的连续性,所以我们hidden的initial只要每个epoch的第一次initial一下就可以了,后面训练的过程中需要从计算图中拿掉。另外就是因为有梯度爆炸的问题,所以我们需要对梯度进行修剪。
最后一个是我自己最容易犯错的地方,死活记不住的就是RNN的输入输出每个dimension都代表了什么含义。原始的RNN接受的输入是(seq_len, batch_size, embedding_dimension),输出的是(seq_len, batch_size, num_direction * hidden_dim)。所以我们习惯的batch在先的数据需要在这里做一个permute,将batch和seq做一下调换。然后就是我们做分类的时候,直接flatten成为一个长向量的时候,其实已经变成了[seq_len, seq_len, ...]这样的样子。简单理解就是本来我们都是横着看诗歌的,现在模型的输出是竖着输出的。所以我们后面算loss的时候,y也需要做一个转置再flatten。
具体的可以看我的这个notebook。
接下来可能想试一下的是如果不用这种方法的话,是不是可以用padding的方法把句子长度统一再训练。
另外强势推荐最全中华古诗词数据库。数据非常非常全了。
后面如果要做到很好的效果可以做的方向一个是做韵脚的信息,还有就是平仄的信息也带进去。
anyway,想了一下,这样训练完的hidden是不是就包含了一个作者的文风信息?!
java rnn生成古诗_Char-RNN生成古诗相关推荐
- Char RNN原理介绍以及文本生成实践
正文共1523张图,3张图,预计阅读时间8分钟. 1.简介 Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness ...
- windows7下,Java中利用JNI调用c++生成的动态库的使用步骤
1.从http://www.oracle.com/technetwork/java/javase/downloads/jdk-7u2-download-1377129.html下载jdk-7u2-wi ...
- Java 快速开发二维码生成服务
点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 来源 | 公众号「码农小胖哥」 1. 前言 不知道从什么 ...
- java token生成和验证_java生成定长度的随机验证码
平凡也就两个字: 懒和惰;成功也就两个字: 苦和勤;优秀也就两个字: 你和我.跟着我从0学习JAVA.spring全家桶和linux运维等知识,带你从懵懂少年走向人生巅峰,迎娶白富美!每一篇文章都是心 ...
- java生成小图片_JAVA生成缩略小图片类
JAVA生成缩略小图片类 java.awt.image.BufferedImage是缓冲图片类主要将生成的图片对象缓冲起来:javax.imageio.ImageIO是图片IO控制类,可以将缓冲图片对 ...
- 菜鸟学Java(六)——简单验证码生成(Java版)
转载自 菜鸟学Java(六)--简单验证码生成(Java版) 验证码大家都知道,它的作用也不用我多说了吧.如果不太清楚请参见百度百科中的解释,一般验证码的生成就是随机产生字符(数字.字母或者汉字等) ...
- java 生成校验验证码_java生成验证码并进行验证
一实现思路使用BufferedImage用于在内存中存储生成的验证码图片使用Graphics来进行验证码图片的绘制,并将绘制在图片上的验证码存放到session中用于后续验证 最后通过ImageIO将 ...
- java 生成客户端代码_swagger-codegen生成java客户端代码
前后端分离的时候,需要建立契约,Swagger可达到该目的(略). 建立Rest接口后,通过swagger-codegen项目可以自动生成对应的客户端代码(c++.php.java.js.node等等 ...
- apache poi使用例_使用java Apache poi 根据word模板生成word报表例子
[实例简介] 使用java Apache poi 根据word模板生成word报表 仅支持docx格式的word文件,大概是word2010及以后版本,doc格式不支持. 使用说明:https://b ...
- 极客技术专题【003期】:java mvc 增删改查 自动生成工具来袭
日期:2013-4-17 来源:GBin1.com 技术专题介绍 分享专题:java mvc 增删改查 自动生成工具来袭 分享人:激情燃烧的UI 授课时间:2013/04/19 21:00-22: ...
最新文章
- 【洛谷搜索专题Python和C++解】DFS和BFS经典题目(陆续补充)
- 《github一天一道算法题》:分治法求数组最大连续子序列和
- 《算法竞赛进阶指南》打卡-基本算法-AcWing 93. 递归实现组合型枚举:递归与递推、dfs、状态压缩
- 监控服务器怎么增加碟机,微服务业务监控方法及服务器专利_专利申请于2017-12-15_专利查询 - 天眼查...
- 在线ajax测试,在线测试 - SosoApi,简单强大的api接口文档管理平台
- 计算机网络原理(第二章)课后题答案
- 快手2021年营收810亿元 经调整净亏损188亿元
- 新来的同事把公司现有项目的性能优化了一遍,来看看他是怎么做到的
- 谈谈浮动和清除浮动?
- leetcode python3 简单题112. Path Sum
- 国际项目投标那些事(六)投标文件怎么写才能惊呆业主 WTSolutions
- java物业管理系统描述,基于java小区物业管理系统.doc
- AE光效效果插件:Trapcode Shine
- linux下gbd调试基础
- zcmu--1931: wjw的剪纸(dfs+枚举)
- iPhone的设置中,找不到“开发者选项”
- 算法工程 # 深度学习算法落地最后一公里:工业界中的大规模向量检索
- 电商Banner设计背后的12个人性的秘密
- 小米note2 支付宝指纹支付 -10008
- 海边旅行必备物品清单