一、写在前面的话

最近项目需要做一个针对内容的打标签系统,这里的内容是CSDN网站上面用户创作的内容,例如,博客、问答等,打上CSDN统一标签之后有利于对内容的归类和检索,即知识的结构化。

CSDN统一标签目前大概有400-500个,有大类和小类两个层级,对于python这个大类来说,下面的小类有:python,list,django,virtualenv,tornado,flask等标签。大家都知道每个博客的标签数量是不固定的,有可能是1个也有可能是多个,所以这里是一个多标签分类的场景。

博客数据拥有用户自己打的标签,有一些还是挺准确的,训练数据的获取还是比较容易的。模型也比较简单,下面直接开始介绍吧。

二、模型部分

2.1 框架选择

工业界的话选择tensorflow,开源,其生态比较完善,不论模型的互转,还是推理引擎的支持,其都有相应的项目支持,并且tensorflow2.0也开始支持动态图,习惯用pytorch做模型的小伙伴也不妨尝试一下。

2.2 模型搭建

这里选择在textcnn的基础之上进行改进,不熟悉textcnn的可以先自行百度一下,至于为什么选择textcnn,也是因为在众多可选的分类器中,textcnn应该是性价比最高的一款了,一般来说,其效果好于机器学习分类算法,例如svm,但是比bert等预训练模型又差一些。同时考虑到工程应用要考虑到推理速度,硬件成本等,textcnn就成了首选。

首先来看一下textcnn的基本结构,这里借用论文A Sensitivity Analysis of (and Practitioners’ Guide to) Convolutional Neural Networks for Sentence Classification的结构图来进行说明:

这个图其实已经非常好理解了,前面的几层就不多说,主要是enbedding层和卷积层部分,直接看改造部分,原本的textcnn在全连接层之后经过softmax就可以得到每个类别的概率,取概率最高的一个的话就是一个单标签多分类器,现在让全连接层输出num_class*2,然后reshape成batch_size*num_class*2,即让每个类别单独做一个二分类,计算每个二分类loss。最后两层改造之后的结构为:

这里直接使用tensorflow2的动态图(sub-class model)尝试搭建模型,配置如下:

class BlogTagClassifyConfig(object):# Data loading paramsfile_train_set = "./data/pro/datasets/tags/blog/recommend/train.txt"file_dev_set = "./data/pro/datasets/tags/blog/recommend/dev.txt"out_dir = "./data/pro/models/tag/"# Model Hyperparametersvocab_size = 0embedding_dim = 300dropout_rate = 0.6num_classes = 50regularizers_lambda = 0.2filter_sizes = "2,3,4"num_filters = 128seq_length = 120learning_rate = 3e-4# Training parametersbatch_size = 256num_epochs = 100evaluate_every = 100print_every = 50# 阈值threshold = 0.8

