文章目录

  • 生成训练数据
  • 构建TextRNN
  • 开始训练
    • 构建训练数据集
    • 训练三件套:模型,loss,优化器
    • 开始训练
  • 完整代码

生成训练数据

这里使用随机数生成训练数据,大家在自己写的时候只需要替换这里就OK了:

def get_total_train_data(word_embedding_size, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, 20, word_embedding_size)))  # 维度是 [ 数据量, 一句话的单词个数, 句子的embedding]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_train

由于多数是用于文本训练,因此这里将word_embedding_size换成自己的embedding,对于分类任务,class_count = N表示N分类任务;回归任务也可以自定义修改代码

构建TextRNN

class TextRnnModel(nn.Module):def __init__(self, embedding_size, lstm_hidden_size, output_size, num_layers=2, dropout=0.3):super(TextRnnModel, self).__init__()self.lstm = nn.LSTM(embedding_size,  # 词嵌入模型词语维度lstm_hidden_size,  # 隐层神经元的维度,为输出的维度num_layers,  # 构建两层的LSTM:堆叠LSTMbidirectional=True,  # 双向的LSTM:词向量从前往后走,再重后往前走,最后拼接起来batch_first=True,  # 把第一个维度的输入作为batch输入的维度dropout=dropout)self.fc = nn.Linear(lstm_hidden_size * 2, output_size)# nn.Linear(输入维度,输出维度):[ 上一步输出的LSTM维度*2(双向) , 10分类 ]def forward(self, x):"""前向传播"""out, _ = self.lstm(x)  # 过一个LSTM [batch_size, seq_len, embedding]out = self.fc(out[:, -1, :])return out

开始训练

构建训练数据集

构建batch数据

if __name__ == '__main__':epochs = 1000batch_size = 30embedding_size = 350output_class = 14x_train, y_train = get_total_train_data(embedding_size, output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)

训练三件套:模型,loss,优化器

   model = TextRnnModel(embedding_size=embedding_size, lstm_hidden_size=200, output_size=output_class)cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器

开始训练

   model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq)  # 压缩维度:得到输出,并将维度为1的去除single_loss = cross_loss(y_pred, labels.squeeze())# single_loss = cross_loss(y_pred, labels)single_loss.backward()optimizer.step()print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))

完整代码

#!/usr/bin/env Python
# coding=utf-8
import torch
import torch.nn as nn
import torch.utils.data as Dataclass TextRnnModel(nn.Module):def __init__(self, embedding_size, lstm_hidden_size, output_size, num_layers=2, dropout=0.3):super(TextRnnModel, self).__init__()self.lstm = nn.LSTM(embedding_size,  # 词嵌入模型词语维度lstm_hidden_size,  # 隐层神经元的维度,为输出的维度num_layers,  # 构建两层的LSTM:堆叠LSTMbidirectional=True,  # 双向的LSTM:词向量从前往后走,再重后往前走,最后拼接起来batch_first=True,  # 把第一个维度的输入作为batch输入的维度dropout=dropout)self.fc = nn.Linear(lstm_hidden_size * 2, output_size)# nn.Linear(输入维度,输出维度):[ 上一步输出的LSTM维度*2(双向) , 10分类 ]def forward(self, x):"""前向传播"""out, _ = self.lstm(x)  # 过一个LSTM [batch_size, seq_len, embedding]out = self.fc(out[:, -1, :])return outdef get_total_train_data(word_embedding_size, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, 20, word_embedding_size)))  # 维度是 [ 数据量, 一句话的单词个数, 句子的embedding]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':epochs = 1000batch_size = 30embedding_size = 350output_class = 14x_train, y_train = get_total_train_data(embedding_size, output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)model = TextRnnModel(embedding_size=embedding_size, lstm_hidden_size=200, output_size=output_class)cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq)  # 压缩维度:得到输出,并将维度为1的去除single_loss = cross_loss(y_pred, labels.squeeze())# single_loss = cross_loss(y_pred, labels)single_loss.backward()optimizer.step()print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))

pytorch搭建TextRNN与使用案例相关推荐

  1. pytorch搭建TextCNN与使用案例

    TextCNN算法流程 整体流程是将词拼接在一起,一句话构成一个特征图 根据卷积核得到多个特征向量 每个特征向量全局池化,选最大的特征作为这个特征向量的值 拼接特征值,得到句子的特征向量 全连接后得到 ...

  2. 一、pytorch搭建实战以及sequential的使用

    一.pytorch搭建实战以及sequential的使用 1.A sequential container 2.搭建cifar10 model structure 3.创建实例进行测试(可以检查网络是 ...

  3. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  4. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

  5. Flume环境搭建_五种案例(转)

    Flume环境搭建_五种案例 http://flume.apache.org/FlumeUserGuide.html A simple example Here, we give an example ...

  6. Pytorch搭建自己的模型

    前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...

  7. Educoder 机器学习 神经网络 第四关:使用pytorch搭建卷积神经网络识别手写数字

    任务描述 相关知识 卷积神经网络 为什么使用卷积神经网络 卷积 池化 全连接网络 卷积神经网络大致结构 pytorch构建卷积神经网络项目流程 数据集介绍与加载数据 构建模型 训练模型 保存模型 加载 ...

  8. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  9. 睿智的目标检测30——Pytorch搭建YoloV4目标检测平台

    睿智的目标检测30--Pytorch搭建YoloV4目标检测平台 学习前言 什么是YOLOV4 代码下载 YOLOV4改进的部分(不完全) YOLOV4结构解析 1.主干特征提取网络Backbone ...

最新文章

  1. Dajngo admin使用
  2. android界面布局题,【填空题】Android 系统中, 用于定义布局显示在界面上的风格。...
  3. 二分图最小覆盖的Konig定理及其证明
  4. Java中使用Base64进行编码解码的工具类-将验证码图片使用Base64编码并返回给前端
  5. 以主干开发作为持续交付的基础
  6. 第013课_代码重定位
  7. 如何在JS中应用正则表达式
  8. Linux下Socket编程的端口问题( Bind error: Address already in use )
  9. HDU 2017 一系列统计数据
  10. php 保存文件并换行,php是怎样向文件中写入换行_后端开发
  11. Ubuntu更改分辨率
  12. 国军标 软件测评 静态分析常见问题总结
  13. word画流程图工具
  14. ai怎么渐变颜色_AI渐变工具怎么使用?AI渐变工具使用方法介绍
  15. ZOC7 for Mac(终端仿真器)含注册码 v7.22.7激活版
  16. Clover 驱动文件夹_黑苹果(clover文件夹中各个文件的主要功能)
  17. 使用MySQL的binlog日志恢复误删数据
  18. Python中[ : n]、[m : ]、[-n]、[:-n]、[::-n]、[m::-n]和[m:]的含义
  19. GVS视声引入睿住资本,完成A轮融资
  20. 区块链与分布式数据库的区别

热门文章

  1. 使用Python快速获取公众号文章定制电子书(一)
  2. Grafana密码重置为admin
  3. C# Excel导入、导出
  4. iPhone判断是否已插入SIM卡
  5. 在Windows上build Spark
  6. RAID简介[zz]
  7. navigator属性
  8. jquery数组怎么传给后台_我是如何让公司后台管理系统焕然一新的(下)封装组件...
  9. 思科交换机配置试题_【干货】思科交换机路由器怎么配置密码?
  10. python中类的属性一般来说_python中实例属性和类属性之间的关系