首先辨析一下概念:

1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的

2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程

在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:

# 方式一

def vae_loss(x, x_decoded_mean):

xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)

kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)

return xent_loss + kl_loss

vae.compile(optimizer='rmsprop', loss=vae_loss)

或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:

# 方式二

# Custom loss layer

class CustomVariationalLayer(Layer):

def __init__(self, **kwargs):

self.is_placeholder = True

super(CustomVariationalLayer, self).__init__(**kwargs)

def vae_loss(self, x, x_decoded_mean_squash):

x = K.flatten(x)

x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)

xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)

kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)

return K.mean(xent_loss + kl_loss)

def call(self, inputs):

x = inputs[0]

x_decoded_mean_squash = inputs[1]

loss = self.vae_loss(x, x_decoded_mean_squash)

self.add_loss(loss, inputs=inputs)

# We don't use this output.

return x

y = CustomVariationalLayer()([x, x_decoded_mean_squash])

vae = Model(x, y)

vae.compile(optimizer='rmsprop', loss=None)

在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数   点击查看metric的设置

注意事项:

1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar

2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错

有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如

# Class weights:

# To balance the difference in occurences of digit class labels.

# 50% of labels that the discriminator trains on are 'fake'.

# Weight = 1 / frequency

cw1 = {0: 1, 1: 1}

cw2 = {i: self.num_classes / half_batch for i in range(self.num_classes)}

cw2[self.num_classes] = 1 / half_batch

class_weights = [cw1, cw2]  # 使得两种loss能够一样重要

discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)

---------------------

作者:ustc_lijia

来源:CSDN

原文:https://blog.csdn.net/xiaojiajia007/article/details/73274669

版权声明:本文为博主原创文章,转载请附上博文链接!

loss 加权_【转载】keras 自定义 loss损失函数, sample在loss上的加权 和 metric相关推荐

  1. navicat mysql 百度云_转载:用navicat连接百度云服务器上的mysql数据库

    原文连接:http://blog.sciencenet.cn/home.php?mod=space&uid=853805&do=blog&quickforward=1& ...

  2. keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score

    keras分类回归的损失函数与评价指标 目标函数 (1)mean_squared_error / mse 均方误差,常用的目标函数,公式为((y_pred-y_true)**2).mean() (2) ...

  3. Keras自定义损失函数出现:ValueError: Unknown loss function: focal_loss

    Keras自定义损失函数出现:ValueError: Unknown loss function: focal_loss 1.软件环境 2.问题描述 3.解决方法 4.结果预览 1.软件环境 Wind ...

  4. keras自定义loss

    loss是model.compile编译时所需的参数之一,可以用损失函数名或者 TensorFlow 符号函数: #损失函数名 model.compile(loss='mean_squared_err ...

  5. 多分类svm的hinge loss公式推导_损失函数—深度学习常见损失函数总结【图像分类|下】...

    点击蓝字关注我们 AI研习图书馆,发现不一样的精彩世界 学习 笔记 常见损失函数总结-图像分类下篇 一.前言 在深度学习中,损失函数扮演着至关重要的角色.通过最小化损失函数,使模型达到收敛状态,减少模 ...

  6. python中dice常见问题_【Pytorch】 Dice系数与Dice Loss损失函数实现

    由于 Dice系数是图像分割中常用的指标,而在Pytoch中没有官方的实现,下面结合网上的教程进行详细实现. 先来看一个我在网上经常看到的一个版本. def diceCoeff(pred, gt, s ...

  7. 从L1 loss到EIoU loss,目标检测边框回归的损失函数一览

    本文转载自知乎,已获作者授权转载. 链接:https://zhuanlan.zhihu.com/p/342991797 目标检测任务的损失函数由Classificition Loss和BBox Reg ...

  8. [转载] python实现语义分割_使用Keras实现深度学习中的一些语义分割模型

    参考链接: Keras中的深度学习-数据预处理 Keras-Sematic-Segmentation 使用Keras实现深度学习中的一些语义分割模型. 配置 tensorflow 1.13.1+ten ...

  9. 在Keras使用center-losss损失函数\Keras自定义损失函数

    目录 1.站在巨人的肩膀上 2.Keras的损失函数 3.在Keras实现center-loss损失函数 3.1.导入库和定义常量 3.2.实现多元分类softmax损失函数 3.3.实现center ...

最新文章

  1. matlab智能算法30个案例分析_赞!继电保护25个事故案例分析总结,值得收藏!...
  2. 腾讯数据中心负责人揭秘:半年时间如何搭好“山洞鹅厂”
  3. java面试题二十七 多线程考题2
  4. c# npoi 公式不计算_建筑行业计算公式大全,钢筋重量计算公式,不收藏吃亏的是你自己...
  5. go设置后端启动_Go语言基础(十四)
  6. C#实现二维码功能,winform 以及 asp.net均可以用
  7. 5G的基站覆盖范围300米,今后边远地区的手机通话怎样保证?
  8. Talking Data副总裁高铎:我们如何赋予大数据生命力
  9. [课堂实践与项目]IOS优先级的计算器
  10. java中AWT如何关闭窗口_java 窗口关闭的六种方法
  11. 手机号段199/198/166,横空出世
  12. Ubuntu 20.04 美化教程
  13. prettier和beautify哪个好用
  14. Spring源码分析总结(二)-Spring AOP 解析aop:aspectj-autoproxy
  15. Spring配置文件中的parent与abstract
  16. 窗口的创建CreateWindow/CreateWindowEx函数使用说明
  17. Eigrp恶意插入路由和致瘫攻击测试(一)
  18. 可编程逻辑控制器(PLC) : 基础、类型和应用
  19. CMPedometer 计步器的使用—— 基于API分析
  20. 爱心宠物诊所系统(实训)

热门文章

  1. 【JAVA 第五章 】课后习题 随机数统计
  2. 【网络编程】中文字符、时间等编码转换
  3. 【Day13】说一下 Vue 组件的通信方式都有哪些?(父子组件,兄弟组件,多级嵌套组件等等)
  4. 【C语言】逗号运算符的使用举例
  5. 数据结构链表代码_代码简介:链表数据结构如何工作
  6. Matlab对图像进行鼠标取点操作及K值聚类分析
  7. 深度学习-超参数和交叉验证
  8. asp.net学习之再论sqlDataSource 1
  9. 计算机项目教学法探讨,【计算机教学论文】项目教学法在计算机教学中的应用(共3594字)...
  10. PyTorch搜索Tensor指定维度的前K大个(K小个)元素--------(torch.topk)命令参数详解及举例