在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。

代码结构分为四部分,分别为

1.model.py,定义了双层LSTM模型

2.data.py,定义了从网上得到的唐诗数据的处理方法

3.utlis.py 定义了损失可视化的函数

4.main.py定义了模型参数,以及训练、唐诗生成函数。

参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章

main代码及注释如下

import sys, os

import torch as t

from data import get_data

from model import PoetryModel

from torch import nn

from torch.autograd import Variable

from utils import Visualizer

import tqdm

from torchnet import meter

import ipdb

class Config(object):

data_path = 'data/'

pickle_path = 'tang.npz'

author = None

constrain = None

category = 'poet.tang' #or poet.song

lr = 1e-3

weight_decay = 1e-4

use_gpu = True

epoch = 20

batch_size = 128

maxlen = 125

plot_every = 20

#use_env = True #是否使用visodm

env = 'poety'

#visdom env

max_gen_len = 200

debug_file = '/tmp/debugp'

model_path = None

prefix_words = '细雨鱼儿出,微风燕子斜。'

#不是诗歌组成部分,是意境

start_words = '闲云潭影日悠悠'

#诗歌开始

acrostic = False

#是否藏头

model_prefix = 'checkpoints/tang'

#模型保存路径

opt = Config()

def generate(model, start_words, ix2word, word2ix, prefix_words=None):

'''

给定几个词,根据这几个词接着生成一首完整的诗歌

'''

results = list(start_words)

start_word_len = len(start_words)

# 手动设置第一个词为

# 这个地方有问题,最后需要再看一下

input = Variable(t.Tensor([word2ix['']]).view(1,1).long())

if opt.use_gpu:input=input.cuda()

hidden = None

if prefix_words:

for word in prefix_words:

output,hidden = model(input,hidden)

# 下边这句话是为了把input变成1*1?

input = Variable(input.data.new([word2ix[word]])).view(1,1)

for i in range(opt.max_gen_len):

output,hidden = model(input,hidden)

if i

w = results[i]

input = Variable(input.data.new([word2ix[w]])).view(1,1)

else:

top_index = output.data[0].topk(1)[1][0]

w = ix2word[top_index]

results.append(w)

input = Variable(input.data.new([top_index])).view(1,1)

if w=='':

del results[-1] #-1的意思是倒数第一个

break

return results

def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):

'''

生成藏头诗

start_words : u'深度学习'

生成:

深木通中岳,青苔半日脂。

度山分地险,逆浪到南巴。

学道兵犹毒,当时燕不移。

习根通古岸,开镜出清羸。

'''

results = []

start_word_len = len(start_words)

input = Variable(t.Tensor([word2ix['']]).view(1,1).long())

if opt.use_gpu:input=input.cuda()

hidden = None

index=0 # 用来指示已经生成了多少句藏头诗

# 上一个词

pre_word=''

if prefix_words:

for word in prefix_words:

output,hidden = model(input,hidden)

input = Variable(input.data.new([word2ix[word]])).view(1,1)

for i in range(opt.max_gen_len):

output,hidden = model(input,hidden)

top_index = output.data[0].topk(1)[1][0]

w = ix2word[top_index]

if (pre_word in {u'。',u'!',''} ):

# 如果遇到句号,藏头的词送进去生成

if index==start_word_len:

# 如果生成的诗歌已经包含全部藏头的词,则结束

break

else:

# 把藏头的词作为输入送入模型

w = start_words[index]

index+=1

input = Variable(input.data.new([word2ix[w]])).view(1,1)

else:

# 否则的话,把上一次预测是词作为下一个词输入

input = Variable(input.data.new([word2ix[w]])).view(1,1)

results.append(w)

pre_word = w

return results

def train(**kwargs):

for k,v in kwargs.items():

setattr(opt,k,v) #设置apt里属性的值

vis = Visualizer(env=opt.env)

#获取数据

data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数

data = t.from_numpy(data)

#这个地方出错了,是大写的L

dataloader = t.utils.data.DataLoader(data,

batch_size = opt.batch_size,

shuffle = True,

num_workers = 1) #在python里,这样写程序可以吗?

#模型定义

model = PoetryModel(len(word2ix), 128, 256)

optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)

criterion = nn.CrossEntropyLoss()

if opt.model_path:

