参考这篇文章LSTM唐诗生成器Keras版

将相关的 keras 模型代码进行修改,改成对应的 pytorch 模型,现将有区别的部分放在这里。

训练模型

搭建网络

# 把keras 模型改成 pytorch 模型
# 建立LSTM模型
import torch
import torch.nn as nn
import torch.nn.functional as F# 设置 CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = Sequential()
# model.add(Embedding(10000, 128, input_length=20))
# model.add(LSTM(128, return_sequences=True))
# model.add(Dropout(0.2))
# model.add(LSTM(128))
# model.add(Dropout(0.2))
# model.add(Dense(10000, activation='softmax'))# 参考上述的 keras 模型,建立 pytorch 模型# 第二层 LSTM 只取最后一个输出,所以 return_sequences=Falseclass LSTMNet(nn.Module):def __init__(self):super(LSTMNet, self).__init__()self.embedding = nn.Embedding(10000, 128)self.lstm1 = nn.LSTM(input_size=128, hidden_size=128, num_layers=1, batch_first=True)self.dropout1 = nn.Dropout(0.2)self.lstm2 = nn.LSTM(input_size=128, hidden_size=128, num_layers=1, batch_first=True)self.dropout2 = nn.Dropout(0.2)self.fc = nn.Linear(128, 10000)def forward(self, x):x = self.embedding(x) # [batch_size, seq_len, embedding_size]x, _ = self.lstm1(x)  # [batch_size, seq_len, hidden_size]x = self.dropout1(x)  # [batch_size, seq_len, hidden_size]x, _ = self.lstm2(x)  # [batch_size, seq_len, hidden_size]x = self.dropout2(x)  # [batch_size, seq_len, hidden_size]x = x[:, -1, :] #       这里-1的意思是:取最后一个输出 [batch_size, hidden_size]x = self.fc(x)  #       [batch_size, 10000]return x
# 实例化模型
model = LSTMNet().to(device)
model

LSTMNet(
(embedding): Embedding(10000, 128)
(lstm1): LSTM(128, 128, batch_first=True)
(dropout1): Dropout(p=0.2, inplace=False)
(lstm2): LSTM(128, 128, batch_first=True)
(dropout2): Dropout(p=0.2, inplace=False)
(fc): Linear(in_features=128, out_features=10000, bias=True)
)

Pytorch 数据转换

注意:因为 y_train 和 y_test [batch, 1] 最后一个维度是没用的,
所以要把它去掉,变成 [batch] 才能正常给交叉熵损失函数计算

# 先把 x_train, x_test, y_train, y_test 转化为 tensor
x_train = torch.tensor(x_train).to(device)
x_test = torch.tensor(x_test).to(device)
y_train = torch.tensor(y_train).to(device)
y_test = torch.tensor(y_test).to(device)
# 测试样本能否正常输入网络
pred = model(x_train[0:3].to(device))
print(x_train[0:3].shape) # [3, 20] # 3个样本,每个样本20个词
print(pred.shape) # [3, 10000]     #  3个样本,每个样本10000个分类

torch.Size([3, 20])
torch.Size([3, 10000])

# 因为 y_train 和 y_test [batch, 1] 最后一个维度是没用的,
# 所以要把它去掉,变成 [batch] 才能正常给交叉熵损失函数计算
y_train = y_train.squeeze()
y_test = y_test.squeeze()# 转化成 Long
y_train = y_train.long()
y_test = y_test.long()# 查看形状
y_train.shape,y_test.shape

(torch.Size([39405]), torch.Size([16889]))

训练模型

