在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdbclass Config(object):data_path = 'data/'pickle_path = 'tang.npz'author = Noneconstrain = Nonecategory = 'poet.tang' #or poet.songlr = 1e-3weight_decay = 1e-4use_gpu = Trueepoch = 20batch_size = 128maxlen = 125plot_every = 20#use_env = True #是否使用visodmenv = 'poety' #visdom envmax_gen_len = 200debug_file = '/tmp/debugp'model_path = Noneprefix_words = '细雨鱼儿出,微风燕子斜。' #不是诗歌组成部分,是意境start_words = '闲云潭影日悠悠' #诗歌开始acrostic = False #是否藏头model_prefix = 'checkpoints/tang' #模型保存路径
opt = Config()def generate(model, start_words, ix2word, word2ix, prefix_words=None):'''给定几个词,根据这几个词接着生成一首完整的诗歌'''results = list(start_words)start_word_len = len(start_words)# 手动设置第一个词为<START># 这个地方有问题,最后需要再看一下input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())if opt.use_gpu:input=input.cuda()hidden = Noneif prefix_words:for word in prefix_words:output,hidden = model(input,hidden)# 下边这句话是为了把input变成1*1?input = Variable(input.data.new([word2ix[word]])).view(1,1)for i in range(opt.max_gen_len):output,hidden = model(input,hidden)if i<start_word_len:w = results[i]input = Variable(input.data.new([word2ix[w]])).view(1,1)else:top_index = output.data[0].topk(1)[1][0]w = ix2word[top_index]results.append(w)input = Variable(input.data.new([top_index])).view(1,1)if w=='<EOP>':del results[-1] #-1的意思是倒数第一个breakreturn resultsdef gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):'''生成藏头诗start_words : u'深度学习'生成:深木通中岳,青苔半日脂。度山分地险,逆浪到南巴。学道兵犹毒,当时燕不移。习根通古岸,开镜出清羸。'''results = []start_word_len = len(start_words)input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())if opt.use_gpu:input=input.cuda()hidden = Noneindex=0 # 用来指示已经生成了多少句藏头诗# 上一个词pre_word='<START>'if prefix_words:for word in prefix_words:output,hidden = model(input,hidden)input = Variable(input.data.new([word2ix[word]])).view(1,1)for i in range(opt.max_gen_len):output,hidden = model(input,hidden)top_index  = output.data[0].topk(1)[1][0]w = ix2word[top_index]if (pre_word  in {u'。',u'!','<START>'} ):# 如果遇到句号,藏头的词送进去生成if index==start_word_len:# 如果生成的诗歌已经包含全部藏头的词,则结束breakelse:  # 把藏头的词作为输入送入模型w = start_words[index]index+=1input = Variable(input.data.new([word2ix[w]])).view(1,1)    else:# 否则的话,把上一次预测是词作为下一个词输入input = Variable(input.data.new([word2ix[w]])).view(1,1)results.append(w)pre_word = wreturn resultsdef train(**kwargs):for k,v in kwargs.items():setattr(opt,k,v) #设置apt里属性的值vis = Visualizer(env=opt.env)#获取数据data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数data = t.from_numpy(data)#这个地方出错了,是大写的Ldataloader = t.utils.data.DataLoader(data, batch_size = opt.batch_size,shuffle = True,num_workers = 1) #在python里,这样写程序可以吗?#模型定义model = PoetryModel(len(word2ix), 128, 256)optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)criterion = nn.CrossEntropyLoss()if opt.model_path:model.load_state_dict(t.load(opt.model_path))if opt.use_gpu:model.cuda()criterion.cuda()#The tnt.AverageValueMeter measures and returns the average value #and the standard deviation of any collection of numbers that are #added to it. It is useful, for instance, to measure the average #loss over a collection of examples.#The add() function expects as input a Lua number value, which #is the value that needs to be added to the list of values to #average. It also takes as input an optional parameter n that #assigns a weight to value in the average, in order to facilitate #computing weighted averages (default = 1).#The tnt.AverageValueMeter has no parameters to be set at initialization time. loss_meter = meter.AverageValueMeter()for epoch in range(opt.epoch):loss_meter.reset()for ii,data_ in tqdm.tqdm(enumerate(dataloader)):#tqdm是python中的进度条#训练data_ = data_.long().transpose(1,0).contiguous()#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的if opt.use_gpu: data_ = data_.cuda()optimizer.zero_grad()input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])#上边一句,将输入的诗句错开一个字,形成训练和目标output,_ = model(input_)loss = criterion(output, target.view(-1))loss.backward()optimizer.step()loss_meter.add(loss.data[0]) #为什么是data[0]?#可视化用到的是utlis.py里的函数if (1+ii)%opt.plot_every ==0:if os.path.exists(opt.debug_file):ipdb.set_trace()vis.plot('loss',loss_meter.value()[0])# 下面是对目前模型情况的测试,诗歌原文poetrys = [[ix2word[_word] for _word in data_[:,_iii]] for _iii in range(data_.size(1))][:16]#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文vis.text('</br>'.join([''.join(poetry) for poetry in poetrys]),win = u'origin_poem')gen_poetries = []#分别以以下几个字作为诗歌的第一个字,生成8首诗for word in list(u'春江花月夜凉如水'):gen_poetry = ''.join(generate(model,word,ix2word,word2ix))gen_poetries.append(gen_poetry)vis.text('</br>'.join([''.join(poetry) for poetry in gen_poetries]), win = u'gen_poem')t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))def gen(**kwargs):'''提供命令行接口,用以生成相应的诗'''for k,v in kwargs.items():setattr(opt,k,v)data, word2ix, ix2word = get_data(opt)model = PoetryModel(len(word2ix), 128, 256)map_location = lambda s,l:s# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,# 上边句子的意思是将模型加载到默认的GPU上state_dict = t.load(opt.model_path, map_location = map_location)model.load_state_dict(state_dict)if opt.use_gpu:model.cuda()if sys.version_info.major == 3:if opt.start_words.insprintable():start_words = opt.start_wordsprefix_words = opt.prefix_words if opt.prefix_words else Noneelse:start_words = opt.start_words.encode('ascii',\'surrogateescape').decode('utf8')prefix_words = opt.prefix_words.encode('ascii',\'surrogateescape').decode('utf8') if opt.prefix_words else Nonestart_words = start_words.replace(',',u',')\.replace('.',u'。')\.replace('?',u'?')gen_poetry = gen_acrostic if opt.acrostic else generateresult = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)print(''.join(result))
if __name__ == '__main__':import firefire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

