pytorch学习笔记(二十九):简洁实现循环神经网络
本节将使用PyTorch来更简洁地实现基于循环神经网络的语言模型。首先,我们读取周杰伦专辑歌词数据集。
import time
import math
import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as Fimport sys
sys.path.append("..")
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()
1. 定义模型
PyTorch中的nn
模块提供了循环神经网络的实现。下面构造一个含单隐藏层、隐藏单元个数为256的循环神经网络层rnn_layer
。
num_hiddens = 256
# rnn_layer = nn.LSTM(input_size=vocab_size, hidden_size=num_hiddens) # 已测试
rnn_layer = nn.RNN(input_size=vocab_size, hidden_size=num_hiddens)
与上一节中实现的循环神经网络不同,这里rnn_layer
的输入形状为(时间步数, 批量大小, 输入个数)。其中输入个数即one-hot向量长度(词典大小)。此外,rnn_layer
作为nn.RNN
实例,在前向计算后会分别返回输出和隐藏状态h,其中输出指的是隐藏层在各个时间步上计算并输出的隐藏状态,它们通常作为后续输出层的输入。需要强调的是,该“输出”本身并不涉及输出层计算,形状为(时间步数, 批量大小, 隐藏单元个数)。而nn.RNN
实例在前向计算返回的隐藏状态指的是隐藏层在最后时间步的隐藏状态:当隐藏层有多层时,每一层的隐藏状态都会记录在该变量中。
来看看我们的例子,输出形状为(时间步数, 批量大小, 隐藏单元个数),隐藏状态h的形状为(层数, 批量大小, 隐藏单元个数)。
num_steps = 35
batch_size = 2
state = None
X = torch.rand(num_steps, batch_size, vocab_size)
Y, state_new = rnn_layer(X, state)
print(Y.shape, len(state_new), state_new[0].shape)
输出:
torch.Size([35, 2, 256]) 1 torch.Size([2, 256])
接下来我们继承Module
类来定义一个完整的循环神经网络。它首先将输入数据使用one-hot向量表示后输入到rnn_layer
中,然后使用全连接输出层得到输出。输出个数等于词典大小vocab_size
。
# 本类已保存在d2lzh_pytorch包中方便以后使用
class RNNModel(nn.Module):def __init__(self, rnn_layer, vocab_size):super(RNNModel, self).__init__()self.rnn = rnn_layerself.hidden_size = rnn_layer.hidden_size * (2 if rnn_layer.bidirectional else 1) self.vocab_size = vocab_sizeself.dense = nn.Linear(self.hidden_size, vocab_size)self.state = Nonedef forward(self, inputs, state): # inputs: (batch, seq_len)# 获取one-hot向量表示X = d2l.to_onehot(inputs, self.vocab_size) # X是个listY, self.state = self.rnn(torch.stack(X), state)# 全连接层会首先将Y的形状变成(num_steps * batch_size, num_hiddens),它的输出# 形状为(num_steps * batch_size, vocab_size)output = self.dense(Y.view(-1, Y.shape[-1]))return output, self.state
2. 训练模型
同上一节一样,下面定义一个预测函数。这里的实现区别在于前向计算和初始化隐藏状态的函数接口。
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def predict_rnn_pytorch(prefix, num_chars, model, vocab_size, device, idx_to_char,char_to_idx):state = Noneoutput = [char_to_idx[prefix[0]]] # output会记录prefix加上输出for t in range(num_chars + len(prefix) - 1):X = torch.tensor([output[-1]], device=device).view(1, 1)if state is not None:if isinstance(state, tuple): # LSTM, state:(h, c) state = (state[0].to(device), state[1].to(device))else: state = state.to(device)(Y, state) = model(X, state)if t < len(prefix) - 1:output.append(char_to_idx[prefix[t + 1]])else:output.append(int(Y.argmax(dim=1).item()))return ''.join([idx_to_char[i] for i in output])
让我们使用权重为随机值的模型来预测一次。
model = RNNModel(rnn_layer, vocab_size).to(device)
predict_rnn_pytorch('分开', 10, model, vocab_size, device, idx_to_char, char_to_idx)
输出:
'分开戏想暖迎凉想征凉征征'
接下来实现训练函数。算法同上一节的一样,但这里只使用了相邻采样来读取数据。
# 本函数已保存在d2lzh_pytorch包中方便以后使用
def train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes):loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=lr)model.to(device)state = Nonefor epoch in range(num_epochs):l_sum, n, start = 0.0, 0, time.time()data_iter = d2l.data_iter_consecutive(corpus_indices, batch_size, num_steps, device) # 相邻采样for X, Y in data_iter:if state is not None:# 使用detach函数从计算图分离隐藏状态, 这是为了# 使模型参数的梯度计算只依赖一次迭代读取的小批量序列(防止梯度计算开销太大)if isinstance (state, tuple): # LSTM, state:(h, c) state = (state[0].detach(), state[1].detach())else: state = state.detach()(output, state) = model(X, state) # output: 形状为(num_steps * batch_size, vocab_size)# Y的形状是(batch_size, num_steps),转置后再变成长度为# batch * num_steps 的向量,这样跟输出的行一一对应y = torch.transpose(Y, 0, 1).contiguous().view(-1)l = loss(output, y.long())optimizer.zero_grad()l.backward()# 梯度裁剪d2l.grad_clipping(model.parameters(), clipping_theta, device)optimizer.step()l_sum += l.item() * y.shape[0]n += y.shape[0]try:perplexity = math.exp(l_sum / n)except OverflowError:perplexity = float('inf')if (epoch + 1) % pred_period == 0:print('epoch %d, perplexity %f, time %.2f sec' % (epoch + 1, perplexity, time.time() - start))for prefix in prefixes:print(' -', predict_rnn_pytorch(prefix, pred_len, model, vocab_size, device, idx_to_char,char_to_idx))
使用和上一节实验中一样的超参数(除了学习率)来训练模型。
num_epochs, batch_size, lr, clipping_theta = 250, 32, 1e-3, 1e-2 # 注意这里的学习率设置
pred_period, pred_len, prefixes = 50, 50, ['分开', '不分开']
train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)
输出:
epoch 50, perplexity 1.012515, time 0.16 sec- 分开 我去定不 却已在不要再 我爱你看着我已经 我跟 爸 爱是我的快像大只 妈妈整全画怕跟 没有你- 不分开 我有多年样 后不知不觉 你已经很久 不能就能知不觉 你已经离开我 不知不觉 我跟了这节奏 后知
epoch 100, perplexity 1.009353, time 0.15 sec- 分开 我在定的让我有多知道 有一个人已什么会痛 有教堂有你烦故事 有教堂有城堡 每天忙碌地的寻找 到- 不分开不了口多著一口默默默 我妈 我不要再想 我爱 你的微我有多这样 生妈 这爱你 后悔着对不起 一枚铜
epoch 150, perplexity 1.008741, time 0.16 sec- 分开 我在定的可以让我知道你说你只会我依不知错搞错 拜托 我想是你的脑袋有问题 随便说说 其实我早已经猜- 不分开不了口不著 纪录那最心始的美丽 纪录第一次遇见的你 如果我遇见你是一场悲剧 我想我这辈子注定一个人演
epoch 200, perplexity 1.007901, time 0.16 sec- 分开 我有人看着你 手过去对 全来像伊斯坦堡 你说你说啊不会不要 这样打我妈妈 我想的声宣布 对你依依- 不分开不想就多著 心碎起 娘子 从前 教育别人的家庭 别人的爸爸种种的暴力因素一定都会有原因 但是呢
epoch 250, perplexity 1.007771, time 0.16 sec- 分开 我时空被你在的对 泪拆封动 所有回忆对着我进攻 我 我不人 漂亮心让我面红的可爱女人 温柔- 不分开不想就多难熬 穿过 在因为闷了很久 是因为想说太多 是心理起了作用 你的的苦笑常常陪着你 在一起有
小结
- PyTorch的
nn
模块提供了循环神经网络层的实现。 - PyTorch的
nn.RNN
实例在前向计算后会分别返回输出和隐藏状态。该前向计算并不涉及输出层计算。
pytorch学习笔记(二十九):简洁实现循环神经网络相关推荐
- pytorch学习笔记(十九):二维卷积层
文章目录 1. 二维互相关运算 2. 二维卷积层 3. 图像中物体边缘检测 4. 通过数据学习核数组 卷积神经网络(convolutional neural network)是含有卷积层(convol ...
- C++语法学习笔记二十九: 详解decltype含义,decltype主要用途
实例代码 // 详解decltype含义,decltype主要用途#include <iostream> #include <functional> #include < ...
- opencv学习笔记二十九:SIFT特征点检测与匹配
SIFT(Scale-invariant feature transform)是一种检测局部特征的算法,该算法通过求一幅图中的特征点(interest points,or corner points) ...
- Mr.J-- jQuery学习笔记(二十九)--属性操作方法(获取属性判断)
获取 attr() <span class="span1" name="it666"></span> <span class=&q ...
- Linux学习笔记二十九——http服务
基础概念: HTTP:Hyper Text Transfer Protocol 超文本传输协议 versions: HTTP/0.9:只接收GET一种请求方法,只支持纯文本 HTTP/1.0:支持PU ...
- 深度学习入门笔记(十九):卷积神经网络(二)
欢迎关注WX公众号:[程序员管小亮] 专栏--深度学习入门笔记 声明 1)该文章整理自网上的大牛和机器学习专家无私奉献的资料,具体引用的资料请看参考文献. 2)本文仅供学术交流,非商用.所以每一部分具 ...
- python数据挖掘学习笔记】十九.鸢尾花数据集可视化、线性回归、决策树花样分析
#2018-04-05 16:57:26 April Thursday the 14 week, the 095 day SZ SSMR python数据挖掘学习笔记]十九.鸢尾花数据集可视化.线性回 ...
- Python学习笔记(十九)面向对象 - 继承
Python学习笔记(十九)面向对象 - 继承 一.继承的概念 # 继承:子类继承父类的所有方法和属性# 1. 子类 class A(object):def __init__(self):self.n ...
- PyTorch学习笔记(二)——回归
PyTorch学习笔记(二)--回归 本文主要是用PyTorch来实现一个简单的回归任务. 编辑器:spyder 1.引入相应的包及生成伪数据 import torch import torch.nn ...
- Mr.J-- jQuery学习笔记(二十八)--DOM操作方法(添加方法总结)
Table of Contents appendTo appendTo(source, target) 源代码 append prependTo prependTo源码 prepend ...
最新文章
- 【iOS】iOS10.3新增API:应用内评分
- 阿里面试官:给我说说Netty是如何在Dubbo中应用的?
- 张一鸣:10年面试2000人,我发现混的好的人,全都有同一个特质!
- 全栈技术实践经历告诉你:开发一个商城小程序要多少钱?
- React Native学习(七)—— FlatList实现横向滑动列表效果
- 什么是RPC?RPC框架dubbo的核心流程
- 谈谈设计模式的几个原则
- bzoj2561 最小生成树
- php读取excel的数据,php读取excel文件数据
- xpath选择当前结点的子节点
- 计算机应用与基础进制,计算机应用基础选择题-删进制、字符题(2).docx
- 面对1.3 亿用户数据泄露,企业如何围绕核心数据构建安全管理体系?
- php 开发环境配置,开发环境配置
- 02_HBase集群部署
- Layui 数据表格table 重载reload 保留上次where条件的问题
- 电子罗盘简单介绍和应用
- --随笔--带你轻松理解TCP中的三次握手
- 浅析企业应收账款保理融资
- 拇指玩」制作的「谷歌安装器」app
- 微信公众号url接口配置,使微信公众号更多功能化(python简单解决)
热门文章
- UITableViewStyleGrouped顶部留白问题
- UDK游戏开发基础命令
- VoltDB开篇 简介
- rz/sz:工作原理
- SpringBoot项目报错Cannot determine embedded database driver class for database type NONE
- Unity 3D 中NGUI插件设置中文label
- Yandex.Algorithm 2011 Round 2 D. Powerful array 莫队算法
- OpenShift:外国的免费云平台
- tp中自定义跳转页面
- 表视图(UITableView)与表视图控制器(UITableViewController)