model.load_state_dict(t.load(opt.model_path))

if opt.use_gpu:

model.cuda()

criterion.cuda()

#The tnt.AverageValueMeter measures and returns the average value

#and the standard deviation of any collection of numbers that are

#added to it. It is useful, for instance, to measure the average

#loss over a collection of examples.

#The add() function expects as input a Lua number value, which

#is the value that needs to be added to the list of values to

#average. It also takes as input an optional parameter n that

#assigns a weight to value in the average, in order to facilitate

#computing weighted averages (default = 1).

#The tnt.AverageValueMeter has no parameters to be set at initialization time.

loss_meter = meter.AverageValueMeter()

for epoch in range(opt.epoch):

loss_meter.reset()

for ii,data_ in tqdm.tqdm(enumerate(dataloader)):

#tqdm是python中的进度条

#训练

data_ = data_.long().transpose(1,0).contiguous()

#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的

if opt.use_gpu: data_ = data_.cuda()

optimizer.zero_grad()

input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])

#上边一句,将输入的诗句错开一个字,形成训练和目标

output,_ = model(input_)

loss = criterion(output, target.view(-1))

loss.backward()

optimizer.step()

loss_meter.add(loss.data[0]) #为什么是data[0]?

#可视化用到的是utlis.py里的函数

if (1+ii)%opt.plot_every ==0:

if os.path.exists(opt.debug_file):

ipdb.set_trace()

vis.plot('loss',loss_meter.value()[0])

# 下面是对目前模型情况的测试,诗歌原文

poetrys = [[ix2word[_word] for _word in data_[:,_iii]]

for _iii in range(data_.size(1))][:16]

#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文

vis.text(''.join([''.join(poetry) for poetry in

poetrys]),win = u'origin_poem')

gen_poetries = []

#分别以以下几个字作为诗歌的第一个字,生成8首诗

for word in list(u'春江花月夜凉如水'):

gen_poetry = ''.join(generate(model,word,ix2word,word2ix))

gen_poetries.append(gen_poetry)

vis.text(''.join([''.join(poetry) for poetry in

gen_poetries]), win = u'gen_poem')

t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))

def gen(**kwargs):

'''

提供命令行接口,用以生成相应的诗

'''

for k,v in kwargs.items():

setattr(opt,k,v)

data, word2ix, ix2word = get_data(opt)

model = PoetryModel(len(word2ix), 128, 256)

map_location = lambda s,l:s

# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,

# 上边句子的意思是将模型加载到默认的GPU上

state_dict = t.load(opt.model_path, map_location = map_location)

model.load_state_dict(state_dict)

if opt.use_gpu:

model.cuda()

if sys.version_info.major == 3:

if opt.start_words.insprintable():

start_words = opt.start_words

prefix_words = opt.prefix_words if opt.prefix_words else None

else:

start_words = opt.start_words.encode('ascii',\

'surrogateescape').decode('utf8')

prefix_words = opt.prefix_words.encode('ascii',\

'surrogateescape').decode('utf8') if opt.prefix_words else None

start_words = start_words.replace(',',u',')\

.replace('.',u'。')\

.replace('?',u'?')

gen_poetry = gen_acrostic if opt.acrostic else generate

result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)

print(''.join(result))

if __name__ == '__main__':

import fire

fire.Fire()

以上代码给我一些经验,

1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;

2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。

3. 对cuda()的用法有了进一步认识;

4. 学会了调试程序(fire);

5. 学会了训练结果的可视化(visdom);

6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。

这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持我们。

本文标题: pytorch下使用LSTM神经网络写诗实例

本文地址: http://www.cppcns.com/jiaoben/python/298353.html

