度量学习(Metric learning)—— 基于分类损失函数(softmax、交叉熵、cosface、arcface)
class Linear(nn.Module):def __init__(self):super(Linear, self).__init__()self.weight = nn.Parameter(torch.Tensor(2, 10)) # (input,output)nn.init.xavier_uniform_(self.weight)def forward(self, x, label):out = x.mm(self.weight) # 分类器的全连接层loss = F.cross_entropy(out, label) # 合并了softmax 和 交叉熵return out, loss
class ArcMarginProduct(nn.Module):r"""Implement of large margin arc distance: :Args:in_features: size of each input sampleout_features: size of each output samples: norm of input featurem: margincos(theta + m)"""def __init__(self, in_features, out_features, s=30.0, m=0.50, easy_margin=False):super(ArcMarginProduct, self).__init__()self.in_features = in_features #输入特征维度,一般是512self.out_features = out_features #输出维度,是类别数目self.s = s #re-scaleself.m = m #角度惩罚项self.weight = Parameter(torch.FloatTensor(out_features, in_features)) #权重矩阵nn.init.xavier_uniform_(self.weight) #权重矩阵初始化self.easy_margin = easy_marginself.cos_m = math.cos(m)self.sin_m = math.sin(m)self.th = math.cos(math.pi - m)self.mm = math.sin(math.pi - m) * mdef forward(self, input, label):# --------------------------- cos(theta) & phi(theta) ---------------------------# 对应伪代码中的1、2、3行:输入x标准化、输入W标准化和它们之间进行FC层得到cos(theta)cosine = F.linear(F.normalize(input), F.normalize(self.weight))# 计算sin(theta)sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))# 对应伪代码中的5、6行:计算cos(theta+m) = cos(theta)cos(m) - sin(theta)sin(m)phi = cosine * self.cos_m - sine * self.sin_mif self.easy_margin:phi = torch.where(cosine > 0, phi, cosine)else:# 当cos(theta)>cos(pi-m)时,phi=cos(theta)-sin(pi-m)*mphi = torch.where(cosine > self.th, phi, cosine - self.mm)# --------------------------- convert label to one-hot ---------------------------# 对应伪代码中的7行:对label形式进行转换,假设batch为2、有3类的话,即将label从[1,2]转换成[[0,1,0],[0,0,1]]one_hot = torch.zeros(cosine.size(), device='cuda')one_hot.scatter_(1, label.view(-1, 1).long(), 1)# 对应伪代码中的8行:计算公式(6)# -------------torch.where(out_i = {x_i if condition_i else y_i) -------------output = (one_hot * phi) + ((1.0 - one_hot) * cosine) # you can use torch.where if your torch.__version__ is 0.4# 对应伪代码中的9行,进行re-scaleoutput *= self.sreturn output
其实通过归一化+乘固定尺度因子,一方面可以使用余弦距离,一方面先增加难度(归一化),再缓和难度(乘尺度因子),其实特征已经比较好,如上图,实际性能也确实是加margin有提升,但是并没有很多。下面来说明原因。
为什么有提升。我们仍然以四分类,已经归一化且乘尺度因子的情况讨论(其余情况原理类似):输出{x1, x2, x3, x4} 等价于{s * cosθ, x2, x3, x4} 。原始在输出x = {5, 1, 1, 1}时就接近收敛,训练停止,此时改用large margin softmax,第一列的 cosθ 强制变成cos(mθ) (即SphereFace)、 cosθ-m(即cosface)或 cos(θ+m)(即arcface),会使输出减小,其他列保持不变,此时输出可能变成了x = {4, 1, 1, 1},网络又可以继续训练了,也就是增加训练难度,使训练得到的特征映射更好。不同loss的曲线对比,下图来自ArcFace,所有loss都是单调递减的。对比Softmax的 cosθ 曲线,乘性margin的SphereFace对应cos(mθ) 曲线下降最多,训练难度剧增,退火技术也难以收敛,反观加性margin的CosineFace和ArcFace下降较少,训练难度稍微增加,所以更容易收敛。
度量学习(Metric learning)—— 基于分类损失函数(softmax、交叉熵、cosface、arcface)相关推荐
- 度量学习————Metric Learning
度量学习的概念 度量学习 (Metric Learning) == 距离度量学习 (Distance Metric Learning,DML) == 相似度学习 度量学习 是指 距离度量学习,Dist ...
- 度量学习 (Metric Learning) 解读
本文转载于以下博客地址:https://blog.csdn.net/jningwei/article/details/80641184 如有冒犯,还望谅解! Introduction 度量学习 (Me ...
- 深度度量学习 (metric learning deep metric learning )度量函数总结
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/qq_16234613/article/ ...
- 度量学习 (Metric Learning)(一)
度量学习(Metric Learning) 度量(Metric)的定义 在数学中,一个度量(或距离函数)是一个定义集合中元素之间距离的函数.一个具有度量的集合被称为度量空间. 1 为什么要用度量学习 ...
- Pytorch深度学习笔记(02)--损失函数、交叉熵、过拟合与欠拟合
目录 一.损失函数 二.交叉熵损失函数详解 1.交叉熵 2.信息量 3.信息熵 4.相对熵(KL散度) 5.交叉熵 6.交叉熵在单分类问题中的应用 7.总结: 三.过拟合和欠拟合通俗解释 1.过拟合 ...
- softmax交叉熵损失函数深入理解(二)
0.前言 前期博文提到经过两步smooth化之后,我们将一个难以收敛的函数逐步改造成了softmax交叉熵损失函数,解决了原始的目标函数难以优化的问题.Softmax 交叉熵损失函数是目前最常用的分类 ...
- 深度学习自学(二十):SmoothL1 和 Softmax交叉熵
整理的人脸系列学习经验:包括人脸检测.人脸关键点检测.人脸优选.人脸对齐.人脸特征提取等过程总结,有需要的可以参考,仅供学习,请勿盗用.https://blog.csdn.net/TheDayIn_C ...
- 【深度学习】——分类损失函数、回归损失函数、交叉熵损失函数、均方差损失函数、损失函数曲线、
目录 代码 回归问题的损失函数 分类问题的损失函数 1. 0-1损失 (zero-one loss) 2.Logistic loss 3.Hinge loss 4.指数损失(Exponential l ...
- 将“softmax+交叉熵”推广到多标签分类问题
©PaperWeekly 原创 · 作者|苏剑林 单位|追一科技 研究方向|NLP.神经网络 一般来说,在处理常规的多分类问题时,我们会在模型的最后用一个全连接层输出每个类的分数,然后用 softma ...
最新文章
- 渣渣菜鸡的 ElasticSearch 源码解析 —— 启动流程(上)
- 五大因素推动中国AI崛起,生态报告概览中国AI产业 By 机器之心2017年7月17日 12:51 中国的人工智能将会在全世界扮演什么样的角色?最近,风险投资机构Vertex发表了一份生态研究报告
- jquery-基础事件[下]
- It's not a Bug, it's a Feature! UVA - 658 (最短路)
- 极简桌面 android 2.3,极简桌面(手机桌面)V3.1 for android 免费版
- 大数据分析中的四大数据类型
- c盘java文件误删_C盘的文件被误删如何恢复
- IE(11)浏览器清理缓存方法
- Android JNI开发笔记二:动态库和静态库
- init: wait for '/dev/block/bootdevice/by-name/cache' timed out and took 5007ms【学习笔记】
- 修改Win10 C盘用户文件夹名称
- 7z解压crc错误_rar文件解压缩失败解压末端出现错误的解决方法
- Unhandled kernel unaligned access问题记录
- 爱学术,让论文写作不再难!
- matlab中circle函数_MATLAB如何用自带函数画圆
- 凯恩帝1000对刀图解_凯恩帝数控机床对刀方法
- HTML粒子旋涡特效代码
- vue-cli+mock.js+axios模拟前后台数据交互
- css compressor java_javascript/css压缩工具---yuicompressor使用方法
- 6-RabbitMQ实战
热门文章
- zip解压多个分卷.z0...文件
- 计算机小学有趣课程,小学计算机教案课程全-20210723192716.docx-原创力文档
- 【网页设计】期末大作业html+css (个人生活记录介绍网站)
- Python编程——shelve模块的使用详解(附实例)
- 全局直方图均衡处理和局部直方图均衡化处理的比较
- angular n'g-zorro走马灯划过时如何停止切换
- Unity3D-在Android平台快速验证功能的更新
- ECS_搭建个人Leanote云笔记本
- linux - resize2fs:新大小太大,无法用32位表示
- 4、MyBatis 框架适用场合: