尝试用char-RNN生成古诗,本来是想要尝试用来生成广告文案的,测试一波生成古诗的效果。嘛,虽然我对业务兴趣不大,不过这个模型居然把我硬盘跑挂了,也是醉。

其实Char-RNN来生成文本的逻辑非常简单,就是一个字一个字放进去,让RNN开始学,按照前面的字预测下面的字。所以就要想办法把文本揉成我们需要的格式。

比如说,我们现在有一句诗“床前明月光,疑是地上霜”。那么我们的输入就是“床前明月光”,那么我们的预测就是“前明月光,”,其实就是错位一位。

然后我们要考虑的是如何批量的把数据喂进去,这里参考了gluon的教程上面的一个操作,因为诗歌是有上下文联系的,如果我们用随机选取的话,很可能就会丢掉很多有用的信息,所以我们还要想办法将诗歌的这种连续性保留下来。

mxnet教程的方法是先将所有的文本串成一行。所有的换行符替换为空格,所以空格在这里起到了分段的作用,空格也就有了意义。然后我们因为我们要批量训练,所以先按照我们每批打算训练多少行文本,将这一个超长的文本截断成这样,然后按照我们一次想看多少个字的窗口扫描过去。代码实现上如下: 1

2

3

4

5

6

7

8

9

10

11

12def data_iter_consecutive(corpus_indices, batch_size, num_steps):

corpus_indices = torch.tensor(corpus_indices)

data_len = len(corpus_indices)

batch_len = data_len // batch_size

indices = corpus_indices[0: batch_size*batch_len].reshape((

batch_size, batch_len))

epoch_size = (batch_len - 1) // num_steps

for i in range(epoch_size):

i = i * num_steps

X = indices[:, i: i + num_steps]

Y = indices[:, i + 1: i + num_steps + 1]

yield X, Y

这样有一个好处就是可以保持诗句的连续性,效果上大概是:

1

2

3

4

5

6

7# 所有诗句拼成一行

