Center Loss的Pytorch实现

  • 开始
  • 结果
  • 在自己的项目中使用中心损失函数

Center Loss的Pytorch实现: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016.
这个损失函数也被使用在: deep-person-reid
github项目: https://github.com/KaiyangZhou/pytorch-center-loss

开始

Clone this repo and run the code.

$ git clone https://github.com/KaiyangZhou/pytorch-center-loss
$ cd pytorch-center-loss
$ python main.py --eval-freq 1 --gpu 0 --save-dir log/ --plot

You will see the following info in your terminal.

Currently using GPU: 0
Creating dataset: mnist
Creating model: cnn
==> Epoch 1/100
Batch 50/469     Loss 2.332793 (2.557837) XentLoss 2.332744 (2.388296) CenterLoss 0.000048 (0.169540)
Batch 100/469    Loss 2.354638 (2.463851) XentLoss 2.354637 (2.379078) CenterLoss 0.000001 (0.084773)
Batch 150/469    Loss 2.361732 (2.434477) XentLoss 2.361732 (2.377962) CenterLoss 0.000000 (0.056515)
Batch 200/469    Loss 2.336701 (2.417842) XentLoss 2.336700 (2.375455) CenterLoss 0.000001 (0.042386)
Batch 250/469    Loss 2.404814 (2.407015) XentLoss 2.404813 (2.373106) CenterLoss 0.000001 (0.033909)
Batch 300/469    Loss 2.338753 (2.398546) XentLoss 2.338752 (2.370288) CenterLoss 0.000001 (0.028258)
Batch 350/469    Loss 2.367068 (2.390672) XentLoss 2.367059 (2.366450) CenterLoss 0.000009 (0.024221)
Batch 400/469    Loss 2.344178 (2.384820) XentLoss 2.344142 (2.363620) CenterLoss 0.000036 (0.021199)
Batch 450/469    Loss 2.329708 (2.379460) XentLoss 2.329661 (2.360611) CenterLoss 0.000047 (0.018848)
==> Test
Accuracy (%): 10.32  Error rate (%): 89.68
... ...
==> Epoch 30/100
Batch 50/469     Loss 0.141117 (0.155986) XentLoss 0.084169 (0.091617) CenterLoss 0.056949 (0.064369)
Batch 100/469    Loss 0.138201 (0.151291) XentLoss 0.089146 (0.092839) CenterLoss 0.049055 (0.058452)
Batch 150/469    Loss 0.151055 (0.151985) XentLoss 0.090816 (0.092405) CenterLoss 0.060239 (0.059580)
Batch 200/469    Loss 0.150803 (0.153333) XentLoss 0.092857 (0.092156) CenterLoss 0.057946 (0.061176)
Batch 250/469    Loss 0.162954 (0.154971) XentLoss 0.094889 (0.092099) CenterLoss 0.068065 (0.062872)
Batch 300/469    Loss 0.162895 (0.156038) XentLoss 0.093100 (0.092034) CenterLoss 0.069795 (0.064004)
Batch 350/469    Loss 0.146187 (0.156491) XentLoss 0.082508 (0.091787) CenterLoss 0.063679 (0.064704)
Batch 400/469    Loss 0.171533 (0.157390) XentLoss 0.092526 (0.091674) CenterLoss 0.079007 (0.065716)
Batch 450/469    Loss 0.209196 (0.158371) XentLoss 0.098388 (0.091560) CenterLoss 0.110808 (0.066811)
==> Test
Accuracy (%): 98.51  Error rate (%): 1.49
... ...

Please run python main.py -h for more details regarding input arguments.

结果

We visualize the feature learning process below.
Softmax only. Left: training set. Right: test set.

Softmax + center loss. Left: training set. Right: test set.

在自己的项目中使用中心损失函数

  1. All you need is the center_loss.py file
from center_loss import CenterLoss
  1. Initialize center loss in the main function
center_loss = CenterLoss(num_classes=10, feat_dim=2, use_gpu=True)
  1. Construct an optimizer for center loss
optimizer_centloss = torch.optim.SGD(center_loss.parameters(), lr=0.5)

Alternatively, you can merge optimizers of model and center loss, like

params = list(model.parameters()) + list(center_loss.parameters())
optimizer = torch.optim.SGD(params, lr=0.1) # here lr is the overall learning rate
  1. Update class centers just like how you update a pytorch model
# features (torch tensor): a 2D torch float tensor with shape (batch_size, feat_dim)
# labels (torch long tensor): 1D torch long tensor with shape (batch_size)
# alpha (float): weight for center loss
loss = center_loss(features, labels) * alpha + other_loss
optimizer_centloss.zero_grad()
loss.backward()
# multiple (1./alpha) in order to remove the effect of alpha on updating centers
for param in center_loss.parameters():param.grad.data *= (1./alpha)
optimizer_centloss.step()

