我们来解读一下,中心损失,再来看代码。

链接:https://www.cnblogs.com/carlber/p/10811396.html

我们的重点是分析代码,所以定义部分,大家详情参见上面的博客。

代码:

#coding=gbk
'''
Created on 2020年4月20日@author: DELL
'''
import tensorflow as tf
import numpy as npdata = [[1,1,1,1,1],[1,1,2,1,1],[1,1,3,1,1],[1,1,4,1,1],[2,2,2,1,2],[2,2,2,2,2],[2,2,2,3,2],[3,3,3,3,1],[3,3,3,3,2]]label = [0,0,0,0,1,1,1,2,2]data = np.array(data,dtype = 'float32')
label = np.array(label)data = tf.convert_to_tensor(data)
label = tf.convert_to_tensor(label)def center_loss(features, label, alfa, nrof_classes):"""Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"(http://ydwen.github.io/papers/WenECCV16.pdf)"""nrof_features = features.get_shape()[1]centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,initializer=tf.constant_initializer(0), trainable=False)#定义一个全零的centers, [nrof_classes, nrof_features]->(类别数,特征维度)#print(sess.run(centers))label = tf.reshape(label, [-1]) #一维向量centers_batch = tf.gather(centers, label) #[batch_size,nrof_features] #按照label将centers归类,形成的新矩阵维度为 [label_size,nrof_features]diff = (1 - alfa) * (centers_batch - features) #乘上我们的因子alfa [label_size,nrof_features]centers = tf.scatter_sub(centers, label, diff) #按照label用centers - diff,产生本次的centerswith tf.control_dependencies([centers]):#注意这个函数的作用,是限制计算顺序的,即先计算centers,在利用计算好的centers去计算centers_batch以求lossloss = tf.reduce_mean(tf.square(features - centers_batch))return loss, centers,features,centers_batch,features - centers_batchloss, cen, fea, cen_bat,a = center_loss(data,label,0.5,3)sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)print(sess.run(cen))
#print(sess.run(loss))
print(sess.run(fea))
#print(sess.run(cen_bat))
print(sess.run(a))
print(sess.run(fea - cen_bat))
print(sess.run(tf.square(fea - cen_bat)))
print(sess.run(loss))'''验证tf.scatter_sub函数
sess = tf.Session()
ref = tf.Variable([1, 2, 3],dtype = tf.int32)
indices = tf.constant([0, 0, 1, 1],dtype = tf.int32)
updates = tf.constant([9, 10, 11, 12],dtype = tf.int32)
sub = tf.scatter_sub(ref, indices, updates)
with tf.Session() as sess:sess.run(tf.global_variables_initializer())print (sess.run(sub))
'''    

结果:

1.centers:
[[2.  2.  5.  2.  2. ][3.  3.  3.  3.  3. ][3.  3.  3.  3.  1.5]]
2.features:
[[1. 1. 1. 1. 1.][1. 1. 2. 1. 1.][1. 1. 3. 1. 1.][1. 1. 4. 1. 1.][2. 2. 2. 1. 2.][2. 2. 2. 2. 2.][2. 2. 2. 3. 2.][3. 3. 3. 3. 1.][3. 3. 3. 3. 2.]]
3.centers_batch
[[2.  2.  5.  2.  2. ][2.  2.  5.  2.  2. ][2.  2.  5.  2.  2. ][2.  2.  5.  2.  2. ][3.  3.  3.  3.  3. ][3.  3.  3.  3.  3. ][3.  3.  3.  3.  3. ][3.  3.  3.  3.  1.5][3.  3.  3.  3.  1.5]]
4.features - centers_batch
[[-1.  -1.  -4.  -1.  -1. ][-1.  -1.  -3.  -1.  -1. ][-1.  -1.  -2.  -1.  -1. ][-1.  -1.  -1.  -1.  -1. ][-1.  -1.  -1.  -2.  -1. ][-1.  -1.  -1.  -1.  -1. ][-1.  -1.  -1.   0.  -1. ][ 0.   0.   0.   0.  -0.5][ 0.   0.   0.   0.   0.5]]
5.loss
1.4111111

主要用到的函数:1.tf.gather(data,labels),将data按labels扩充

2.tf.scatter_sub(data,label,data_1),按label用data - data_

3.with tf.control_dependencies(): ,限制运算顺序

在实验验证时注意的点是:不要多次sess.run()某个张量涉及到带有依赖关系的张量,比如这里的loss,计算loss时 会 主动更新一次值,导致运算结果出错。原理我还没搞清,日后补上

