前言:
之前的博文部分地讲解了RNN的标准结构,也用pytorch的RNNCell类和RNN类实现了前向传播的计算,这里我们再举一个例子,做一个特别简单特别简单特别简单特别简单的翻译器,目标如下:
将英文hello and thank you翻译成汉语拼音ni hao qie duo xie ni

篇幅有限,我们拿这五个数据练手吧。
如下两个例子也是为了让我们学习一下RNN相关维度的概念。

一:使用RNNCell类实现简单的字符串翻译
这里我们需要用到one-hot vector技术,也就是我们建立一个字符表,包含空格字符和26个小写英文字符,也就是一共有27个字符的字典表,空格符号标号是0,az编号是126,我们定义有22个序列的输入(也就是最多一次能接受22个字符的输入),每个输入的x 也是27维度的向量,输出层也是27维度的,也就是多个分类的问题,用RNN实现。代码如下:

import torch
import numpy as np
from torch.autograd import Variablebatch_size = 1      # 批量数据的大小,数据的批次,比如这里有5个需要测试的单词。
seq_len = 22    # 一共是多少个时间序列,是指时间的维度,在这里的意思就是一次性最大能接受20个字符的输入。
input_size = 27     # 某一时刻下,输入的X的维度向量,比如这里的4指的是维度是[4, 1]
hidden_size = 27    # 每一层的激活后 h 的维度,比如这里的2指的是维度是[2, 1]# num_layers = 1# a~z标号是1~26, 空格是1,这个字典表,每个字符有个对应的编号,也就是下标就是编号
idx2char = [' ', # 0'a', 'b', 'c', 'd', 'e', 'f', 'g', # 1-7'h', 'i', 'j', 'k', 'l', 'm', 'n', # 8-14'o', 'p', 'q', 'r', 's', 't', # 15-20'u', 'v', 'w', 'x', 'y', 'z'] # 21-26
idx2char = np.array(idx2char)
# 每个编号有个对应的向量,也就是下标对应的向量
one_hot_look = np.array(np.eye(idx2char.shape[0])).astype(int)
print('idx2char.shape[0]=', idx2char.shape[0])x_data = np.zeros((batch_size, seq_len)).astype(int)
y_data = np.zeros((batch_size, seq_len)).astype(int)# 待翻译的句子
x_data_char = ['hello and thank you']
y_data_char = ['ni hao qie duo xie you']# 填充每个字符的下标
i=0
for item in x_data_char:# fill each dataj=0for letter in item:x_data[i][j] = np.where(idx2char == letter)[0][0]j = j + 1i = i + 1i=0
for item in y_data_char:# fill each dataj=0for letter in item:y_data[i][j] = np.where(idx2char == letter)[0][0]j = j + 1i = i + 1#print(x_data)
#print(y_data)# Step 2:============================定义一个RNNCell的模型===================
class RNNCellModel(torch.nn.Module):def __init__(self, input_size, hidden_size):super(RNNCellModel, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.rnncell = torch.nn.RNNCell(input_size=self.input_size, hidden_size=self.hidden_size)def forward(self, input, hidden):hidden = self.rnncell(input, hidden)return hiddenmodel = RNNCellModel(input_size, hidden_size)# Step 3:============================定义损失函数和优化器===================
# 定义 loss 函数,这里用的是交叉熵损失函数(Cross Entropy),这种损失函数之前博文也讲过的。
criterion = torch.nn.CrossEntropyLoss()
# 我们优先使用Adam下降,lr是学习率: 0.1
optimizer = torch.optim.Adam(model.parameters(), 1e-1)# Step 4:============================开始训练===================
for e in range(200):loss = 0optimizer.zero_grad()hidden = torch.zeros(batch_size, hidden_size)  # 随机初始化的h0,且必须是Tensor类型的i = 0print('e============================')print('predicted str: ', end='')for cur_seq in zip(*x_data):   # for each seqcur_seq = np.array(cur_seq)#input 必须是Tensor类型的input = torch.Tensor([one_hot_look[x] for x in cur_seq])# print(input.shape)# y_label 必须是长整数型,且必须是Tensor类型的y_label = torch.LongTensor(np.array(y_data[:, i]))#print('real y_label: ', idx2char[y_label.item()], end='')# 前向传播hidden = model(input, hidden)_, idx = hidden.max(dim=1)print(idx2char[idx.item()], end='')# print(hidden)# 累加损失loss += criterion(hidden, y_label)i = i+1loss.backward()optimizer.step()print(',epoch [%d/200] loss=%.4f' % (e+1, loss.item()), end='')

输出结果如下:训练的字符串是hello and thank you

这个模型基本就差不多了训练出来是 ni hao qie duo xie ni

二:使用RNN类实现简单的字符串翻译
或者还可以用RNN模型来做,这样省去了自己写循环迭代序列的过程
,这里需要注意模型的建立需要传入层数,以及需要传递batch_size,因为hidden节点需要这个参数。

import torch
import numpy as np
from torch.autograd import Variablebatch_size = 1      # 批量数据的大小,数据的批次,比如这里有5个需要测试的单词。
seq_len = 22    # 一共是多少个时间序列,是指时间的维度,在这里的意思就是一次性最大能接受22个字符的输入。
input_size = 27     # 某一时刻下,输入的X的维度向量,比如这里的4指的是维度是[4, 1]
hidden_size = 27    # 每一层的激活后 h 的维度,比如这里的2指的是维度是[2, 1]
num_layers = 3      # 共有三层RNN# a~z标号是1~26, 空格是1,这个字典表,每个字符有个对应的编号,也就是下标就是编号
idx2char = [' ', # 0'a', 'b', 'c', 'd', 'e', 'f', 'g', # 1-7'h', 'i', 'j', 'k', 'l', 'm', 'n', # 8-14'o', 'p', 'q', 'r', 's', 't', # 15-20'u', 'v', 'w', 'x', 'y', 'z'] # 21-26
idx2char = np.array(idx2char)
# 每个编号有个对应的向量,也就是下标对应的向量
one_hot_look = np.array(np.eye(idx2char.shape[0])).astype(int)
print('idx2char.shape[0]=', idx2char.shape[0])x_data = np.zeros((batch_size, seq_len)).astype(int)
y_data = np.zeros((batch_size, seq_len)).astype(int)# 待翻译的句子
x_data_char = ['hello and thank you']
y_data_char = ['ni hao qie duo xie you']# 填充每个字符的下标
i=0
for item in x_data_char:# fill each dataj=0for letter in item:x_data[i][j] = np.where(idx2char == letter)[0][0]j = j + 1i = i + 1i=0
for item in y_data_char:# fill each dataj=0for letter in item:y_data[i][j] = np.where(idx2char == letter)[0][0]j = j + 1i = i + 1#print(x_data)
#print(y_data)
x_one_hot = [one_hot_look[x] for x in x_data]
inputs = torch.Tensor(x_one_hot).view(seq_len, batch_size, input_size)
lables = torch.LongTensor(y_data).view(-1)# Step 2:============================定义一个RNNCell的模型===================
class RNNModel(torch.nn.Module):def __init__(self, input_size, hidden_size, batch_size, num_layers):super(RNNModel, self).__init__()self.input_size = input_sizeself.hidden_size = hidden_sizeself.batch_size = batch_sizeself.num_layers = num_layersself.rnn = torch.nn.RNN(input_size=self.input_size, hidden_size=self.hidden_size,num_layers=num_layers)def forward(self, input):# 由于不需要自己写序列的迭代,因此直接将随机初始化h0写道这里hidden = torch.zeros(num_layers, batch_size, hidden_size)  # 随机初始化的h0,且必须是Tensor类型的# 我们只关心output: [seq_len, batch_size, hidden_size]# input是: [seq_len, batch_size, input_size]output, _ = self.rnn(input, hidden)# 将out reshape到[seq_len * batch_size, hidden_size]return output.view(-1, self.hidden_size)model = RNNModel(input_size, hidden_size, batch_size, num_layers)# Step 3:============================定义损失函数和优化器===================
# 定义 loss 函数,这里用的是交叉熵损失函数(Cross Entropy),这种损失函数之前博文也讲过的。
criterion = torch.nn.CrossEntropyLoss()
# 我们优先使用Adam下降,lr是学习率: 0.1
optimizer = torch.optim.Adam(model.parameters(), lr=0.05)# Step 4:============================开始训练===================
for e in range(200):loss = 0optimizer.zero_grad()# 前向传播hidden = model(inputs)# 计算损失loss = criterion(hidden, lables)# 反向传播,更新参数loss.backward()optimizer.step()_, idx = hidden.max(dim=1)idx = idx.data.numpy()print('e============================')print('predicted str: ', end='')print(''.join([idx2char[x] for x in idx]), end='')print(',epoch [%d/200] loss=%.4f' % (e+1, loss.item()), end='')

效果如下:很明显,损失好像更低了些。

《Pytorch - RNN模型》相关推荐

  1. ComeFuture英伽学院——2020年 全国大学生英语竞赛【C类初赛真题解析】(持续更新)

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  2. ComeFuture英伽学院——2019年 全国大学生英语竞赛【C类初赛真题解析】大小作文——详细解析

    视频:ComeFuture英伽学院--2019年 全国大学生英语竞赛[C类初赛真题解析]大小作文--详细解析 课件:[课件]2019年大学生英语竞赛C类初赛.pdf 视频:2020年全国大学生英语竞赛 ...

  3. 信息学奥赛真题解析(玩具谜题)

    玩具谜题(2016年信息学奥赛提高组真题) 题目描述 小南有一套可爱的玩具小人, 它们各有不同的职业.有一天, 这些玩具小人把小南的眼镜藏了起来.小南发现玩具小人们围成了一个圈,它们有的面朝圈内,有的 ...

  4. 信息学奥赛之初赛 第1轮 讲解(01-08课)

    信息学奥赛之初赛讲解 01 计算机概述 系统基本结构 信息学奥赛之初赛讲解 01 计算机概述 系统基本结构_哔哩哔哩_bilibili 信息学奥赛之初赛讲解 02 软件系统 计算机语言 进制转换 信息 ...

  5. 信息学奥赛一本通习题答案(五)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  6. 信息学奥赛一本通习题答案(三)

    最近在给小学生做C++的入门培训,用的教程是信息学奥赛一本通,刷题网址 http://ybt.ssoier.cn:8088/index.php 现将部分习题的答案放在博客上,希望能给其他有需要的人带来 ...

  7. 信息学奥赛一本通 提高篇 第六部分 数学基础 相关的真题

    第1章   快速幂 1875:[13NOIP提高组]转圈游戏 信息学奥赛一本通(C++版)在线评测系统 第2 章  素数 第 3 章  约数 第 4 章  同余问题 第 5 章  矩阵乘法 第 6 章 ...

  8. 信息学奥赛一本通题目代码(非题库)

    为了完善自己学c++,很多人都去读相关文献,就比如<信息学奥赛一本通>,可又对题目无从下手,从今天开始,我将把书上的题目一 一的解析下来,可以做参考,如果有错,可以告诉我,将在下次解析里重 ...

  9. 信息学奥赛一本通(C++版) 刷题 记录

    总目录详见:https://blog.csdn.net/mrcrack/article/details/86501716 信息学奥赛一本通(C++版) 刷题 记录 http://ybt.ssoier. ...

  10. 最近公共祖先三种算法详解 + 模板题 建议新手收藏 例题: 信息学奥赛一本通 祖孙询问 距离

    首先什么是最近公共祖先?? 如图:红色节点的祖先为红色的1, 2, 3. 绿色节点的祖先为绿色的1, 2, 3, 4. 他们的最近公共祖先即他们最先相交的地方,如在上图中黄色的点就是他们的最近公共祖先 ...

最新文章

  1. AI 经典书单 | 人工智能学习该读哪些书
  2. C++之stdafx.h的用法说明
  3. HashMap与垃圾回收
  4. UITabBarController 基本用法
  5. 华为鸿蒙vogtloop30pro价格,华为Mate30系列基本确认:首发麒麟985+鸿蒙系统,价格感人!...
  6. 8086指令系统 操作数地址,双操作数,单操作数,无操作数指令。一,传送类指令;二,二、算数运算类指令
  7. 从零开始学习mitmproxy源码阅读
  8. Django构建简介
  9. 通过PXE网络启动WinPE的方法
  10. 用数据分析头部微信公众号到底有多牛
  11. 2018蓝桥杯第几个幸运数(C语言)
  12. python将多个表的数据合并到一个表
  13. 天融信数通小百科:无线AP的Soul mate—POE交换机
  14. 共享租车平台“车便利租车”完成A轮融资
  15. mysql密码为空包密码错误_apk空包签名方法及工具
  16. React.createClass()方法
  17. 网站故障和安全事件的应急预案
  18. 鼠标悬浮移出控制div的显示与隐藏
  19. 【Spring Cloud】OpenFeign和Spring Cloud Loadbalancer调用失败后的重试机制比较
  20. App打造自定义的统计SDK

热门文章

  1. 2017 最值得关注的十大 APP、Web 界面设计趋势
  2. 基于HTML5 的人脸识别技术
  3. SSH集成log4j日志环境
  4. 2015第25周三iframe小结
  5. Custom Looks using Qt Style Sheets
  6. Web容器和Servlet生命周期
  7. sql交叉报表实例(转)
  8. 内网穿透工具 Ngrok
  9. NET Core 3.0 项目中使用 AutoFac
  10. 粗谈MySQL事务的特性和隔离级别