[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

# batch_size = 2, num_steps = 3

# batch 1

[[1, 2, 3], [7, 8, 9]]

# batch 2

[[4, 5, 6], [10, 11, 12]]

这样一来,一句诗[1, 2, 3, 4, 5, 6]就能在不同batch里面保持连贯性了。

然后就是很简单设计网络:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33class lyricNet(nn.Module):

def __init__(self, hidden_dim, embed_dim, num_layers, weight,

num_labels, bidirectional, dropout=0.5, **kwargs):

super(lyricNet, self).__init__(**kwargs)

self.hidden_dim = hidden_dim

self.embed_dim = embed_dim

self.num_layers = num_layers

self.num_labels = num_labels

self.bidirectional = bidirectional

if num_layers <= 1:

self.dropout = 0

else:

self.dropout = dropout

self.embedding = nn.Embedding.from_pretrained(weight)

self.embedding.weight.requires_grad = False

# self.embedding = nn.Embedding(num_labels, self.embed_dim)

self.rnn = nn.GRU(input_size=self.embed_dim, hidden_size=self.hidden_dim,

num_layers=self.num_layers, bidirectional=self.bidirectional,

dropout=self.dropout)

if self.bidirectional:

self.decoder = nn.Linear(hidden_dim * 2, self.num_labels)

else:

self.decoder = nn.Linear(hidden_dim, self.num_labels)

def forward(self, inputs, hidden=None):

embeddings = self.embedding(inputs)

states, hidden = self.rnn(embeddings.permute([1, 0, 2]), hidden)

outputs = self.decoder(states.reshape((-1, states.shape[-1])))

return(outputs, hidden)

def init_hidden(self, num_layers, batch_size, hidden_dim, **kwargs):

hidden = torch.zeros(num_layers, batch_size, hidden_dim)

return hidden

这里我用的是很简单的one-hot做词向量,当然数据量大一点可以考虑pretrained的字向量。不过直观感受上用白话文训练的字向量应该效果不会太好吧。

接着就可以开始训练了:

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28for epoch in range(num_epoch):

start = time.time()

num, total_loss = 0, 0

data = data_iter_consecutive(corpus_indice, batch_size, 35)

hidden = model.init_hidden(num_layers, batch_size, hidden_dim)

for X, Y in data:

num += 1

hidden.detach_()

if use_gpu:

X = X.to(device)

Y = Y.to(device)

hidden = hidden.to(device)

optimizer.zero_grad()

output, hidden = model(X, hidden)

l = loss_function(output, Y.t().reshape((-1,)))

l.backward()

norm = nn.utils.clip_grad_norm_(model.parameters(), 1e-2)

optimizer.step()

total_loss += l.item()

end = time.time()

s = end - since

h = math.floor(s / 3600)

m = s - h * 3600

m = math.floor(m / 60)

s -= m * 60

if (epoch % 10 == 0) or (epoch == (num_epoch - 1)):

print('epoch %d/%d, loss %.4f, norm %.4f, time %.3fs, since %dh %dm %ds'

%(epoch+1, num_epoch, total_loss / num, norm, end-start, h, m, s))

这里的训练过程需要注意两个点,一个是hidden的initial,因为我们想要保持句子的连续性,所以我们hidden的initial只要每个epoch的第一次initial一下就可以了,后面训练的过程中需要从计算图中拿掉。另外就是因为有梯度爆炸的问题,所以我们需要对梯度进行修剪。

最后一个是我自己最容易犯错的地方,死活记不住的就是RNN的输入输出每个dimension都代表了什么含义。原始的RNN接受的输入是(seq_len, batch_size, embedding_dimension),输出的是(seq_len, batch_size, num_direction * hidden_dim)。所以我们习惯的batch在先的数据需要在这里做一个permute,将batch和seq做一下调换。然后就是我们做分类的时候,直接flatten成为一个长向量的时候,其实已经变成了[seq_len, seq_len, ...]这样的样子。简单理解就是本来我们都是横着看诗歌的,现在模型的输出是竖着输出的。所以我们后面算loss的时候,y也需要做一个转置再flatten。

具体的可以看我的这个notebook。

接下来可能想试一下的是如果不用这种方法的话,是不是可以用padding的方法把句子长度统一再训练。

另外强势推荐最全中华古诗词数据库。数据非常非常全了。

后面如果要做到很好的效果可以做的方向一个是做韵脚的信息,还有就是平仄的信息也带进去。

anyway,想了一下,这样训练完的hidden是不是就包含了一个作者的文风信息?!

java rnn生成古诗_Char-RNN生成古诗相关推荐

  1. Char RNN原理介绍以及文本生成实践

    正文共1523张图,3张图,预计阅读时间8分钟. 1.简介 Char-RNN,字符级循环神经网络,出自于Andrej Karpathy写的The Unreasonable Effectiveness ...

  2. windows7下,Java中利用JNI调用c++生成的动态库的使用步骤

    1.从http://www.oracle.com/technetwork/java/javase/downloads/jdk-7u2-download-1377129.html下载jdk-7u2-wi ...

  3. Java 快速开发二维码生成服务

    点击上方蓝色"程序猿DD",选择"设为星标" 回复"资源"获取独家整理的学习资料! 来源 | 公众号「码农小胖哥」 1. 前言 不知道从什么 ...

  4. java token生成和验证_java生成定长度的随机验证码

    平凡也就两个字: 懒和惰;成功也就两个字: 苦和勤;优秀也就两个字: 你和我.跟着我从0学习JAVA.spring全家桶和linux运维等知识,带你从懵懂少年走向人生巅峰,迎娶白富美!每一篇文章都是心 ...

  5. java生成小图片_JAVA生成缩略小图片类

    JAVA生成缩略小图片类 java.awt.image.BufferedImage是缓冲图片类主要将生成的图片对象缓冲起来:javax.imageio.ImageIO是图片IO控制类,可以将缓冲图片对 ...

  6. 菜鸟学Java(六)——简单验证码生成(Java版)

    转载自  菜鸟学Java(六)--简单验证码生成(Java版) 验证码大家都知道,它的作用也不用我多说了吧.如果不太清楚请参见百度百科中的解释,一般验证码的生成就是随机产生字符(数字.字母或者汉字等) ...

  7. java 生成校验验证码_java生成验证码并进行验证

    一实现思路使用BufferedImage用于在内存中存储生成的验证码图片使用Graphics来进行验证码图片的绘制,并将绘制在图片上的验证码存放到session中用于后续验证 最后通过ImageIO将 ...

  8. java 生成客户端代码_swagger-codegen生成java客户端代码

    前后端分离的时候,需要建立契约,Swagger可达到该目的(略). 建立Rest接口后,通过swagger-codegen项目可以自动生成对应的客户端代码(c++.php.java.js.node等等 ...

  9. apache poi使用例_使用java Apache poi 根据word模板生成word报表例子

    [实例简介] 使用java Apache poi 根据word模板生成word报表 仅支持docx格式的word文件,大概是word2010及以后版本,doc格式不支持. 使用说明:https://b ...

  10. 极客技术专题【003期】:java mvc 增删改查 自动生成工具来袭

    日期:2013-4-17  来源:GBin1.com 技术专题介绍 分享专题:java mvc 增删改查 自动生成工具来袭 分享人:激情燃烧的UI 授课时间:2013/04/19  21:00-22: ...

最新文章

  1. 【洛谷搜索专题Python和C++解】DFS和BFS经典题目(陆续补充)
  2. 《github一天一道算法题》:分治法求数组最大连续子序列和
  3. 《算法竞赛进阶指南》打卡-基本算法-AcWing 93. 递归实现组合型枚举:递归与递推、dfs、状态压缩
  4. 监控服务器怎么增加碟机,微服务业务监控方法及服务器专利_专利申请于2017-12-15_专利查询 - 天眼查...
  5. 在线ajax测试,在线测试 - SosoApi,简单强大的api接口文档管理平台
  6. 计算机网络原理(第二章)课后题答案
  7. 快手2021年营收810亿元 经调整净亏损188亿元
  8. 新来的同事把公司现有项目的性能优化了一遍,来看看他是怎么做到的
  9. 谈谈浮动和清除浮动?
  10. leetcode python3 简单题112. Path Sum
  11. 国际项目投标那些事(六)投标文件怎么写才能惊呆业主 WTSolutions
  12. java物业管理系统描述,基于java小区物业管理系统.doc
  13. AE光效效果插件:Trapcode Shine
  14. linux下gbd调试基础
  15. zcmu--1931: wjw的剪纸(dfs+枚举)
  16. iPhone的设置中,找不到“开发者选项”
  17. 算法工程 # 深度学习算法落地最后一公里:工业界中的大规模向量检索
  18. 电商Banner设计背后的12个人性的秘密
  19. 小米note2 支付宝指纹支付 -10008
  20. 海边旅行必备物品清单

热门文章

  1. 使用Post不传Body,出现socket hang up报错
  2. 每日10行代码57: appium测试坚果手机出现socket hang up报错的解决
  3. MySQL复习笔记(三)
  4. 怎样将语音转化为文字
  5. 基于Arduino锂电池容量测试仪
  6. Python super( ) 函数详解
  7. JavaScript是多线程还是单线程?
  8. LC振荡电路L和C 参数越小 频率越高
  9. HPUoj1210: OY问题 [搜索](DFS
  10. FCBF算法的Matlab实现