改进后的MCWD算法,让你的弱标注多标签数据赢在起跑线上

  • 前言
  • MCWD算法
    • 算法展示
      • 算法改进
      • 实现代码
  • 实验结果
  • 总结

前言

最近刷完了李航老师的《统计学习与方法》,手痒到又想复现几个算法,正好碰上在云音乐的云村视频标签运维标注不完全问题,也算是弱标注数据吧,之前这比数据作了多标签分类,尽管特征上线后各项数据都有所提升,但总感觉用神经网络直接对弱标签数据进行多标签分类很不舒服。
基于以下两个思考点:

  1. 存在标签缺失的问题,神经网络的意识在于我竟可能相信你给我的数据都是准确的,某个样本有某个标签是准确的,没有某个标签也是准确的。这会导致对于一些有缺失的标签,尽管网络见过相似的特征,但拟合的标签组却大不相同,无法有效学习到标签与特征的关系。
  2. 当标签缺失情况较少,而样本数据量足够多时,神经网络确实能比较好的应付脏数据带来的错误信息,但是对于那些本来就很少被标注,而且缺失情况较多的标签,最后模型在推理时,该类标签的产出也会非常稀有,预测不准确。尽管我们在最后加入了先验:即有效标注越少的标签越有可能被遗漏,对每个标签结果都乘以其IDF来纠正,但效果不是非常明显。这就导致最后输出的样本标签区分度较低,作为特征效果较差。
  3. 比较明显的问题:无法考虑标签之间的关联性,在MLOG方面,举例来说:风景标签与拍摄标签具有较强的关联性,女团标签与舞蹈、演唱等关联性较强,而神经网络无法利用这一信息。这个问题对于文本多标签分类任务已经有了一些解决方案,详情见论文:A Deep Reinforced Sequence-to-Set Model for Multi-Label Text Classification

MCWD算法

其实在寻找办法的过程中也看到了周志华团队的《Learning from Semi-Supervised Weak-Label Data》无奈其中数学公式复杂,尽管看懂了原理,但用python实现起来还是有一定的难度。

最后将目标锁定在了《针对弱标记数据的多标签分类算法(王晶晶,杨有龙)》(原文请大家自行搜索)文章利用了样本间的加权KNN + 标签相关性的二阶策略 对弱标注数据的标签矩阵进行恢复,再通过多标签分类算法进行标注。

算法展示

原论文中算法具体步骤如下:

算法改进

  1. 引入先验,适当增加随机的不相关标签数量: 在算法初始化第二步中,论文对于初始标签矩阵 C 中的每个标签 j ,在 C中随机选择 pj 个 Cij = 0 的实例,同时将选定实例的 Cij 值由原来的0变为−1,其中 pj 是每个标签 j 中所有相关标签的总数目。目的是为了,在训练数据的标签信息中添加一些不相关的标签信息,从而将缺失标签和不相关标签进行有效区分。但在复现过程中,发现对于稀疏的标签,往往由于初始的不相关标签信息不足,容易导致在迭代过程中过分补全。根据上述即有效标注越少的标签其本身先验概率也越低,我将该Pj值调整为—不包含该标签的样本数量 / n ( n为超参数,在实验过程中以 4 为佳)
  2. 在整体迭代结束后根据权重矩阵Wij来有选择的对Cij进行恢复,而是不论文中直接取sign(Cij),Wij * Cij 值高说明该标签越置信,而对于 Wij*Cij 值较低的,倾向于重置为0,让之后的标签相关矩阵去学习。较好的抑制算法过拟合。此处加入 阈值w(超参)
  3. 以及部分对原论文不合理之处的修改,详情见代码。

