Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估

  • 前言
  • 二分类 focal loss
  • 多分类 focal loss
  • 测试结果
    • 二分类focal_loss结果
    • 多分类focal_loss结果
  • 总结

前言

最近看了focal loss的文章,正好在做文本分类的项目,一个是Sentence Bert句子匹配,一个是网易云音乐评论的情绪分类。本人用的框架是tensorflow2.0,所以想尝试实践一下focal loss,但是翻遍了网上的文章,不是代码报错就是错误实现。最后就自己根据focal loss的公式写了一个,试跑了代码确认无误。

tensorflow :2.0.0(GPU上跑)
transformers :3.1


二分类 focal loss

from tensorflow.python.ops import array_ops
def binary_focal_loss(target_tensor,prediction_tensor, alpha=0.25, gamma=2):zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)target_tensor = tf.cast(target_tensor,prediction_tensor.dtype)pos_p_sub = array_ops.where(target_tensor > zeros, target_tensor - prediction_tensor, zeros)neg_p_sub = array_ops.where(target_tensor > zeros, zeros, prediction_tensor)per_entry_cross_ent = - alpha * (pos_p_sub ** gamma) * tf.math.log(tf.clip_by_value(prediction_tensor, 1e-8, 1.0)) \- (1 - alpha) * (neg_p_sub ** gamma) * tf.math.log(tf.clip_by_value(1.0 - prediction_tensor, 1e-8, 1.0))return tf.math.reduce_sum(per_entry_cross_ent)

使用方法:

model.compile(optimizer=optimizer,loss=binary_focal_loss,metrics=['acc'])

几个注意的点:

  1. tensorflow2.0 自定义损失函数 默认传入的是y_true,y_pred的格式,所以在自定义损失函数的时候必须是这个顺序,否则会报错没有梯度传入来更新参数。
  2. tf2.0 所有的张量运算放在了tf.math.里,所以要用tf.math.log
  3. 这里模型最后一层用的是sigmiod激活,所以在损失函数里不需要再使用tf.nn.sigmoid来转化logits

多分类 focal loss

def softmax_focal_loss(label,pred,class_num=6, gamma=2):label = tf.squeeze(tf.cast(tf.one_hot(tf.cast(label,tf.int32),class_num),pred.dtype)) pred = tf.clip_by_value(pred, 1e-8, 1.0)w1 = tf.math.pow((1.0-pred),gamma)L =  - tf.math.reduce_sum(w1 * label * tf.math.log(pred))return L

使用方法

