import numpy as np
import torch
from torch import nn, optim
import torch.nn.functional as F

载入数据

import sys
sys.path.append("../input/")
import d2l_jay9460 as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()

初始化参数

num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size
print('will use', device)def get_params():  def _one(shape):ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32) #正态分布return torch.nn.Parameter(ts, requires_grad=True)def _three():return (_one((num_inputs, num_hiddens)),_one((num_hiddens, num_hiddens)),torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))W_xz, W_hz, b_z = _three()  # 更新门参数W_xr, W_hr, b_r = _three()  # 重置门参数W_xh, W_hh, b_h = _three()  # 候选隐藏状态参数# 输出层参数W_hq = _one((num_hiddens, num_outputs))b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])def init_gru_state(batch_size, num_hiddens, device):   #隐藏状态初始化return (torch.zeros((batch_size, num_hiddens), device=device), )

GRU 模型

def gru(inputs, state, params):W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []for X in inputs:Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)H_tilda = torch.tanh(torch.matmul(X, W_xh) + R * torch.matmul(H, W_hh) + b_h)H = Z * H + (1 - Z) * H_tildaY = torch.matmul(H, W_hq) + b_qoutputs.append(Y)return outputs, (H,)

训练模型

num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']
d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens,vocab_size, device, corpus_indices, idx_to_char,char_to_idx, False, num_epochs, num_steps, lr,clipping_theta, batch_size, pred_period, pred_len,prefixes)

简洁实现

num_hiddens=256
num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2
pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']lr = 1e-2 # 注意调整学习率
gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)
model = d2l.RNNModel(gru_layer, vocab_size).to(device)
d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device,corpus_indices, idx_to_char, char_to_idx,num_epochs, num_steps, lr, clipping_theta,batch_size, pred_period, pred_len, prefixes)

Pytorch 从零开始实现 GRU相关推荐

  1. 使用PyTorch从零开始实现YOLO-V3目标检测算法 (一)

    原文:https://blog.csdn.net/u011520516/article/details/80222743 点击查看博客原文 标检测是深度学习近期发展过程中受益最多的领域.随着技术的进步 ...

  2. 使用pytorch从零开始实现YOLO-V3目标检测算法 (二)

    原文:https://blog.csdn.net/u011520516/article/details/80212960 博客翻译 这是从零开始实现YOLO v3检测器的教程的第2部分.在上一节中,我 ...

  3. 使用PyTorch从零开始实现YOLO-V3目标检测算法 (四)

    原文:https://blog.csdn.net/u011520516/article/details/80228130 点击查看博客原文 这是从零开始实现YOLO v3检测器的教程的第4部分,在上一 ...

  4. 使用PyTorch从零开始实现YOLO-V3目标检测算法 (三)

    原文:https://blog.csdn.net/u011520516/article/details/80216009 点击查看博客原文 这是从零开始实现YOLO v3检测器的教程的第3部分.第二部 ...

  5. 使用PyTorch从零开始构建Elman循环神经网络

    摘要: 循环神经网络是如何工作的?如何构建一个Elman循环神经网络?在这里,教你手把手创建一个Elman循环神经网络进行简单的序列预测. 本文以最简单的RNNs模型为例:Elman循环神经网络,讲述 ...

  6. DL之GRU:基于2022年6月最新上证指数数据集结合Pytorch框架利用GRU算法预测最新股票上证指数实现回归预测

    DL之GRU:基于2022年6月最新上证指数数据集结合Pytorch框架利用GRU算法预测最新股票上证指数实现回归预测 目录 基于2022年6月最新上证指数数据集结合Pytorch框架利用GRU算法预 ...

  7. 【pytorch】nn.GRU的使用

    官方文档在这里. GRU具体不做介绍了,本篇只做pytorch的API使用介绍. torch.nn.GRU(*args, **kwargs) 公式 下面公式忽略bias,由于输入向量的长度和隐藏层特征 ...

  8. [PyTorch] rnn,lstm,gru中输入输出维度

    本文中的RNN泛指LSTM,GRU等等 CNN中和RNN中batchSize的默认位置是不同的. CNN中:batchsize的位置是position 0. RNN中:batchsize的位置是pos ...

  9. pytorch 预测手写体数字_教你用PyTorch从零开始实现LeNet 5手写数字的识别

    背景 LeNET-5是最早的卷积神经网络之一,曾广泛用于美国银行.手写数字识别正确率在99%以上. PyTorch是Facebook 人工智能研究院在2017年1月,基于Torch退出的一个Pytho ...

最新文章

  1. Nature综述 | 种内多样性:解释微生物组中的菌株
  2. STL之hashtable源代码剖析
  3. Python 程序打包 -- 使用pyinstaller
  4. 《TCP/IP详解》读书笔记
  5. 同样是做大数据分析,你月薪8k他30k,到底差在了哪?
  6. 微信小程序教程笔记4
  7. 安装IBM Data Studio Client
  8. 请解释jsonp的工作原理
  9. 【Axure PR原型模板】微信公众小程序手机移动端高保真交互原型
  10. python http请求时gzip解压
  11. 【zhasite】托业英语阅读技巧有哪些
  12. 命里有时终须有,命里无时莫强求
  13. 简单翻译工具--必应词典第三方api使用方法
  14. REDHAT版本与支持的intel CPU型号
  15. Django+Vue开发生鲜电商平台之2.开发环境搭建
  16. 【陈工笔记】# 微信小程序实现的基础步骤速记,持续更新关键词 #
  17. IDEA类和方法注释模板设置
  18. python中字符串转数组、数组转字符串
  19. 阿里云服务器创建快照、回滚磁盘
  20. 兰州大学最新预测:新冠大流行将于2023年底结束

热门文章

  1. 难怪好人有恶报,原来秘密在这里
  2. 基于python和MATLAB的遗传算法优化函数最小值
  3. LAMMPS学习系列(8)
  4. 物联网行业基建级的平台形态:阿里云Link Develop
  5. 特征选择:概述与方法
  6. 构筑立体世界,AR Engine助力B站会员购打造沉浸式营销
  7. PHP设计网站的编码,PHP网站开发如何高效、准确、自动识别网页编码 ?
  8. 2016阿里巴巴实习生在线笔试
  9. 《阿里巴巴 Java 开发手册 》读书笔记
  10. Windows 安装Docker 打包镜像