If you adopt the second way (i.e. use one optimizer for both model and center loss), the update code would look like

loss = center_loss(features, labels) * alpha + other_loss
optimizer.zero_grad()
loss.backward()
for param in center_loss.parameters():# lr_cent is learning rate for center loss, e.g. lr_cent = 0.5param.grad.data *= (lr_cent / (alpha * lr))
optimizer.step()

Center Loss的Pytorch实现相关推荐

  1. Domain Adaptation for Object Detection using SE Adaptors and Center Loss 论文翻译

    - 摘要: 尽管人们对目标检测的兴趣日益浓厚,但很少有工作能够解决跨域健壮性这一极其实际的问题,特别是对于自动化应用而言.为了防止域移位导致的性能下降,我们在faster-RCNN的基础上引入了一种无 ...

  2. CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss、Center Loss)简介、使用方法之详细攻略

    CV之FRec之ME/LF:人脸识别中常用的模型评估指标/损失函数(Triplet Loss.Center Loss)简介.使用方法之详细攻略 目录 T1.Triplet Loss 1.英文原文解释 ...

  3. facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数

    我们来解读一下,中心损失,再来看代码. 链接:https://www.cnblogs.com/carlber/p/10811396.html 我们的重点是分析代码,所以定义部分,大家详情参见上面的博客 ...

  4. 深度学习中的损失函数总结以及Center Loss函数笔记

    北京 上海巡回站 | NVIDIA DLI深度学习培训 2018年1月26/1月12日 NVIDIA 深度学习学院 带你快速进入火热的DL领域 阅读全文                        ...

  5. Focal Loss 的Pytorch 实现以及实验

    Focal Loss 的Pytorch 实现以及实验 Focal loss 是 文章 Focal Loss for Dense Object Detection 中提出对简单样本的进行decay的一种 ...

  6. 【人脸识别】Center Loss详解

    论文题目:<A Discriminative Feature Learning Approach for Deep Face Recognition> 论文地址:http://ydwen. ...

  7. Center Loss

    <A Discriminative Feature Learning Approach for Deep Face Recognition>   可鉴别性的特征学习用于人脸识别,对每个类别 ...

  8. Center Loss层

    雷锋网按:本文作者祝浩(皮搋子狐狸),3M Cogent Beijing R&D 高级算法工程师.本硕分别毕业于哈尔滨工业大学机械专业和北京师范大学计算机专业,并于2012年加入3M.14年拿 ...

  9. 目标检测之Loss:Center Loss梯度更新

    转载:https://blog.csdn.net/u014380165/article/details/76946339 最近几年网络效果的提升除了改变网络结构外,还有一群人在研究损失层的改进,这篇博 ...

最新文章

  1. 阿里重金投数梦工场 布局PaaS动了谁的奶酪
  2. AndroidStudio中安装GsonFormat插件并根据json文件生成JavaBean
  3. 为什么阿里巴巴Java开发手册中不允许用Executors去创建线程池?
  4. Python技术分享:如何同时检测多个人脸?
  5. MongoDB基本应用操作整理
  6. 手机腾讯视频软件如何开启护眼功能
  7. 互联网组织的未来:剖析GitHub员工的任性之源
  8. [.net 面向对象程序设计进阶] (2) 正则表达式 (一) 快速入门
  9. Bitmap简单使用及简单解析
  10. S3C2440 进行微秒级、毫秒级延时函数
  11. WEBQQ登陆综合帖
  12. 读研究生的目的之我见
  13. N多计算机精品免费视频下载,不要别后悔啊
  14. 华为鸿蒙支持APP,华为 WATCH 3 已到线下店:预装鸿蒙 HarmonyOS 2,支持安装 App
  15. python分析股票主力_筹码分布及计算原理
  16. 【Ubuntu】Ubuntu如何实现中文输入?
  17. LeetCode844-比较含退格的字符串
  18. Unity 和 Android Studio的SDK接入(新手心得)
  19. 振南的znFAT FAT32文件系统
  20. android更换导航背景,修改TabHost导航高度和背景颜色,tabhost背景颜色,主要是android使用...

热门文章

  1. SAP批次管理先进先出基本后台逻辑
  2. html文档元素大小相关的单位,网页字体单位px、em、%、rem、pt、vm、vh介绍
  3. 近距离无线通信(NFC)技术介绍
  4. 2014c语言二级考试题,2014年3月计算机二级C语言真题及答案
  5. JAVA潜心修炼五天——第4天
  6. 以前常用的攻击软件源代码
  7. 2021年电工杯B体详细思路分析
  8. 百慕大神秘三角神秘事件视频
  9. 在HTML中制作贪吃蛇游戏
  10. 知名插画师走尺,带你走进“薪”世界