六年的大学生涯结束了,目前在搜索推荐岗位上继续进阶,近期正好在做类目预测多标签分类的项目,因此把相关的模型记录总结一下,便于后续查阅总结。

一、理论篇:
在我们的场景中,文本数据量比较大,因此直接采用深度学习模型来预测文本类目的多标签,而TextCNN向来以速度快,准确率高著称。TextCNN的核心思想是抓取文本的局部特征:通过不同的卷积核尺寸(确切的说是卷积核高度)来提取文本的N-gram信息,然后通过最大池化操作来突出各个卷积操作提取的最关键信息(颇有一番Attention的味道),拼接后通过全连接层对特征进行组合,最后通过交叉熵损失函数来训练模型。

模型的第一层就是Embedding层,预训练的词嵌入可以利用其它语料库得到更多的先验知识,经过模型训练后能够抓住与当前任务最相关的文本特征。在我们的应用场景中,使用预训练的Embedding比随机初始化带来的效果不是特别显著。第二层为卷积层,CV中常见的卷积尺寸通常是正方形,而本文的卷积尺寸与之不同,本文的卷积宽度等于文本Embedding后的维度,保持不变,因为每个词或字相当于一个最小的单元,不可进一步分割。而卷积核的高度可以自定义,在向下滑动的过程中,通过定义不同的窗口来提取不同的特征向量,有点类似于N-gram过程。这样不同的kernel可以获取不同范围内词的关系,获得的是纵向的差异信息,也就是在一个句子中不同范围的词出现会带来什么信息。比如可以使用3,4,5个词数分别作为卷积核的大小),每个卷积尺寸下又有多个相同的卷积核(原因是卷积神经网络学习的是卷积核中的参数,每个filter都有自己的关注点,它们可以从同一个窗口学习相互之间互补的特征,这样多个卷积核就能学习到不同的信息,比如可以设置size为3的filter有4个卷积核)。

第三层是最大池化层,即为从每个滑动窗口产生的特征向量中筛选出一个最大的特征,然后将这些特征拼接起来构成向量表示。也可以选用K-Max池化(选出每个特征向量中最大的K个特征),或者平均池化(将特征向量中的每一维取平均)等,达到的效果都是将不同长度的句子通过池化得到一个定长的向量表示。在短文本分类场景中,每条文本中都会有一些对分类无用的信息,而最大池化可以突出最重要的关键词以帮助模型更容易找到对应的类目。

接下来的几层就跟具体的任务相关了,一般都会拼接特征,在通过全连接层自由组合提取出来的特征实现分类。在损失函数上,二分类和多标签分类可以采用基于Sigmoid函数的交叉熵损失函数binary_crossentropy多分类任务可以采用基于Softmax的多类别交叉熵损失函数(categorical_crossentropy)。

二、代码部分:

def textcnn(hyper_parameters):input = Input(shape=(hyper_parameters.max_len,))if hyper_parameters.embedding_matrix is None:embedding = Embedding(input_dim=hyper_parameters.vocab + 1,output_dim=hyper_parameters.emd_dim,input_length=hyper_parameters.MAX_LEN,trainable=True)(input)else:  # 使用预训练矩阵初始化Embeddingembedding = Embedding(input_dim=hyper_parameters.vocab + 1,output_dim=hyper_parameters.emd_dim,weights=[hyper_parameters.embedding_matrix],input_length=hyper_parameters.MAX_LEN,trainable=False)(input)convs = []for kernel_size in hyper_parameters.kernel_size:conv = Conv1D(hyper_parameters.conv_code, kernel_size,activation=hyper_parameters.relu)(embedding)pool = MaxPooling1D()(conv)convs.append(pool)concat = Concatenate()(convs)flattern = Flatten()(concat)dropout = Dropout(hyper_parameters.dropout)(flattern)output = Dense(hyper_parameters.classes, activation=hyper_parameters.sigmoid)(dropout)model = Model(input, output)model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])return model

在Embedding部分,如果有条件可以使用自己预训练的文本信息来初始化Embedding矩阵,效果可能会比随机初始化Embedding提升一点。

三、几点思考:

1.TextCNN能用于文本分类的主要原因是什么?

除了预训练文本外,TextCNN通过利用不同的卷积核尺寸并行提取文本的信息(类似N-gram),并通过最大池化来突出最重要的关键词来实现分类。

2.TextCNN的缺点:

2.1. TextCNN的卷积和池化操作会丢失文本序列中的词汇顺序和位置信息等内容,但也可利用这一点来增强文本,例如白色旅游鞋,可以添加旅游鞋白色数据等,分词后白色和旅游鞋位置就可以互换来丰富语料 。

2.2. 在长文本使用TextCNN效果可能没有在短文本中效果好(具体还需要实践确认),原因同上,因此可以尝试使用Top k池化替代最大池化提取更多的文本信息。

