用keras搭好模型架构之后的下一步,就是执行编译操作。在编译时,经常需要指定三个参数

  • loss
  • optimizer
  • metrics

这三个参数有两类选择:

  • 使用字符串
  • 使用标识符,如keras.losses,keras.optimizers,metrics包下面的函数

例如:

sgd = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(loss='categorical_crossentropy',optimizer=sgd,metrics=['accuracy'])

因为有时可以使用字符串,有时可以使用标识符,令人很想知道背后是如何操作的。下面分别针对optimizer,loss,metrics三种对象的获取进行研究。

optimizer

一个模型只能有一个optimizer,在执行编译的时候只能指定一个optimizer。
在keras.optimizers.py中,有一个get函数,用于根据用户传进来的optimizer参数获取优化器的实例:

def get(identifier):# 如果后端是tensorflow并且使用的是tensorflow自带的优化器实例,可以直接使用tensorflow原生的优化器 if K.backend() == 'tensorflow':# Wrap TF optimizer instancesif isinstance(identifier, tf.train.Optimizer):return TFOptimizer(identifier)# 如果以json串的形式定义optimizer并进行参数配置if isinstance(identifier, dict):return deserialize(identifier)elif isinstance(identifier, six.string_types):# 如果以字符串形式指定optimizer,那么使用优化器的默认配置参数config = {'class_name': str(identifier), 'config': {}}return deserialize(config)if isinstance(identifier, Optimizer):# 如果使用keras封装的Optimizer的实例return identifierelse:raise ValueError('Could not interpret optimizer identifier: ' +str(identifier))

其中,deserilize(config)函数的作用就是把optimizer反序列化制造一个实例。

loss

keras.losses函数也有一个get(identifier)方法。其中需要注意以下一点:

如果identifier是可调用的一个函数名,也就是一个自定义的损失函数,这个损失函数返回值是一个张量。这样就轻而易举的实现了自定义损失函数。除了使用str和dict类型的identifier,我们也可以直接使用keras.losses包下面的损失函数。

def get(identifier):if identifier is None:return Noneif isinstance(identifier, six.string_types):identifier = str(identifier)return deserialize(identifier)if isinstance(identifier, dict):return deserialize(identifier)elif callable(identifier):return identifierelse:raise ValueError('Could not interpret ''loss function identifier:', identifier)

metrics

在model.compile()函数中,optimizer和loss都是单数形式,只有metrics是复数形式。因为一个模型只能指明一个optimizer和loss,却可以指明多个metrics。metrics也是三者中处理逻辑最为复杂的一个。

在keras最核心的地方keras.engine.train.py中有如下处理metrics的函数。这个函数其实就做了两件事:

  • 根据输入的metric找到具体的metric对应的函数
  • 计算metric张量

在寻找metric对应函数时,有两种步骤:

  • 使用字符串形式指明准确率和交叉熵
  • 使用keras.metrics.py中的函数
def handle_metrics(metrics, weights=None):metric_name_prefix = 'weighted_' if weights is not None else ''for metric in metrics:# 如果metrics是最常见的那种:accuracy,交叉熵if metric in ('accuracy', 'acc', 'crossentropy', 'ce'):# custom handling of accuracy/crossentropy# (because of class mode duality)output_shape = K.int_shape(self.outputs[i])# 如果输出维度是1或者损失函数是二分类损失函数,那么说明是个二分类问题,应该使用二分类的accuracy和二分类的的交叉熵if (output_shape[-1] == 1 orself.loss_functions[i] == losses.binary_crossentropy):# case: binary accuracy/crossentropyif metric in ('accuracy', 'acc'):metric_fn = metrics_module.binary_accuracyelif metric in ('crossentropy', 'ce'):metric_fn = metrics_module.binary_crossentropy# 如果损失函数是sparse_categorical_crossentropy,那么目标y_input就不是one-hot的,所以就需要使用sparse的多类准去率和sparse的多类交叉熵elif self.loss_functions[i] == losses.sparse_categorical_crossentropy:# case: categorical accuracy/crossentropy# with sparse targetsif metric in ('accuracy', 'acc'):metric_fn = metrics_module.sparse_categorical_accuracyelif metric in ('crossentropy', 'ce'):metric_fn = metrics_module.sparse_categorical_crossentropyelse:# case: categorical accuracy/crossentropyif metric in ('accuracy', 'acc'):metric_fn = metrics_module.categorical_accuracyelif metric in ('crossentropy', 'ce'):metric_fn = metrics_module.categorical_crossentropyif metric in ('accuracy', 'acc'):suffix = 'acc'elif metric in ('crossentropy', 'ce'):suffix = 'ce'weighted_metric_fn = weighted_masked_objective(metric_fn)metric_name = metric_name_prefix + suffixelse:# 如果输入的metric不是字符串,那么就调用metrics模块获取metric_fn = metrics_module.get(metric)weighted_metric_fn = weighted_masked_objective(metric_fn)# Get metric name as stringif hasattr(metric_fn, 'name'):metric_name = metric_fn.nameelse:metric_name = metric_fn.__name__metric_name = metric_name_prefix + metric_namewith K.name_scope(metric_name):metric_result = weighted_metric_fn(y_true, y_pred,weights=weights,mask=masks[i])# Append to self.metrics_names, self.metric_tensors,# self.stateful_metric_namesif len(self.output_names) > 1:metric_name = self.output_names[i] + '_' + metric_name# Dedupe namej = 1base_metric_name = metric_namewhile metric_name in self.metrics_names:metric_name = base_metric_name + '_' + str(j)j += 1self.metrics_names.append(metric_name)self.metrics_tensors.append(metric_result)# Keep track of state updates created by# stateful metrics (i.e. metrics layers).if isinstance(metric_fn, Layer) and metric_fn.stateful:self.stateful_metric_names.append(metric_name)self.stateful_metric_functions.append(metric_fn)self.metrics_updates += metric_fn.updates