python程序写诗_pytorch下使用LSTM神经网络写诗实例相关推荐

  1. python程序导入import、规范化和封装自己写的.py文件

    目录 1. 简单地导入自己写的.py文件 2. 将自己写的多个.py文件规范化成外部类,并创建__init__.py 3. 将自己的程序封装成外部包 1. 简单地导入自己写的.py文件 将a.py与b ...

  2. python程序代码是什么_python代码用什么写

    展开全部 对于新手而言,学了大概的语62616964757a686964616fe59b9ee7ad9431333433653336法,七七八八的历史.概念.知识,然而最直接的一个问题却无人解答:到底 ...

  3. python程序怎么运行通讯录管理中心_蛋疼写的python 通讯录管理

    我以后啊,再也不敢把文本界面的程序模仿成图形界面的了,因为需要处理的东西太多了 ,所以我以后写的程序,尽量把菜单写得简单一点,至于图形界面,就留给别人了 现在这个程序的菜单还是有bug 大家凑合着用吧 ...

  4. python程序开机自启动_Linux下Python脚本自启动和定时启动的详细步骤

    一.Python开机自动运行 假如Python自启动脚本为 auto.py .那么用root权限编辑以下文件: sudo vim /etc/rc.local 如果没有 rc.local 请看 这篇文章 ...

  5. vs python生成exe文件_使用VScode编写python程序并打包成.exe文件-文件夹变成exe

    1. 下载vscode并安装 2. 配置Python环境 点击左下角的吃了图标,在弹出的菜单中选择extensions,在左上方搜索框内输入"Python",可以看到好多Pytho ...

  6. Python(2)-第一个python程序、执行python程序三种方式

    第一个Python 程序 1. 第一个Python 程序 2. 常用两Python个版本 3. 程序执行的三种方式 3.1 解释器 3.2 交互式运行Python程序 3.3 IDE(集成开发环境)- ...

  7. python 程序打包 vscode_使用VScode编写python程序并打包成.exe文件

    听说Visual Studio Code(VS Code)的诸多好处,了解了一下果真很喜欢,我喜欢它的缘由主要有3个,一是VS Code开源且跨平台,二是由于其界面很是酷,三是能够知足个人大所属代码需 ...

  8. bat文件指定jdk路径_配置点击就能运行Python程序的bat批处理脚本(Windows)

    0,需求说明 在编写和调试程序时,一般我们会在集成编辑环境里写代码和运行,但如果程序比较完善需要快速运行,或者让同事在其他电脑上快速运行时,再打开IDE(Integrated Development ...

  9. 从零开始开发Python程序(四)—— 抓取每日早报新闻

    这是一片系列文章,最好先看看上一篇 从零开始开发Python程序(三)-- 用文本编辑器来写代码 目录 一.需求说明 二.程序设计 三.从指定网站获取新闻 1.找一个提供新闻的网站 2. 库的安装 3 ...

最新文章

  1. 当前工作目录Python
  2. 【MySQL】4、Select查询语句
  3. SQL2005的配置
  4. Delphi 7学习开发控件
  5. html编辑四则运算,前端四则运算验证
  6. winform 调用外部程序和多线程
  7. linux软件包管理 pdf,中标麒麟Linux系统软件包管理介绍(22页)-原创力文档
  8. APP自动化测试系列之Appium环境安装
  9. Java中常见RuntimeException与其他异常表及Exception逻辑关系详解
  10. dell 虚拟linux,戴尔:Linux是轻松构建虚拟主机的关键
  11. java exception 包_什么是Java中的异常包装?
  12. PHP实现当前文件夹下所有文件和文件夹的遍历
  13. js控制layui radio button选中
  14. java菜鸟疑问1:为什么我的代码总出现cannot be resolved or is not a field这种问题
  15. 简单使用Jconsole
  16. matlab求样本相关系数,matlab中样本相关系数的计算与测试
  17. Uniapp苹果登录sign in Apple
  18. esxtop 指标%RDY,NUMA,Wide-VMs
  19. 计算机专业去,计算机专业去哪个学校_西信院
  20. 计算机 打印格式 不一样,同一个EXCEL表格在不同电脑打印出来的不一样,如何解决?:excel表格自动变形...

热门文章

  1. 利用k-means算法解决简单的无监督图像识别任务
  2. 香农编码的 matlab 实现
  3. DFS(深度优先算法)难
  4. docker安装常用组件(mysql,redis,postgres,rancher,Portainer,蝉道,JIRA,sonarqube,Confluence,pgadmin4,harbor)
  5. python机器学习案例系列教程——优化,寻找使成本函数最小的最优解
  6. LaTex warning:Font shape `TU/ppl/bx/n' undefined(Font) using `TU/ppl/m/n' instead
  7. 贺利坚老师汇编课程56笔记:CMP指令
  8. 菜鸟学Struts2——Interceptors
  9. C# File流操作
  10. vimnbsp;自动识别UTF8和GB2312