bert_ner_model.compile(optimizer=optimizer, loss=softmax_focal_loss,metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

几个注意的点:

  1. 与二分类不同这里没有使用a的系数,我也见过带有a调整系数的多分类公式,但是要么是只在最后起到缩放loss的功能,要么是需要人工更具类别比例传入一个N维列表,考虑到人为加入太多超参数,可能会让调试很不稳定,所以这里直接舍弃了a。
  2. 在使用tf.one_hot转换label的时候(因为这边传入的是sparse的label[1,2,4,1,2,3,5])总是提示是float32类型,类型错误,所这里用tf.cast转化成int类型。经过tf.one_hot转换之后是(None,1,N)(N为你的类别数)所以用tf.squeeze压缩成(None,N)的shape,方便后面的乘法运算。
  3. model.compile时,评估函数如果要看acc,一定要指定SparseCategoricalAccuracy,之前用[‘acc’]发现accu一直上不去,可能是架构无法识别当前的y_pred类型,所以最好指定一下。
  4. 这里模型最后一层用的是softmax激活,所以在损失函数里不需要再使用tf.nn.softmax来转化logits

测试结果

二分类focal_loss结果

此处是在Sentence Bert模型上的测试结果

binary_crossentropy 结果

可以看到用tf自带的binary_crossentropy,训练一轮就已经有过拟合的趋势了,该用focal_loss可以很好的抑制模型过拟合且模型效果也有1个多点的提升。


多分类focal_loss结果

此处是在Roberta评论情绪多分类数据上的结果

发现loss会下降的越来越慢,这是正常的,需要训练的轮次也变多,因为这里对loss乘了(1-pred)**gama的系数所以整体更新速度会变慢。

对比下使用sparse_cross_entropy结果:

发现并没有提升,这可能与我的数据集类别分布比较平衡有关。所以focal_loss的使用场景还是要看自己的数据集情况。


总结

focal loss的使用还需要根据自己的数据集情况来判断,当样本不平衡性较强时使用focal loss会有较好的提升,在多分类上使用focal loss得到的效果目前无法很好的评估。


完整的模型代码之后会专门写一个博客来讲,用 tf2.0.0 + transformers 搭一个Sentence Bert也借鉴了很多pytroch的代码,tf实现比较少,也是自己慢慢摸索出来的。

Tensorlfow2.0 二分类和多分类focal loss实现和在文本分类任务效果评估相关推荐

  1. 二分类交叉熵,多分类交叉熵,focal loss

    1:二分类交叉熵 a) 公式:  ,其中表示网络预测结果,是一个属于(0到1)的值,我们当然希望它们的值很接近1.是真实标签,因为是二分类,所以,的值为0或者1.网络最后一层一般为sigmoid.比如 ...

  2. 天池零基础入门NLP竞赛实战:Task4-基于深度学习的文本分类3-基于Bert预训练和微调进行文本分类

    Task4-基于深度学习的文本分类3-基于Bert预训练和微调进行文本分类 因为天池这个比赛的数据集是脱敏的,无法利用其它已经预训练好的模型,所以需要针对这个数据集自己从头预训练一个模型. 我们利用H ...

  3. 【CV】RetinaNet:使用二分类类别不平衡损失 Focal Loss 实现更好的目标检测

    论文名称:Focal Loss for Dense Object Detection 论文下载:https://arxiv.org/abs/1610.02357 论文年份:ICCV 2017 论文被引 ...

  4. python3spark文本分类_如何用Spark深度集成Tensorflow实现文本分类?

    本篇知识点:Tensorflow编程 CNN相关知识 PySpark相关知识 因为例子较为复杂,我们会假设你不但学习了[Tensorflow基础],而且还自己主动扩展了TF相关的知识,并且根据里面的推 ...

  5. 文本分类 决策树 python_NLTK学习笔记(六):利用机器学习进行文本分类

    关于分类文本,有三个问题 怎么识别出文本中用于明显分类的特征 怎么构建自动分类文本的模型 相关的语言知识 按照这个思路,博主进行了艰苦学习(手动捂脸..) 一.监督式分类:建立在训练语料基础上的分类 ...

  6. 【多标签文本分类】融合CNN-SAM与GAT的多标签文本分类模型

    ·阅读摘要:   在本文中,作者基于CNN.Attention.GAT提出CS-GAT模型,在一些通用数据集上,表现良好. ·参考文献:   [1] 融合CNN-SAM与GAT的多标签文本分类模型   ...

  7. 【文本分类】基于BERT预训练模型的灾害推文分类方法、基于BERT和RNN的新闻文本分类对比

    ·阅读摘要: 两篇论文,第一篇发表于<图学学报>,<图学学报>是核心期刊:第二篇发表于<北京印刷学院学报>,<北京印刷学院学报>没有任何标签. ·参考文 ...

  8. html文本分类输出,构建中文网页分类器对网页进行文本分类

    网络原指用一个巨大的虚拟画面,把所有东西连接起来,也可以作为动词使用.在计算机领域中,网络就是用物理链路将各个孤立的工作站或主机相连在一起,组成数据链路,从而达到资源共享和通信的目的.凡将地理位置不同 ...

  9. python文本分类_手把手教你在Python中实现文本分类.pdf

    手把手教你在Python 中实现文本分类(附代码.数 据集) 引言 文本分类是商业问题中常见的自然语言处理任务,目标是自动将文本文件分到一个 或多个已定义好的类别中.文本分类的一些例子如下: • 分析 ...

最新文章

  1. tomcat在服务器上改了8080的端口之后所带来的问题
  2. TP-GAN 让图像生成再获突破,根据单一侧脸生成正面逼真人脸
  3. xStream完美转换XML、JSON
  4. mysql长连接与短连接
  5. 谈谈“学习”这件事儿
  6. 照猫画虎owin oauth for qq and sina
  7. crmeb java单商户源码java二开文档部署文档H5商城部署文档【5】
  8. docker视频教程 百度云网盘
  9. linux设置ipsan_linux 配置SAN存储-IPSAN
  10. 7-112 约分最简分式
  11. 【图解CDD】利用CANdelaStudio编辑诊断描述CDD文件带你入门到精通
  12. 第一章:J2EE高级软件工程师面试题集
  13. 有没有好用的视频压缩软件?分享几个好用的压缩视频软件
  14. 汉罗塔汉洛塔c++,看不懂ni打我
  15. 笔记本加装固态硬盘,安装Ubuntu
  16. Hie with the Pie
  17. android微信分享长图功能,安卓分享9宫格图片到微信
  18. android分享图片到qq,Android实现截图分享qq,微信
  19. sudo,,sudo-i ,,su的区别
  20. mysql暴力撞库与弱密码检测

热门文章

  1. Reinforcement Learning
  2. python报错 TypeError: an integer is required
  3. nyoj1237 最大岛屿(河南省第八届acm程序设计大赛)
  4. return 和 exit
  5. Bootstrap系列 -- 37. 基础导航样式
  6. flume与Mosquitto的集成
  7. getdc 与getwindowDc的区别,loadbitmap 与loadimage的区别
  8. python 动态类型_python学习--动态类型
  9. 用python计算准确率_Python中计算模型精度的几种方法,Pytorch,中求,准确率
  10. Dev-C++ 5.11安装教程