# 训练模型
import torch.optim as optim
from tqdm import tqdm
optimizer = optim.Adam(model.parameters(), lr=0.001)batch_size = 256
epochs = 20# 注意,这里 y_train, y_test 的形状都是 [batch, 1] ,也就是说,并不是 one-hot 编码
# 所以,损失函数用的是 CrossEntropyLossloss_func = nn.CrossEntropyLoss()
for epoch in range(epochs):print('Epoch: ', epoch)for i in tqdm(range(0, len(x_train), batch_size)):x_batch = x_train[i:i+batch_size]y_batch = y_train[i:i+batch_size]pred = model(x_batch)loss = loss_func(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()# 每个 epoch 结束后,计算一下准确率# 训练集准确率pred = model(x_train)pred = torch.argmax(pred, dim=1)acc = (pred == y_train).sum().item() / len(y_train)print('Train acc: ', acc)# 测试集准确率pred = model(x_test)pred = torch.argmax(pred, dim=1)acc = (pred == y_test).sum().item() / len(y_test)print('Test acc: ', acc)

Epoch: 0
100%|██████████| 154/154 [00:38<00:00, 4.01it/s]
Train acc: 0.10216977540921203
Test acc: 0.10320326839955
Epoch: 1

Epoch: 19
100%|██████████| 154/154 [00:37<00:00, 4.09it/s]
Train acc: 0.20576069026773253
Test acc: 0.17970276511338742

test_string = '白日依山盡,黃河入海流,欲窮千里目,更上一'for i in range(300):# 循环 300 步,每步都要预测一个字test_string_token = tokenizer.texts_to_sequences([test_string[-20:]]) # 取最后20个字test_string_mat = np.array(test_string_token)pred = model(torch.tensor(test_string_mat).to(device)) # pred 的形状是 [1, 10000]pred_argmax = torch.argmax(pred, dim=1).item()         # pred_argmax 的形状是 [1]# 把预测的字转化为文字tokenizer.index_word[pred_argmax]test_string = test_string + tokenizer.index_word[pred_argmax]
print(test_string)

【NLP】LSTM 唐诗生成器 pytorch 版相关推荐

  1. gorm 密码字段隐藏_【财富密码】第1期:《LSTM大战上证指数-PyTorch版》

    前言: Hello大家好,我是瑟林洞仙人!这里是[财富密码]系列第1期:<LSTM大战上证指数-PyTorch版>.在这里,我将用我的"意识流"代码,手把手教会大家如何 ...

  2. 最强NLP模型BERT喜迎PyTorch版!谷歌官方推荐,也会支持中文

    郭一璞 夏乙 发自 凹非寺  量子位 报道 | 公众号 QbitAI 谷歌的最强NLP模型BERT发布以来,一直非常受关注,上周开源的官方TensorFlow实现在GitHub上已经收获了近6000星 ...

  3. 【NLP】文本分类TorchText实战-AG_NEWS 新闻主题分类任务(PyTorch版)

    AG_NEWS 新闻主题分类任务(PyTorch版) 前言 1. 使用 N 元组加载数据 2. 安装 Torch-GPU&TorchText 3. 访问原始数据集迭代器 4. 准备数据处理管道 ...

  4. 364 页 PyTorch 版《动手学深度学习》分享(全中文,支持 Jupyter 运行)

    1 前言 最近有朋友留言要求分享一下李沐老师的<动手学深度学习>,小汤本着一直坚持的"好资源大家一起分享,共同学习,共同进步"的初衷,于是便去找了资料,而且还是中文版的 ...

  5. 伯禹公益AI《动手学深度学习PyTorch版》Task 03 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 03 学习笔记 Task 03:过拟合.欠拟合及其解决方案:梯度消失.梯度爆炸:循环神经网络进阶 微信昵称:WarmIce 过拟合. ...

  6. <计算机视觉四> pytorch版yolov3网络搭建

    鼠标点击下载     项目源代码免费下载地址 <计算机视觉一> 使用标定工具标定自己的目标检测 <计算机视觉二> labelme标定的数据转换成yolo训练格式 <计算机 ...

  7. YOLOV1详解——Pytorch版

    YOLOV1详解--Pytorch版 1 YOLOV1 1 数据处理 1.1 数据集划分 1.2 读入xml文件 1.3 数据增强 2 训练 2.1 Backbone 2.2 Loss 2.3 tra ...

  8. 彩色星球图片生成1:使用Gan实现(pytorch版)

    彩色星球图片生成1:使用Gan实现(pytorch版) 1. 描述 2. 代码 2.1 模型代码model.py 2.2 训练代码main.py 3. 效果 4. 趣图 上一集: 使用Gan实现MNI ...

  9. 伯禹公益AI《动手学深度学习PyTorch版》Task 04 学习笔记

    伯禹公益AI<动手学深度学习PyTorch版>Task 04 学习笔记 Task 04:机器翻译及相关技术:注意力机制与Seq2seq模型:Transformer 微信昵称:WarmIce ...

  10. 彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)

    彩色星球图片生成4:转置卷积层+插值缩放+卷积收缩(pytorch版) 1. 改进方面 1.1 优化器与优化步长 1.2 交叉熵损失函数 1.3 Patch判别器 1.4 输入分辨率 1.5 转置卷积 ...

最新文章

  1. spring cloud集成 consul源码分析
  2. WSDM 2020 | RMRN:社区问答中的深度关联推理模型
  3. go语言的计数器iota
  4. Linux C 学习 单向链表
  5. React 产品实现 -任务管理工具“氢”
  6. 浅谈渗透测试之前期信息搜集
  7. 欧姆龙CP-X显示 END重复 以及 条 0 -重叠条
  8. 阿里P8工程师强烈推荐,60本工程师必备读本
  9. ae合成设置快捷键_AE脚本使用快捷键控制关键帧操作 Keyboard v1.2.2【资源分享1449】...
  10. BitOffer携手高盛推出保本保息量化基金,无风险年化收益20%
  11. 华为P9 回退android6.0,华为P9 Plus从EMUI5.0 版本回退EMUI 4.1官方稳定版本
  12. 关于java的国内外论坛地址分享
  13. zynq-7000系列基于zynq-7015的vivado初步设计之linux下控制PL扩展的光以太网(1000BASE-X)
  14. 马云被约谈 传递了什么信号?
  15. Scratch-陶陶摘苹果
  16. RTMP协议与RTSP协议比较
  17. Cordova 打包签名 Android release app
  18. 2018金华高一计算机考试题目,2018年9月金华十校信息技术考试试题(含解析).docx...
  19. FX5U系列添加CClink模块并配置参数
  20. 远程控制任我行的使用

热门文章

  1. 用来向服务器发送邮件的协议是,电子邮件协议中用于发送邮件的协议是
  2. 41、财务总账科目余额表,三栏式总账,三栏式明细账 查询条件科目增加多选查询
  3. OSChina 周四乱弹 —— 懦夫!你就不能找富婆吗
  4. word2010里脚注横线如何顶格
  5. 删除的文件怎么恢复?
  6. 爬取淘宝网站的商品数据
  7. [走过的路]联想时光——人艰不拆(店长篇)
  8. 谷歌大脑组合模型霸榜 SuperGLUE,什么模型这么高?
  9. 细胞穿膜肽-MnO2复合物(TAT-MnO2)多肽偶联氧化锰纳米粒|MnO2包裹聚多巴胺的纳米颗粒
  10. 31页智慧文旅云服务平台建设方案【附下载】