https://arxiv.org/pdf/1408.5882.pdf​arxiv.org小占同学:深入TextCNN(一)详述CNN及TextCNN原理​zhuanlan.zhihu.com

文本分类模型_文本分类模型之TextCNN相关推荐

  1. python 多分类情感_文本情感分类(一):传统模型

    前言:四五月份的时候,我参加了两个数据挖掘相关的竞赛,分别是物电学院举办的"亮剑杯",以及第三届 "泰迪杯"全国大学生数据挖掘竞赛.很碰巧的是,两个比赛中,都有 ...

  2. textcnn文本词向量_文本分类模型之TextCNN

    六年的大学生涯结束了,目前在搜索推荐岗位上继续进阶,近期正好在做类目预测多标签分类的项目,因此把相关的模型记录总结一下,便于后续查阅总结. 一.理论篇: 在我们的场景中,文本数据量比较大,因此直接采用 ...

  3. 人口预测和阻尼-增长模型_使用分类模型预测利率-第2部分

    人口预测和阻尼-增长模型 We are back! This post is a continuation of the series "Predicting Interest Rate w ...

  4. python分类流程_文本分类指南:你真的要错过 Python 吗?

    雷锋网按:本文为雷锋字幕组编译的技术博客,原标题 A Comprehensive Guide to Understand and Implement Text Classification in Py ...

  5. lstm模型_基于LSTM模型的学生反馈文本学业情绪识别方法

    | 全文共8155字,建议阅读时长8分钟 | 本文由<开放教育研究>授权发布 作者:冯翔 邱龙辉 郭晓然 摘要 分析学生学习过程产生的反馈文本,是发现其学业情绪的重要方式.传统的学业情绪 ...

  6. svm多分类代码_监督学习——分类算法I

    本文是监督学习分类算法的第一部分,简单介绍对样本进行分类的算法,包括 判别分析(DA) 支持向量机(SVM) 随机梯度下降分类(SGD) K近邻分类(KNN) 朴素贝叶斯分类(NaiveBayes) ...

  7. 决策树模型 朴素贝叶斯模型_有关决策树模型的概述

    决策树模型 朴素贝叶斯模型 Decision Trees are one of the highly interpretable models and can perform both classif ...

  8. Scikit-learn_分类算法_支持向量机分类

    一.描述 支持向量机的基本原理是找到一个将所有数据样本分隔成两部分的超平面,使所有样本到这个超平面的累计距离最短. 超平面是指n维线性空间中维度为n-1的子空间.例如,在二维平面中,一维的直线可以将二 ...

  9. 时间序列分类算法_时间序列分类算法简介

    时间序列分类算法 A common task for time series machine learning is classification. Given a set of time serie ...

最新文章

  1. 计算机协会成立活动简报,“中国计算机学会CCF走进高校”活动在我校举行
  2. 这次被问懵了!搞定了这些SQL优化技巧,下次横着走
  3. ML之Clustering之普聚类算法:普聚类算法的相关论文、主要思路、关键步骤、代码实现等相关配图之详细攻略
  4. 锻炼能降低13种癌症风险
  5. SQL语言之DQL语言学习(九)多表查询/链接查询 SQL99学习
  6. 出现了奇数次的数字的算法
  7. FewRel 2.0数据集:以近知远,以一知万,少次学习新挑战
  8. VTK:Points之PointOccupancy
  9. Radar Installation
  10. 线性代数四之动态DP(广义矩阵加速)——Can you answer these queries III,保卫王国
  11. 使用supervisord 来守护 nginx进程
  12. Flutter教程app
  13. Storm实验 -- 单词计数4
  14. 云服务器网站301重定向跳转有什么作用?
  15. CSS3实现轮播图效果
  16. (1)python基础语法
  17. psp2000 M33 自制固件---恢复模式说明(基本所有版本都适用)
  18. windows xp虚拟机安装教程
  19. linux img工具,线刷包img提取工具(simg2img win)
  20. mysql被删库如何恢复_mysql整个数据库被删除了怎么恢复

热门文章

  1. c语言总是说有一个错误,我的电脑上的c语言为何老有一个错误
  2. 单纯形 c语言 程序,单纯形法完全c语言程序
  3. oracle如何查询系统变量数据,Oracle如何对IN子查询使用绑定变量(转)
  4. Python中异常处理的用法
  5. python新手遇到的5大坑
  6. Python-list中的append()和extend()方法区别
  7. Python字符串删除指定符号(不限位置)
  8. python flask解决上传下载的问题
  9. Intel Realsense D435 奇怪的现象记录:帧卡住,但wait_for_frame()不报错
  10. python 文件操作 os.path.join(path, *paths) 路径合成(追加)