实现代码

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score,hamming_lossdef data_get_2(path,p):data = pd.read_csv(path)feature = np.array(data.iloc[:,:103])tag = np.array(data.iloc[:,103:])real_tag = tag.copy()for i in range(len(tag)):if len(np.where(tag[i,:] == 1)[0]) > p:index_list = np.random.choice(np.where(tag[i,:] == 1)[0],len(np.where(tag[i, :] == p)[0])-p,replace=False)tag[i,:][index_list] = 0'''随机drop掉每个样本的标签,至每个样本最多拥有p个标签。'''return feature,tag,real_tagdef data_get_3(path,p):data = pd.read_csv(path)feature = np.array(data.iloc[:,1:1187])tag = np.array(data.iloc[:,1187:])real_tag = tag.copy()for i in range(len(tag)):if len(np.where(tag[i,:] == 1)[0]) > p:index_list = np.random.choice(np.where(tag[i,:] == 1)[0],len(np.where(tag[i, :] == p)[0])-p,replace=False)tag[i,:][index_list] = 0return feature,tag,real_tagclass MCWD:def __init__(self,e=0.8,c=0.2,k_t=10,s=1,w_e=0.3,rate=0.5):self.e = eself.c = cself.k_t = k_tself.s = sself.w_e = w_eself.rate = ratedef standetlize(self):'''涉及到距离计算,需要将特征标准化'''for i in range(self.feature_dim):mean = np.mean(self.feature[:,i])sii = 1/(self.sample_num-1) * np.sum((self.feature[:,i] - mean)**2)self.feature[:,i] = (self.feature[:,i]-mean)/siidef add_nagtive_tag(self):'''训练数据的标签信息中添加一些不相关的标签信息'''for j in range(self.tag_c):zeros_way = np.where(self.old_tag_list[:,j] == 0)[0]pj = int(len(zeros_way) * self.rate)if len(zeros_way) < pj:index_list = zeros_wayelse:index_list = np.random.choice(np.where(self.old_tag_list[:,j] == 0)[0],pj,replace=False)self.tag_list[:,j][index_list] = -1def cauclate_dis(self,x1,x2):return np.sum(abs(x1-x2)**2)**(1/2)def get_dis_matrix(self):# self.dis_matrix = np.zeros((self.sample_num,self.sample_num))# for i in range(self.sample_num):#     print(i)#     for j in range(self.sample_num):#         if i < j:#             dis = self.cauclate_dis(self.feature[i],self.feature[j])#             self.dis_matrix[i][j] = dis#             self.dis_matrix[j][i] = dis'''快速计算样本之间的距离矩阵'''G = np.dot(self.feature, self.feature.T)H = np.tile(np.diag(G), (self.sample_num, 1))  # n rows, 1 for each rowD = H + H.T - G * 2self.dis_matrix = np.sqrt(D)def one_iter(self,t):n_near = t * self.k_tfor index in range(self.sample_num):k_near_index = self.dis_matrix[index].argsort()[1:n_near+1]l_w = np.abs(self.tag_W[k_near_index])t_l = self.tag_list[k_near_index]self.tag_list[index] = np.sum(l_w * t_l,0)/(np.sum(l_w,0)+1e-8) #[-1,1]for index in range(self.sample_num):for j in range(self.tag_c):qij = self.tag_list[index,j]wij = self.tag_W[index,j]if np.sign(qij) == np.sign(wij) and np.abs(qij) > self.e and np.abs(wij) > self.e:self.tag_W[index,j] = np.sign(qij)elif np.sign(qij) == -1 and np.sign(self.old_tag_list[index,j]) == 1:self.tag_W[index, j] = self.c * (qij - np.min(self.tag_list[:,j]))/(np.max(self.tag_list[:,j]) - np.min(self.tag_list[:,j]))else:self.tag_W[index, j] = qijif  qij > 0:self.tag_list[index, j] = 1elif qij < 0:self.tag_list[index, j] = -1else:self.tag_list[index, j] = 0self.tag_list[np.where(self.old_tag_list == 1)] = 1def compute(self):'''迭代结束后,根据W和C矩阵对标签矩阵进行恢复和补充'''for index in range(self.sample_num):for j in range(self.tag_c):if self.tag_list[index, j] * self.tag_W[index, j] > self.w_e:self.tag_list[index, j] = np.sign(self.tag_list[index, j])else:self.tag_list[index, j] = 0self.tag_list[np.where(self.old_tag_list == 1)] = 1def get_L_matricx(self):'''计算标签相关矩阵'''self.L_matricx = np.zeros((self.tag_c,self.tag_c))for i in range(self.tag_c):for j in range(self.tag_c):if i >= j:a = np.sum((self.tag_list[:,i] == 1)*(self.tag_list[:,i] == 1)) + self.sb = np.sum(self.tag_list[:,i] == 1) + 2 * self.sc = np.sum(self.tag_list[:,j] == 1) + 2 * self.sself.L_matricx[i][j] = a/bself.L_matricx[j][i] = a/cdef re_20_p(self):'''根据标签相关矩阵补全剩余标签'''self.get_L_matricx()for i,j in zip(np.where(self.tag_list == 0)[0],np.where(self.tag_list == 0)[1]):qij = self.tag_list[i].T.dot(self.L_matricx[:,j])qij = (qij - np.min(self.tag_list[:, j])) / (np.max(self.tag_list[:, j]) - np.min(self.tag_list[:, j]))if qij > 0.5:self.tag_list[i][j] = 1else:self.tag_list[i][j] = -1def fit(self,feature,tag_list,max_iter=100,if_s=True):self.feature = np.array(feature,dtype=float)self.tag_list = np.array(tag_list,dtype=float)self.old_tag_list = np.array(tag_list)self.feature_dim = self.feature.shape[1]self.sample_num = self.feature.shape[0]self.tag_c = self.tag_list.shape[1]if if_s:print('standetlizing')self.standetlize()print('finish')print('add_nagtive_tag')self.add_nagtive_tag()print('finish')self.tag_W = self.tag_list.copy()print('get_dis_matrix')self.get_dis_matrix()print('finish')for iter in range(max_iter):print(iter)self.one_iter(iter+1)num = len(np.where((self.tag_list * self.tag_W) <= self.w_e)[0])print(num/(self.sample_num * self.tag_c))if num < self.sample_num * self.tag_c * 0.2:self.compute()breakself.re_20_p()def main():# ------Yeast数据集--------------------------------------------------feature,tag,real_tag = data_get_2('./yeast.csv',1)mcwd = MCWD(e=0.8,c=0.2,k_t=10,s=1,w_e=0.5,rate=0.25)mcwd.fit(feature,tag)pred_tag = mcwd.tag_listpred_tag[np.where(pred_tag == -1)] = 0print(np.sum(real_tag) - np.sum(tag))print(np.sum(pred_tag[np.where(real_tag==1)]) - np.sum(tag[np.where(real_tag==1)]))print(np.sum(pred_tag[np.where(real_tag==0)]))print("f1_macro")print(f1_score(real_tag,tag, average='macro'))print(f1_score(real_tag,pred_tag,average='macro'))print("f1_micro")print(f1_score(real_tag,tag, average='micro'))print(f1_score(real_tag,pred_tag,average='micro'))print("hamming_loss")print(hamming_loss(real_tag,tag))print(hamming_loss(real_tag,pred_tag))#------genbase数据集--------------------------------------------------feature, tag, real_tag = data_get_3('./file27546ea66c1.csv', 1)mcwd = MCWD(e=0.8, c=0.2, k_t=5, s=1, w_e=0.5, rate=0.25)mcwd.fit(feature,tag,if_s=False)pred_tag = mcwd.tag_listpred_tag[np.where(pred_tag == -1)] = 0print(np.sum(real_tag) - np.sum(tag))print(np.sum(pred_tag[np.where(real_tag == 1)]) - np.sum(tag[np.where(real_tag == 1)]))print(np.sum(pred_tag[np.where(real_tag == 0)]))print("f1_macro")print(f1_score(real_tag,tag, average='macro'))print(f1_score(real_tag,pred_tag,average='macro'))print("f1_micro")print(f1_score(real_tag,tag, average='micro'))print(f1_score(real_tag,pred_tag,average='micro'))print("hamming_loss")print(hamming_loss(real_tag,tag))print(hamming_loss(real_tag,pred_tag))if __name__ == '__main__':main()