facenet 中心损失函数(center loss)详解(代码分析)含tf.gather() 和 tf.scatter_sub()函数相关推荐

  1. 【人脸识别】Center Loss详解

    论文题目:<A Discriminative Feature Learning Approach for Deep Face Recognition> 论文地址:http://ydwen. ...

  2. xvid 详解 代码分析 编译等

    1.   Xvid参数详解 众所周知,Mencoder以其极高的压缩速率和不错的画质赢得了很多朋友的认同! 原来用Mencoder压缩Xvid的AVI都是使用Xvid编码器的默认设置,现在我来给大家冲 ...

  3. Python字符串对齐方法(ljust()、rjust()和center())详解

    Python字符串对齐方法(ljust().rjust()和center())详解 Python str 提供了 3 种可用来进行文本对齐的方法,分别是 ljust().rjust() 和 cente ...

  4. x264 代码重点详解 详细分析

    eg mplayer x264 代码重点详解 详细分析 分类: ffmpeg 2012-02-06 09:19 4229人阅读 评论(1) 收藏 举报 h.264codecflv优化initializ ...

  5. Tensorflow 2.x(keras)源码详解之第十二章:keras中的损失函数之BinaryCrossentropy详解

      大家好,我是爱编程的喵喵.双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中.从事机器学习以及相关的前后端开发工作.曾在阿里云.科大讯飞.CCF等比赛获得多次Top名次.现 ...

  6. 五分钟搞懂后缀数组!后缀数组解析以及应用(附详解代码)

    为什么学后缀数组 后缀数组是一个比较强大的处理字符串的算法,是有关字符串的基础算法,所以必须掌握. 学会后缀自动机(SAM)就不用学后缀数组(SA)了?不,虽然SAM看起来更为强大和全面,但是有些SA ...

  7. 『ML笔记』HOG特征提取原理详解+代码

    HOG特征提取原理详解+代码! 文章目录 一. HOG特征介绍 二. HOG算法具体流程+代码 2.1. 图像灰度化和gamma矫正 2.2. 计算图像像素梯度图 2.3. 在8×8的网格中计算梯度直 ...

  8. Python|SQL详解之DDL|DML|DQL|DCL|索引|视图、函数和过程|JSON类型|窗口函数|接入MySQL|清屏|正则表达式|executemany|语言基础50课:学习(14)

    文章目录 系列目录 原项目地址 第41课:SQL详解之DDL 建库建表 删除表和修改表 第42课:SQL详解之DML insert操作 delete 操作 update 操作 完整的数据 第43课:S ...

  9. c语言 字符串 strncpy,详解c语言中的 strcpy和strncpy字符串函数使用

    详解c语言中的 strcpy和strncpy字符串函数使用 strcpy 和strcnpy函数--字符串复制函数. 1.strcpy函数 函数原型:char *strcpy(char *dst,cha ...

  10. 安卓通知栏管理详解及分析 NotificationListenerService

    NotificationListenerService 安卓通知栏管理详解及分析 一. 方法概述 在api 18前可以通过辅助功能'AccessibilityEvent.TYPE_NOTIFICATI ...

最新文章

  1. 力扣:15三数之和(python)
  2. IOS学习笔记之二十二(文件io)
  3. jtoken判断是否包含键_Redis 数据库、键过期的实现
  4. python怎么背景实现循环_在Python的一段程序中如何使用多次事件循环详解
  5. java this()函数_Java经典面试题之(如何正确的使用this?)
  6. DOS批处理中对含有特殊字符的文件名的处理方法
  7. StyleAI:色调、感情色彩量化、色彩交流API-PCCS颜色体系
  8. win10玩我的世界java_我的世界win10java下载
  9. 别把职场当官斗,聪明人都在自我成长
  10. Win10 下搭建PHP开发环境(自定义方式)
  11. 用AR.js做图片追踪的webAR Demo
  12. 十个经典的Android开源项目
  13. sendgrid html text,包括里面sendgrid鄂麦邮件的内容我的HTML代码
  14. Python遇到过得text和text()
  15. ubuntu 18.04网络图标消失不见解决方法
  16. Lesson 20 Pioneer pilots 内容鉴赏
  17. Linux网卡固件,CentOS下X710网卡升级驱动和固件脚本 | 聂扬帆博客
  18. ORACLE表格操作图文教学二(分组去重、计数、加减、多表)
  19. jQuery 如何得到 scrollHeight 的值
  20. ubuntu安装mysql-python报错

热门文章

  1. C++ sizeof总结
  2. 负数的十进制转二进制
  3. Windows中的权限设置、文件压缩、文件加密、磁盘配额和卷影副本
  4. 计算机考试考前准备,考前必看如何正确准备计算机等级考试 -电脑资料
  5. [python爬虫] Selenium高级篇之窗口移动、弹出对话框自登录
  6. iOS之深入解析CocoaPods的GitLab CI与组件自动化构建与发布
  7. 13.1.2 WEB应用程序
  8. Exp4 恶意代码分析 20164309
  9. Keil 5安装激活教程
  10. 【STM32】STM32F4系统架构