无论怎么使用metric,最终都会变成metrics包下面的函数。当使用字符串形式指明accuracy和crossentropy时,keras会非常智能地确定应该使用metrics包下面的哪个函数。因为metrics包下的那些metric函数有不同的使用场景,例如:

  • 有的处理的是one-hot形式的y_input(数据的类别),有的处理的是非one-hot形式的y_input
  • 有的处理的是二分类问题的metric,有的处理的是多分类问题的metric

当使用字符串“accuracy”和“crossentropy”指明metric时,keras会根据损失函数、输出层的shape来确定具体应该使用哪个metric函数。在任何情况下,直接使用metrics下面的函数名是总不会出错的。

keras.metrics.py文件中也有一个get(identifier)函数用于获取metric函数。

def get(identifier):if isinstance(identifier, dict):config = {'class_name': str(identifier), 'config': {}}return deserialize(config)elif isinstance(identifier, six.string_types):return deserialize(str(identifier))elif callable(identifier):return identifierelse:raise ValueError('Could not interpret ''metric function identifier:', identifier)

如果identifier是字符串或者字典,那么会根据identifier反序列化出一个metric函数。
如果identifier本身就是一个函数名,那么就直接返回这个函数名。这种方式就为自定义metric提供了巨大便利。

keras中的设计哲学堪称完美。

转载于:https://www.cnblogs.com/weiyinfu/p/9783776.html

keras中的loss、optimizer、metrics相关推荐

  1. keras中构造loss曲线图像

  2. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    我的 tensorflow+keras 版本: print(tf.VERSION) # '1.10.0' print(tf.keras.__version__) # '2.1.6-tf' tf.ker ...

  3. CNN在Keras中的实践|机器学习你会遇到的“坑”

    2018-12-16 23:43:37 本文作为上一节<卷积之上的新操作>的补充篇,将会关注一些读者关心的问题,和一些已经提到但并未解决的问题: 到底该如何理解padding中的valid ...

  4. 如何在Keras中训练大型数据集

    https://www.toutiao.com/a6670173759829180936/ 在本文中,我们将讨论如何使用Keras在不适合内存的大数据集上训练我们的深度学习网络. 介绍 深度学习算法优 ...

  5. Python机器学习笔记:深入理解Keras中序贯模型和函数模型

     先从sklearn说起吧,如果学习了sklearn的话,那么学习Keras相对来说比较容易.为什么这样说呢? 我们首先比较一下sklearn的机器学习大致使用流程和Keras的大致使用流程: skl ...

  6. Keras中Callback函数的使用

    回调函数是一组在训练的特定阶段被调用的函数集,你可以使用回调函数来观察训练过程中网络内部的状态和统计信息.通过传递回调函数列表到模型的.fit()中,即可在给定的训练阶段调用该函数集中的函数. [Ti ...

  7. 【小白学习keras教程】十一、Keras中文本处理Text preprocessing

    @Author:Runsen 文章目录 Text preprocessing Tokenization of a sentence One-hot encoding Padding sequences ...

  8. keras中的回调函数

    keras训练 fit(self, x, y, batch_size=32, nb_epoch=10, verbose=1, callbacks=[], validation_split=0.0, v ...

  9. 使用keras进行深度学习_如何在Keras中通过深度学习对蝴蝶进行分类

    使用keras进行深度学习 A while ago I read an interesting blog post on the website of the Dutch organization V ...

最新文章

  1. 车载微信要来了?马化腾:正研发纯语音交互接口
  2. 如何破解安卓手机上的图形锁(九宫格锁)
  3. wgan 不理解 损失函数_AI初识:深度学习中常用的损失函数有哪些?
  4. InstallShield 2012 Spring新功能试用(17): Suite/Advanced UI 和 Advanced UI安装程序支持命令行Log参数...
  5. SAP Spartacus storefront 模块的实现位置
  6. jdk12源代码文件_在JDK 11中启动单文件源代码程序
  7. 首次使用mysql_mysql-8.0.20-winx64_初次使用过程(Win7x64)
  8. 派生类对基类成员的访问控制之公有继承
  9. Android音量设置流程干货版
  10. 选择排序(C++/Java实现)
  11. easyexcell导出专题
  12. vue json对象转数组_Vue优秀表单组件,用Vue构建表单的最简单方法——Vue Formulate
  13. java 方法重载 应用举例,Java中的方法重载应用
  14. oa项目经验描述_项目执行简历中的项目经验怎么写
  15. 车辆颜色识别opencv
  16. cURL – POST请求示例
  17. sql server触发器写法
  18. 考PMP试题的经验和对策
  19. [转载]Eclipse开发工具简介
  20. (转)当AI变成宣传武器:继续深扒大数据公司Cambrige Analytica

热门文章

  1. ThinkPHP3.2.3 语言包切换中英文切换
  2. C#开发MySQL数据库程序时需要注意的几点
  3. highcharts 折线图 和柱状图读取 json值
  4. 【C++】智能指针(auto_ptr,shared_ptr,unique_ptr)及 shared_ptr 强引用原理
  5. Python中threading的join和setDaemon的区别及用法[例子]
  6. [转]2020年2月份Github上最热门的开源项目,速来围观
  7. 从中台、数仓与元数据不为人知的3个角度,看数据管理的生与死
  8. oracle+执行变量语句,ORACLE sql 语句的执行过程(SQL性能调整)
  9. mysql a锁_MYSQL中的锁
  10. python基础语言测试题(10分钟内背熟)