pytorch搭建TextRCNN模型与使用案例
论文地址:http://www.nlpr.ia.ac.cn/cip/~liukang/liukangPageFile/Recurrent%20Convolutional%20Neural%20Networks%20for%20Text%20Classification.pdf
案例代码
直接运行即可,然后根据自己的任务做
#!/usr/bin/env Python
# coding=utf-8
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as Fclass TextRCNNModel(nn.Module):def __init__(self, config):super(TextRCNNModel, self).__init__()self.lstm = nn.LSTM(config['embedding_size'], config['lstm_hidden_size'], 1,bidirectional=True, batch_first=True)self.maxpool = nn.MaxPool1d(config['pad_size'])self.fc = nn.Linear(config['lstm_hidden_size'] * 2 # 由于是双向LSTM,所以这里 * 2+ config['embedding_size'], config['output_size'])def forward(self, x):out, _ = self.lstm(x)out = torch.cat((x, out), 2)out = F.relu(out)out = out.permute(0, 2, 1)out = self.maxpool(out).squeeze()out = self.fc(out)return outdef get_total_train_data(word_embedding_size, class_count, pad_size):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, pad_size, word_embedding_size))) # 维度是 [ 数据量, 一句话的词个数, 每个词的embedding]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long() # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================训练参数=================epochs = 1000batch_size = 30output_class = 14embedding_size = 350# ================模型参数=================config = {# 重要参数'embedding_size': embedding_size, # 输入的字的embedding的长度'output_size': output_class, # 最终分类任务的数量'pad_size': 40, # 每句话的字的个数# 次要参数'lstm_hidden_size': 256,}# ================开始训练================x_train, y_train = get_total_train_data(embedding_size, output_class, config['pad_size'])train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train), # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size, # 每块的大小shuffle=True, # 要不要打乱数据 (打乱比较好)num_workers=6, # 多进程(multiprocess)来读数据drop_last=True,)model = TextRCNNModel(config=config)cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 优化器model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq) # 压缩维度:得到输出,并将维度为1的去除single_loss = cross_loss(y_pred, labels.squeeze())single_loss.backward()optimizer.step()print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))
pytorch搭建TextRCNN模型与使用案例相关推荐
- python与机器学习(七)上——PyTorch搭建LeNet模型进行MNIST分类
任务要求:利用PyTorch框架搭建一个LeNet模型,并针对MNIST数据集进行训练和测试. 数据集:MNIST 导入: import torch from torch import nn, opt ...
- 使用Pytorch搭建CNN模型完成食物图片分类(李宏毅视频课2020作业3,附超详细代码讲解)
文章目录 0 前言 1 任务描述 1.1 数据描述 1.2 作业提交 1.3 数据下载 1.3.1 完整数据集 1.3.2 部分数据集 2 过程讲解 2.1 读取数据 2.2 数据预处理 2.3 模型 ...
- 第18课:项目实战——利用 PyTorch 构建 RNN 模型
上一篇,我们主要介绍了基本的 RNN 模型和 LSTM.本文将通过一个实战项目带大家使用 PyTorch 搭建 RNN 模型. 本项目将构建一个 RNN 模型,来对 MNIST 手写数据集进行分类.可 ...
- 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)
目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...
- Pytorch搭建自己的模型
前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...
- 利用Pytorch搭建简单的图像分类模型(之二)---搭建网络
Pytorch搭建网络模型-ResNet 一.ResNet的两个结构 首先来看一下ResNet和一般卷积网络结构上的差异: 图中上面一部分就是ResNet34的网络结构图,下面可以理解为一个含有34层 ...
- Diffusion扩散模型学习1——Pytorch搭建DDPM实现图片生成
Diffusion扩散模型学习1--Pytorch搭建DDPM利用深度卷积神经网络实现图片生成 学习前言 源码下载地址 网络构建 一.什么是Diffusion 1.加噪过程 2.去噪过程 二.DDPM ...
- PyTorch入门(二)搭建MLP模型实现分类任务
本文是PyTorch入门的第二篇文章,后续将会持续更新,作为PyTorch系列文章. 本文将会介绍如何使用PyTorch来搭建简单的MLP(Multi-layer Perceptron,多层感 ...
- ResNet网络详解并使用pytorch搭建模型、并基于迁移学习训练
1.ResNet网络详解 网络中的创新点: (1)超深的网络结构(突破1000层) (2)提出residual模块 (3)使用Batch Normalization加速训练(丢弃dropout) (1 ...
最新文章
- 学习知识[置顶] C++学习方式方法
- ubuntu下创建图标
- Netty完成网络通信(二)
- java字符串是不是整数的函数_java判断字符串是否为整数的方法
- timeval的时间转换成毫秒之后多大的数据类型可以装下
- c和JAVA的安全编码_C、C++ 和 Java安全编码实践提示与技巧
- 数据库设计优化经验谈(转载)
- 注册时,邮箱自动发送验证
- wraper for bootstrap3.0 + simple_form
- nginx RTMP FFmpeg 视频直播
- educoder算法设计与分析 实验三 动态规划实验
- 计算机网络测速创新,一种计算机网络安全测速装置的制作方法
- 举个栗子!Tableau 技巧(131):用烛台图 Candlestick Chart 分析价格波动
- (002)循环语句,数组,方法,走进面向对象(封装)
- 计算机网络【2】—— CSMA/CD协议
- pythonl list 的修改元素
- Vue+bpmn.js自定义流程图之palette(二)
- java cps变换_C#中的递归APS和CPS模式详解
- 第七届中国信息技术服务产业年会 在西安隆重召开
- php yii2 sns,GitHub - 13240031972/iisns: 基于 yii2 的 sns 社区系统,一站式解决社区建站...