TextCNN算法流程

  • 整体流程是将词拼接在一起,一句话构成一个特征图
  • 根据卷积核得到多个特征向量
  • 每个特征向量全局池化,选最大的特征作为这个特征向量的值
  • 拼接特征值,得到句子的特征向量
  • 全连接后得到目标维度的结果

完整代码

#!/usr/bin/env Python
# coding=utf-8
import torch
import torch.nn as nn
import torch.utils.data as Data
import torch.nn.functional as Fclass TextCnnModel(nn.Module):def __init__(self, embedding_size, output_size, channels=256, filter_sizes=(2, 3, 4), dropout=0.5):"""TextCnn 模型做文本分类:param embedding_size: 每个词向量的embedding长度:param output_size: 最后的输出个数、待分类的个数:param channels: 卷积核的数量:param filter_sizes: 卷积核尺寸:param dropout: 随机失活概率"""super(TextCnnModel, self).__init__()self.convs = nn.ModuleList([nn.Conv2d(1, channels, (k, embedding_size)) for k in filter_sizes])self.dropout = nn.Dropout(dropout)self.fc = nn.Linear(channels * len(filter_sizes), output_size)def conv_and_pool(self, x, conv):x = F.relu(conv(x)).squeeze(3)x = F.max_pool1d(x, x.size(2)).squeeze(2)return xdef forward(self, x):out = x.unsqueeze(1)out = torch.cat([self.conv_and_pool(out, conv) for conv in self.convs], 1)out = self.dropout(out)out = self.fc(out)return outdef get_total_train_data(word_embedding_size, class_count):"""得到全部的训练数据,这里需要替换成自己的数据"""import numpy as npx_train = torch.Tensor(np.random.random((1000, 20, word_embedding_size)))  # 维度是 [ 数据量, 一句话的单词个数, 句子的embedding]y_train = torch.Tensor(np.random.randint(0, class_count, size=(1000, 1))).long()  # [ 数据量, 句子的分类], 这里的class_count=4,就是四分类任务return x_train, y_trainif __name__ == '__main__':# ================参数=================epochs = 1000batch_size = 30embedding_size = 350output_class = 14# ================开始训练================x_train, y_train = get_total_train_data(embedding_size, output_class)train_loader = Data.DataLoader(dataset=Data.TensorDataset(x_train, y_train),  # 封装进Data.TensorDataset()类的数据,可以为任意维度batch_size=batch_size,  # 每块的大小shuffle=True,  # 要不要打乱数据 (打乱比较好)num_workers=6,  # 多进程(multiprocess)来读数据drop_last=True,)model = TextCnnModel(embedding_size=embedding_size, output_size=output_class)cross_loss = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # 优化器model.train()for i in range(epochs):for seq, labels in train_loader:optimizer.zero_grad()y_pred = model(seq)  # 压缩维度:得到输出,并将维度为1的去除single_loss = cross_loss(y_pred, labels.squeeze())single_loss.backward()optimizer.step()print("Step: " + str(i) + " loss : " + str(single_loss.detach().numpy()))

pytorch搭建TextCNN与使用案例相关推荐

  1. pytorch搭建TextRNN与使用案例

    文章目录 生成训练数据 构建TextRNN 开始训练 构建训练数据集 训练三件套:模型,loss,优化器 开始训练 完整代码 生成训练数据 这里使用随机数生成训练数据,大家在自己写的时候只需要替换这里 ...

  2. 一、pytorch搭建实战以及sequential的使用

    一.pytorch搭建实战以及sequential的使用 1.A sequential container 2.搭建cifar10 model structure 3.创建实例进行测试(可以检查网络是 ...

  3. 【深度学习】——利用pytorch搭建一个完整的深度学习项目(构建模型、加载数据集、参数配置、训练、模型保存、预测)

    目录 一.深度学习项目的基本构成 二.实战(猫狗分类) 1.数据集下载 2.dataset.py文件 3.model.py 4.config.py 5.predict.py 一.深度学习项目的基本构成 ...

  4. 基于PyTorch搭建CNN实现视频动作分类任务代码详解

    数据及具体讲解来源: 基于PyTorch搭建CNN实现视频动作分类任务 import torch import torch.nn as nn import torchvision.transforms ...

  5. Flume环境搭建_五种案例(转)

    Flume环境搭建_五种案例 http://flume.apache.org/FlumeUserGuide.html A simple example Here, we give an example ...

  6. Pytorch搭建自己的模型

    前言 PyTorch.TensorFlow都是主流的深度学习框架,今天主要讲解一下如何快速使用pytorch搭建自己的模型.至于为什么选择讲解pytorch,这里我就简单说明一下自己的使用感受(相对T ...

  7. Educoder 机器学习 神经网络 第四关:使用pytorch搭建卷积神经网络识别手写数字

    任务描述 相关知识 卷积神经网络 为什么使用卷积神经网络 卷积 池化 全连接网络 卷积神经网络大致结构 pytorch构建卷积神经网络项目流程 数据集介绍与加载数据 构建模型 训练模型 保存模型 加载 ...

  8. 使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记

    使用Pytorch搭建U-Net网络并基于DRIVE数据集训练(语义分割)学习笔记 https://www.bilibili.com/video/BV1rq4y1w7xM?spm_id_from=33 ...

  9. 睿智的目标检测30——Pytorch搭建YoloV4目标检测平台

    睿智的目标检测30--Pytorch搭建YoloV4目标检测平台 学习前言 什么是YOLOV4 代码下载 YOLOV4改进的部分(不完全) YOLOV4结构解析 1.主干特征提取网络Backbone ...

最新文章

  1. 高德技术评测建设之路
  2. python函数能否增强代码可读性_总结的几个Python函数方法设计原则
  3. 原码、补码、反码详解
  4. websocket实现单聊
  5. gbk 转 UTF-8
  6. Idea 封装Java代码片段 快速提示
  7. strstrsubstr、AfxGetApp
  8. html 字体图标 颜色怎么改,关于css:如何设置Font Awesome Icons的图标颜色,大小和阴影的样式...
  9. mysql for centos_CentOs中mysql的安装与配置
  10. java中重试的使用工具
  11. 刘明计算机学院,西南大学计算机与信息科学学院研究生导师简介-刘明
  12. RGB 与 RGBA 与 16进制 与 HSL 之间的简单转换
  13. 什么是大数据以及大数据的相关技术?
  14. CAT1模块EC200S 4G物联网模块串口透传MQTT协议 快速入门指导资料
  15. css3 简单的动画实现欢乐愉快的小鱼
  16. Eigen实现克罗内克内积
  17. 如何用计算机求特征值特征向量,利用QR算法求解矩阵的特征值和特征向量
  18. Spring实战学习笔记
  19. R时间序列分析|SP500股指的ARIMA模型预测与残差ARCH效应分析
  20. DOM 树的解析渲染

热门文章

  1. 用python写一个文件管理程序下载_Python管理文件神器 os.walk
  2. 未来函数在线检测_嵌入式实时操作系统任务栈溢出检测原理
  3. python种颜色循环_python – 重置Matplotlib中的颜色循环
  4. Unreal Engine 4切换默认Camera实现
  5. UE4手册中文翻译速查表
  6. PYTHON学习0011:enumerate()函数的用法----2019-6-8
  7. vuex其实超简单,只需3步
  8. ROOBO公布A轮1亿美元融资 发布人工智能机器人系统
  9. angularjs的三种注入方式
  10. 统一沟通-技巧-7-Lync 2010-配置信息-EWS未部署