模型:

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@Time    :   2021/07/07
@Author  :   clong
@Descript:   博客标签分类模型
'''import tensorflow as tf
from tensorflow import kerasclass TextCNN(tf.keras.Model):"""TextCNN模型"""def __init__(self, config):super(TextCNN, self).__init__()self.config = configself.embedding = tf.keras.layers.Embedding(self.config.vocab_size+1, self.config.embedding_dim,mask_zero=True,input_length=self.config.seq_length,name='embedding')self.add_channel = tf.keras.layers.Reshape((self.config.seq_length, self.config.embedding_dim, 1), name='add_channel')self.conv_pool = self.build_conv_pool()self.dropout = tf.keras.layers.Dropout(self.config.dropout_rate, name='dropout')self.dense = tf.keras.layers.Dense(self.config.num_classes*2,kernel_regularizer=tf.keras.regularizers.l2(self.config.regularizers_lambda),bias_regularizer=tf.keras.regularizers.l2(self.config.regularizers_lambda),name='dense')self.flatten = tf.keras.layers.Flatten(data_format='channels_last', name='flatten')self.reshape = tf.keras.layers.Reshape((self.config.num_classes, 2), name='reshape')self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.config.learning_rate)def build_conv_pool(self):def conv_pool(embed):pool_outputs = []for filter_size in list(map(int, self.config.filter_sizes.split(','))):filter_shape = (self.config.filter_size, self.config.embedding_dim)conv = keras.layers.Conv2D(self.config.num_filters, filter_shape, strides=(1, 1), padding='valid',data_format='channels_last', activation='relu',kernel_initializer='glorot_normal',bias_initializer=keras.initializers.constant(0.1),name='convolution_{:d}'.format(filter_size))(embed)max_pool_shape = (self.config.seq_length - filter_size + 1, 1)pool = keras.layers.MaxPool2D(pool_size=max_pool_shape,strides=(1, 1), padding='valid',data_format='channels_last',name='max_pooling_{:d}'.format(filter_size))(conv)pool_outputs.append(pool)return pool_outputsreturn conv_pool@tf.functiondef call(self, x, training=None):x = self.embedding(x)x = self.add_channel(x)x = self.conv_pool(x)x = tf.keras.layers.concatenate(x, axis=-1, name='concatenate')x = self.flatten(x)x = self.dropout(x, training)x = self.dense(x)x = self.reshape(x)x = tf.nn.softmax(x, axis=-1, name="softmax")return x

三、数据部分

3.1 数据获取

数据采用博客数据,由于数据的不平衡性,现在取50个类别进行验证,每个类别取2000条数据。

3.2 词典构造

如果有领域词典的话,最好将领域词典和标签加入分词自定义词典,对分词后的博客数据统计词频,按照词频进行排序,这里取词频大于10的60000个词。为了防止漏掉标签词特征,可以将标签也加入到词典中。

这里也可以采用两外两种方式进行尝试:分字,采用字符级的词典,这样词典就会比较小,但是可能训练收敛的速度就会慢一些。另外一种就是特征选择算法挑选特征构建词典,例如卡方验证,信息增益等等(计算量大)。

对于词典中未登录的词(unknown words),网上一般做法都是在词典添加[UNK]词或者随机选择一个未知词,我们的场景只是一个分类模型,不会涉及太多的语义,直接去掉未登录的词,做简化处理。

四、附加策略

由于数据的原因,分类器也不可能百分百达正确,有些时候,也有一些漏掉的情况,例如标题出现了标签的名称,但是分类器却没有打上相关的标签,由于IT行业词汇没有什么歧义,这里可以利用标题和用户自定义标签,对它们进行分词,如果这里面出现了标签或者标签的同义词,则打上相关标签。

五、写在最后

多标签分类器其实还要涉及到分类结果的评估,也可以使用编辑距离等来计算相似度,但我这里更加重视模型打上标签的情况,故没采用通用的方法进行评估。

从结果来看,打标签的效果还是可以的,高阈值的标签都是有着很高的相关性。

================2021/12/13================

增加了demo代码:行走的人偶 / textcnn_demo · GIT CODE

基于博客标签的多标签分类器(multi-label classification)相关推荐

  1. 基于博客系统的访客日志记录

    当我们做的一些应用需要记录一些接口被访问时用户的信息时,我们就需要用到一些记录请求的技术,并记录日志到数据库.本文章使用的方法:注解+AOP 原理:事先在数据库中建立一个记录访客日志的一张表.先自定义 ...

  2. (附源码)计算机毕业设计ssm基于博客系统的UI手机界面展示

    项目运行 环境配置: Jdk1.8 + Tomcat7.0 + Mysql + HBuilderX(Webstorm也行)+ Eclispe(IntelliJ IDEA,Eclispe,MyEclis ...

  3. 基于博客系统的访客日志记录----代码合集

    本文章是基于我的另一篇博客所写的相关代码,如果还没看过的可以先看看我这篇文章: https://blog.csdn.net/qq_56769991/article/details/123915587 ...

  4. iOS_CNBlog项目开发 (基于博客园api开发) 上篇

    按照惯例, 先上效果图 前言 做这个项目是因为刚好在逛博客园的时候看到一篇文章 博客园第三方客户端-i博客园正式发布App Store, 这里就帮忙贴下链接吧. 整个项目做下来大概做了半个月, 今天算 ...

  5. github的博客搭建以及标签的自动化

    github博客搭建以及标签的自动化 引子 没有github的程序员,不是好程序员! BUT 如果有一个*.github.io的blog,会不会更酷? 基于以上原因,本拐也折腾了一下自己的github ...

  6. Django开发个人博客网站——12、实现不同大小的标签云样式

    1.创建标签页面 与上一节中创建归档页面一样,这里就不再赘述了,直接给出程序代码. tags.html {% extends 'base.html' %}{% block title %} 标签云 { ...

  7. 基于SpringBoot的个人博客系统

    项目编号:BS-PT-042 该博客是基于SpringBoot + Mybatis + Thymeleaf 等技术实现的 Java 博客系统,页面美观.功能齐全.部署简单及完善的代码,做为毕设项目的话 ...

  8. 【java毕业设计】基于javaEE+SSM+MySql的个人博客系统设计与实现(毕业论文+程序源码)——个人博客系统

    基于javaEE+SSM+MySql的个人博客系统设计与实现(毕业论文+程序源码) 大家好,今天给大家介绍基于javaEE+SSM+MySql的个人博客系统设计与实现,文章末尾附有本毕业设计的论文和源 ...

  9. 基于ssm的个人博客系统的设计与实现(含源文件)

    欢迎添加微信互相交流学习哦! 项目源码:https://gitee.com/oklongmm/biye 进入二十一世纪,以Internet为核心的现代网络积水和通信技术已经得到了飞速的发展和广泛的应用 ...

最新文章

  1. GitHub 上有什么好玩的项目?(附地址)
  2. js怎么获取访问页数记录(知道的能不能告诉我一下)
  3. Coding and Paper Letter(六)
  4. 阿里面试题:使用dubbo过程中遇到过哪些坑?
  5. 嵌入式论文3000字_SCI英文论文一般多少字
  6. eclipse中对单独JS文件取消报错的处理
  7. 设计模式学习笔记——策略(Strategy)模式
  8. html下拉框传递参数,HTML通过事件传递参数到js详解及实例
  9. scala的静态属性和静态方法
  10. pygame模块_pygame模块方法和事件
  11. iOS Crash常规跟踪方法及Bugly集成运用
  12. 微信小程序获取openId,传参问题导致参数无效(errcode: 40013、errcode:40125、errcode“:40029)
  13. eslint 如何关闭检查变量名规范或者大小写检查
  14. 自动删除QQ空间指定好友的留言
  15. Linux的 ls 和 ll 的使用发放、基本区别
  16. Android开源实战:手把手教你实现一个简单 好用的搜索框(含历史搜索记录
  17. 完美解决composer提示输入用户名和密码
  18. vue安装及创建运行
  19. IDEA设置todo快捷键
  20. Android 11 中文件存储(FileNotFoundException open failed: EPERM (Operation not permitted))

热门文章

  1. css 各个方向渐变(从左到右、从上到下、从左上角到右下角)
  2. 用PHP访问JasperReport
  3. 最全超实用的网站SEO优化方案步骤解析
  4. 什么是超视频时代的用户体验法则?
  5. 读书笔记-《赢在用户:Web人物角色创建和应用实践指南》
  6. SpringMVC的核心架构示意图<搬代码>
  7. (转)很暧昧的话 最暧昧的话 男女间那些玩火暧昧话
  8. 写在前面的一些话:《Learning OpenCV》中文版 .
  9. 5.2 node实现简单登录功能
  10. 2019最新android实例开发视频教程