batch之间传递state

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(900)class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)self.out = nn.Linear(2, 1)def forward(self, x,state):# x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)r_out,state= self.lstm(x,state) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])return outs,state
rnn = RNN()optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
state = (torch.randn(1, 1, 2), torch.randn(1, 1, 2))
for step in range(100):# 构建数据start, end = step * np.pi, (step+1)*np.pisteps = torch.linspace(start, end, 10)x = torch.sin(steps).unsqueeze(0).unsqueeze(2)y = torch.cos(steps).unsqueeze(0).unsqueeze(2)# 学习prediction,state= rnn(x,state)state = (state[0].detach(),state[1].detach())loss = mse(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# 可视化结果plt.plot(steps, y.numpy().flatten(), 'r-')plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')plt.draw(); plt.pause(0.05)
plt.show()

batch之间不传递state

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1000)class RNN(nn.Module):def __init__(self):super(RNN, self).__init__()self.lstm = nn.LSTM(input_size=1,hidden_size=2,num_layers=1,batch_first=True)self.out = nn.Linear(2, 1)def forward(self, x):# x(batch,seq_len,input_size) h_state(n_layers,batch,hidden_size) r_out(batch,seq_len,hidden_size)r_out,_= self.lstm(x,) # (batch,seq_len,hidden_size[1]) -> (batch,seq_len,hidden_size[2])outs = self.out(r_out) # (batch,seq_len,hidden_size[2]) -> (batch,seq_len,hidden_size[1])return outs
rnn = RNN()optimizer = torch.optim.Adam(rnn.parameters(), lr=0.02) # 同时更新W_hh,W_ih
mse = nn.MSELoss()
for step in range(100):# 构建数据start, end = step * np.pi, (step+1)*np.pisteps = torch.linspace(start, end, 10)x = torch.sin(steps).unsqueeze(0).unsqueeze(2)y = torch.cos(steps).unsqueeze(0).unsqueeze(2)# 学习prediction = rnn(x)loss = mse(prediction, y)optimizer.zero_grad()loss.backward()optimizer.step()# 可视化结果plt.plot(steps, y.numpy().flatten(), 'r-')plt.plot(steps, prediction.detach().numpy().flatten(), 'b-')plt.draw(); plt.pause(0.05)
plt.show()

https://zhuanlan.zhihu.com/p/94757947
https://discuss.pytorch.org/t/lstm-how-to-remember-hidden-and-cell-states-across-different-batches/11957
https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html

pytorch LSTM_regression相关推荐

  1. pyTorch api

    应用 pytorch FC_regression pytorch FC_classification pytorch RNN_regression pytorch LSTM_regression py ...

  2. 通过anaconda2安装python2.7和安装pytorch

    ①由于官网下载anaconda2太慢,最好去byrbt下载,然后安装就行 ②安装完anaconda2会自动安装了python2.7(如终端输入python即进入python模式) 但是可能没有设置环境 ...

  3. 记录一次简单、高效、无错误的linux上安装pytorch的过程

    1 准备miniconda Miniconda Miniconda 可以理解成Anaconda的免费.浓缩版.它非常小,只包含了conda.python以及它们依赖的一些包.我们可以根据我们的需要再安 ...

  4. 各种注意力机制PyTorch实现

    给出了整个系列的PyTorch的代码实现,以及使用方法. 各种注意力机制 Pytorch implementation of "Beyond Self-attention: External ...

  5. PyTorch代码调试利器_TorchSnooper

    GitHub 项目地址: https://github.com/zasdfgbnm/TorchSnooper 大家可能遇到这样子的困扰:比如说运行自己编写的 PyTorch 代码的时候,PyTorch ...

  6. pytorch常用代码

    20211228 https://mp.weixin.qq.com/s/4breleAhCh6_9tvMK3WDaw 常用代码段 本文代码基于 PyTorch 1.x 版本,需要用到以下包: impo ...

  7. API pytorch tensorflow

    pytorch与tensorflow API速查表 方法名称 pytroch tensorflow numpy 裁剪 torch.clamp(x, min, max) tf.clip_by_value ...

  8. tensor转换 pytorch tensorflow

    一.tensorflow的numpy与tensor互转 1.数组(numpy)转tensor 利用tf.convert_to_tensor(numpy),将numpy转成tensor >> ...

  9. tensor和模型 保存与加载 PyTorch

    PyTorch教程-7:PyTorch中保存与加载tensor和模型详解 保存和读取Tensor PyTorch中的tensor可以保存成 .pt 或者 .pth 格式的文件,使用torch.save ...

最新文章

  1. ui产品小结 - 包含小程序 前端等
  2. 药师帮完成1.33亿美元D轮融资,投资方为老虎环球基金、H Capital和DCM
  3. 1.MVC的工作流程
  4. PHP-Webshell免杀研究
  5. 数据结构与算法 | 快速排序:Hoare法, 挖坑法,双指针法,非递归, 优化
  6. [LeetCode]题解(python):058-Length of Last Word
  7. 编程大白给编程小白的四点建议
  8. linux内存管理(十三)-内存规整过程分析
  9. 微信小程序云开发教程-微信小程序的API入门-API的类型和语法结构
  10. Nginx配置文档详解
  11. 基于堆栈的缓冲区溢出_基于堆栈溢出问题构建搜索引擎
  12. 云联惠创业经营者认证_广州公安打掉云联惠涉传销组织 零壹财经曾发文警示...
  13. php爬取百度关键词时出现,百度安全验证,解决方法
  14. 三火龙加身战无不胜 TES国际首秀告捷
  15. 第7章第30节:四图排版:四张图片交错对齐排列 [PowerPoint精美幻灯片实战教程]
  16. c罗python可视化分析_鸟枪换炮,利用python3对球员做大数据降维(因子分析得分),为C罗找到合格僚机...
  17. 阿里云网站备案简单流程说明文档
  18. 杨辉三角c语言程序for循环,如何用C语言循环输出杨辉三角?
  19. 上银驱动器使用手册_D1驱动器操作使用手册.pdf
  20. JavaScript 中的 JSON

热门文章

  1. 9-算法 kmp算法
  2. Java实现两个csv文件的对比_Java实现CSV文件差异对比
  3. SQL数据库语言基础之SqlServer表数据的插入、更新与删除
  4. 模拟电子技术不挂科学习笔记3(放大电路的分析方法)
  5. 初学者python笔记(迭代器、生成器、三元表达式、列表解析、send()与yield())
  6. 详解Python序列解包(5)
  7. Python版的百钱买百鸡问题
  8. 1.(单选题) HTML是指,《计算机应用基础》第五阶段在线作业(自测).doc
  9. java mian 方法_Java mian函数
  10. 大数的加减法C语言程序设计,超大数相加C语言程序设计