实验结果


总结

根据实验结果可以发现,经过算法的补全,与补全前的标签矩阵比,F1值、hamming_loss等指标都有不同程度提升,证明了该算法的有效性。
特别的:

  1. 在Genbase数据集上,算法补全了近2/3的标签,而仅产生了3个错误标签,表现惊人,且各项评估指标都有非常大的提升
  2. 对于标签缺失严重的数据,该算法能控制错误生成在可以接受的程度内,尽可能的补全标签。
  3. 但必须要考虑到错误生成带来的代价,有可能会影响之后的模型训练。但对于像短视频等内容,其标签多而杂,一个短视频所带的标签本没有一个真正准确的标准,因此,这种在内容相关性的逻辑下导致的错误生成,对于问题的解决其实并没有那么严重。
  4. 该算法改进,纯属本人摸索而出,能否应用于工作实践,还请各位斟酌尝试。

【论文复现与改进】针对弱标注数据多标签矩阵恢复问题,改进后的MCWD算法,让你的弱标注多标签数据赢在起跑线上相关推荐

  1. 聊聊数据中心备份和恢复解决方案厂商和产品(附Gartner报告)

    聊聊数据中心备份和恢复解决方案厂商和产品(附Gartner报告) https://www.toutiao.com/i6451888861759930894/?tt_from=weixin&ut ...

  2. 求解答:苹果手机恢复出厂设置后还能还原数据吗?

    案例:手机不小心恢复出厂设置了怎么还原? [友友们,iPhone手机不小心恢复了出厂设置,里面的数据还可以还原吗?要哭了,求帮助!] 随着苹果手机的普及,越来越多的人开始使用苹果手机.但是,有时候我们 ...

  3. 怎样把电脑恢复出厂设置_数据蛙:苹果恢复出厂设置,彻底释放手机内存

    点击蓝字关注我们 苹果手机有时候用久了出现卡顿的现象,这个时候我们可以通过苹果恢复出厂设置,以让手机彻底释放内存.恢复出厂设置在哪里?恢复出厂设置后会怎样?如果您还不知道的话,在本文中数据蛙就告诉您答 ...

  4. 经典论文复现 | 基于标注策略的实体和关系联合抽取

    过去几年发表于各大 AI 顶会论文提出的 400 多种算法中,公开算法代码的仅占 6%,其中三分之一的论文作者分享了测试数据,约 54% 的分享包含"伪代码".这是今年 AAAI ...

  5. 26篇计量经济经典论文复现数据和Stata或R代码

    26篇文章的复现数据.Stata或R复制程序.各位学者可以阅读这些文章,并根据Stata和R代码对原文中的图表进行一一复制,只有这样才能成长更快. 以其中一篇文章为例,包含了以下内容: [26篇论文目 ...

  6. 【项目总结】论文复现与改进:一般选择模型的产品组合优化算法(Research@收益管理)

    论文标题:Assortment Optimization Under General Choice 中文标题:一般选择下的产品组合优化 论文下载链接:SSRN 前情提要 本文是基于笔者之前对上述论文做 ...

  7. 经典论文复现 | LSGAN:最小二乘生成对抗网络

    来源:PaperWeekly 本文约2500字,建议阅读10分钟. 本文介绍了经典AI论文--LSGAN,它比常规GAN更稳定,比WGAN收敛更迅速. 笔者这次选择复现的是 Least Squares ...

  8. 《融合视觉显著性和局部熵的红外弱小目标检测》论文复现

    1.复现论文概要 复现的论文为<融合视觉显著性和局部熵的红外弱小目标检测>(赵鹏鹏,李庶中等,中国光学2022,http://www.chineseoptics.net.cn/cn/art ...

  9. 信息抽取(四)【NLP论文复现】Multi-head Selection和Deep Biaffine Attention在关系抽取中的实现和效果

    Multi-head Selection和Deep Biaffine Attention在关系抽取中的应用 前言 Multi-head Selection 一.Joint entity recogni ...

最新文章

  1. 执行eclipse,迅速failed to create the java virtual machine。
  2. [leetcode-JavaScript]---23、合并K个排序链表
  3. 用 DomIt! XML 处理工作
  4. hdu 1524 A Chess Game
  5. mysql连接查询优点_1105 ROM优缺点,MySQL连接类,查插更方法
  6. GitHub 热点速览:不可思议的浏览器 Browser-2020 周涨 Star 超 3 千
  7. 关于邮件服务器应用系统安全SSL ×××(强身份认证)方案
  8. softmax回归的从零开始实现-09-p4
  9. socket编程:简单的TCP客户端
  10. HmacSHA256算法实现消息认证
  11. 基于VHDL的vivado2017.4使用教程
  12. MIKE水动力笔记1_岸线及水深数据之依靠全球数据库资源提取的方法
  13. linux qt触摸屏配置,QT触摸屏的实现
  14. 如何使用MATLAB绘制平滑曲线
  15. 阿里巴巴(容器镜像服务)docker+springboot实践
  16. IT人才异军突起 有招网引领业界精英
  17. IPV6地址数据库导出
  18. TBC声望 恢复萨满 炼金 宏 附魔300-375
  19. BOC保护的色氨酸锌卟啉(Zn·TAPP-Trp-BOC)/铜卟啉(Cu·TAPP-Trp-BOC)/钴卟啉(钴·TAPP-Trp-BOC)/铁卟啉(Fe·TAPP-Trp-BOC)/齐岳供应
  20. ACM进阶大一到大三

热门文章

  1. spring4.2更好的应用事件
  2. c3p0获取连接Connection后的Close()---释疑
  3. NCspider项目总结
  4. psd页面切割成html技巧总结
  5. 每天看一片代码系列(二):WebSocket-Node
  6. 博客美化20150418
  7. Modulus 正式开放 —— Node.js 应用托管平台
  8. DNN模块开发入门指导
  9. Mime类型与文件后缀对照表及探测文件MIME的方法
  10. 华测服务器进不去系统,华测rtk单点到固定怎么操作步骤