论文地址:http://www.nlpr.ia.ac.cn/cip/~liukang/liukangPageFile/Recurrent%20Convolutional%20Neural%20Networks%20for%20Text%20Classification.pdf

案例代码

直接运行即可,然后根据自己的任务做

#!/usr/bin/env Python
# coding=utf-8
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as Fclass TextRCNNModel(nn.Module):def __init__(self, config):super(TextRCNNModel, self).__init__()self.lstm = nn.LSTM(config['embedding_size'], config['lstm_hidden_size'], 1,bidirectional=True, batch_first=True)self.maxpool = nn.MaxPool1d(config['pad_size'])self.fc = nn.Linear(config['lstm_hidden_size'] * 2  # 由于是双向LSTM,所以这里 * 2+ config['embedding_size'], config['output_size'])def forward(self, x):out, _ = self.lstm(x)out = torch.cat((x, out), 2)out = F.relu(out)out = out.permute(0, 2, 1)out = self.maxpool(out).squeeze()out = self.fc(out)return outdef get_total_train_data(word_embedding_size, class_count, pad_size):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, pad_size, word_embedding_size)))  # 维度是 [ 数据量, 一句话的词个数, 每个词的embedding]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 1000batch_size = 30output_class = 14embedding_size = 350# ================模型参数=================config = {# 重要参数'embedding_size': embedding_size,  # 输入的字的embedding的长度'output_size': output_class,  # 最终分类任务的数量'pad_size': 40,  # 每句话的字的个数# 次要参数'lstm_hidden_size': 256,}# ================开始训练================x_train, y_train = get_total_train_data(embedding_size, output_class, config['pad_size'])train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)model = TextRCNNModel(config=config)cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq)  # 压缩维度:得到输出,并将维度为1的去除single_loss = cross_loss(y_pred, labels.squeeze())single_loss.backward()optimizer.step()print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))

pytorch搭建TextRCNN模型与使用案例相关推荐

  1. python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类

    任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...

  2. 使用Pytorch搭建CNN模型完成食物图片分类(李宏毅视频课2020作业3,附超详细代码讲解)

    文章目录 0 前言 1 任务描述 1.1 数据描述 1.2 作业提交 1.3 数据下载 1.3.1 完整数据集 1.3.2 部分数据集 2 过程讲解 2.1 读取数据 2.2 数据预处理 2.3 模型 ...

  3. 第18课:项目实战——利用 PyTorch 构建 RNN 模型

    上一篇,我们主要介绍了基本的 RNN 模型和 LSTM.本文将通过一个实战项目带大家使用 PyTorch 搭建 RNN 模型. 本项目将构建一个 RNN 模型,来对 MNIST 手写数据集进行分类.可 ...

  4. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  5. Pytorch搭建自己的模型

    前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...

  6. 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络

    Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...

  7. Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成

    Diffusion扩散模型学习1--Pytorch搭建DDPM利用深度卷积神经网络实现图片生成 学习前言 源码下载地址 网络构建 一.什么是Diffusion 1.加噪过程 2.去噪过程 二.DDPM ...

  8. PyTorch入门(二)搭建MLP模型实现分类任务

      本文是PyTorch入门的第二篇文章,后续将会持续更新,作为PyTorch系列文章.   本文将会介绍如何使用PyTorch来搭建简单的MLP(Multi-layer Perceptron,多层感 ...

  9. ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练

    1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...

最新文章

  1. 学习知识[置顶] C++学习方式方法
  2. ubuntu下创建图标
  3. Netty完成网络通信(二)
  4. java字符串是不是整数的函数_java判断字符串是否为整数的方法
  5. timeval的时间转换成毫秒之后多大的数据类型可以装下
  6. c和JAVA的安全编码_C、C++ 和 Java安全编码实践提示与技巧
  7. 数据库设计优化经验谈(转载)
  8. 注册时,邮箱自动发送验证
  9. wraper for bootstrap3.0 + simple_form
  10. nginx RTMP FFmpeg 视频直播
  11. educoder算法设计与分析 实验三 动态规划实验
  12. 计算机网络测速创新,一种计算机网络安全测速装置的制作方法
  13. 举个栗子!Tableau 技巧(131):用烛台图 Candlestick Chart 分析价格波动
  14. (002)循环语句,数组,方法,走进面向对象(封装)
  15. 计算机网络【2】—— CSMA/CD协议
  16. pythonl list 的修改元素
  17. Vue+bpmn.js自定义流程图之palette(二)
  18. java cps变换_C#中的递归APS和CPS模式详解
  19. 第七届中国信息技术服务产业年会 在西安隆重召开
  20. php yii2 sns,GitHub - 13240031972/iisns: 基于 yii2 的 sns 社区系统,一站式解决社区建站...

热门文章

  1. iOS内存管理机制解析
  2. Zen Garden驾到:首批Metal游戏已登录iTunes应用商店
  3. 第四十六章:SpringBoot RabbitMQ完成消息延迟消费
  4. 扎根本地连接未来 千米网的电商“红海”生存术
  5. Servlet体系及方法
  6. Spring Boot - 开发Web应用
  7. MBG配置详解及最佳实践
  8. ajax_典型应用_添加商品标题
  9. JSP:src路径里有中文,产生乱码问题
  10. 更新整理本人所有博文中提供的代码与工具(C++,2013.08)