pytorch下使用LSTM神经网络写诗相关推荐

  1. python程序写诗_pytorch下使用LSTM神经网络写诗实例

    在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗. 代码结构分为四部分,分别为 1.model.py,定义了双层LSTM模型 2.data.py,定义了从网上 ...

  2. 【PyTorch实战】用RNN写诗

    用RNN写诗 1. 背景 1.1 词向量 1.2 RNN 2. CharRNN 3. 用PyTorch实现CharRNN 4. 结果分析 参考资料 1. 背景 自然语言处理(Natural Langu ...

  3. PyTorch基础-使用LSTM神经网络实现手写数据集识别-08

    import numpy as np import torch from torch import nn,optim from torch.autograd import Variable from ...

  4. 基于pytorch下用LSTM做股票预测——超详细

    理论 LSTM理论详解 代码 请转到链接:文章详情 另外,欢迎大家打赏!

  5. Pytorch+LSTM+AI自动写诗实战

    文章目录 1.数据集和任务定义 2.读取数据集 3.数据预处理 4.数据制作 5.定义网络结构: 6.测试网络 7.可视化 8.总结 1.数据集和任务定义 本次采用的是唐诗数据集,一共有接近60000 ...

  6. 深度学习框架PyTorch入门与实践:第九章 AI诗人:用RNN写诗

    我们先来看一首诗. 深宫有奇物,璞玉冠何有. 度岁忽如何,遐龄复何欲. 学来玉阶上,仰望金闺籍. 习协万壑间,高高万象逼. 这是一首藏头诗,每句诗的第一个字连起来就是"深度学习". ...

  7. 深度学习(三)之LSTM写诗

    Python微信订餐小程序课程视频 https://edu.csdn.net/course/detail/36074 Python实战量化交易理财系统 https://edu.csdn.net/cou ...

  8. 干货 | 简简单单,用 LSTM 创造一个写诗机器人

    作者 | Carly Stambaugh 来源 | AI 科技评论 人们已经给神经网络找到了越来越多的事情做,比如画画和写诗,微软的小冰都已经出版了一本诗集了.而其实训练一个能写诗的神经网络并不难,A ...

  9. 【深度学习】我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    今天我们将使用 Pytorch 来实现 LeNet-5 模型,并用它来解决 MNIST数据集的识别. 正文开始! 一.使用 LeNet-5 网络结构创建 MNIST 手写数字识别分类器 MNIST是一 ...

  10. 我用 PyTorch 复现了 LeNet-5 神经网络(MNIST 手写数据集篇)!

    大家好,我是红色石头! 在上一篇文章: 这可能是神经网络 LeNet-5 最详细的解释了! 详细介绍了卷积神经网络 LeNet-5 的理论部分.今天我们将使用 Pytorch 来实现 LeNet-5 ...

最新文章

  1. 麦子的第一个注解+spring小案例 欢迎指点学习。
  2. 月读の自动读取 根据注释解析列名和字典
  3. Jquery中使用setInterval和setTimeout
  4. 实践1-qq邮箱主页
  5. Linux sed 删除行命令常见使用详解
  6. 763. 划分字母区间009(贪心算法+思路+详解+图示)
  7. LSGO软件技术团队2015~2016学年第十二周(1116~1122)总结
  8. [Python] L1-045 宇宙无敌大招呼-PAT团体程序设计天梯赛GPLT
  9. AngularDart Material Design 选项树
  10. Format - Numeric
  11. ActivityMq下载、安装、使用
  12. tplink迷你路由器中继模式_[转载]TL-WR800N迷你型无线路由器Repeater模式(中继模式)设置教程...
  13. 基于51单片机的教室人数检测
  14. 利用线性回归进行销售预测
  15. 手机停机后你们知道怎么打电话?教你鲜为人知的手机锦囊
  16. uniapp - app 获取短信内容
  17. 开源商城WSTMart支付开发研究[转]
  18. 听歌学日语2 五十音图 たなは行
  19. [USACO2.1] 健康的荷斯坦奶牛 dfs
  20. 计算之魂算法复杂度的相关概念

热门文章

  1. 使用com.alibaba.druid.filter.config.ConfigTools进行加密和解密工具类
  2. 【51单片机】SG90舵机控制
  3. 单片机c语言内部ram移动,2012年微型计算机原理与接口技术自考题模拟(9)
  4. java测量麦克风音量_Android 获取麦克风的音量(分贝)
  5. 【转】PIC单片机入门笔记(新手学PIC必看)——基于PIC16F886
  6. Linux系统搭建房卡游戏教程,2020年H5电玩房卡游戏源码合集运营级:超强后台控制+部署教程文档...
  7. 基于TMC4361-超静音闭环步进电机驱动方案
  8. 如何看懂蓝桥杯单片机(CT107S)原理图
  9. 三极管开关电路_利用三极管设计开关电路
  10. matlab 频域采样定理,信号时域和频域采样函数周期性与原信号的关系