文章目录

  • tf.keras.Model类
    • 1. 创建一个tf.keras.Model类实例的方法
      • 1.1 通过指定输入输出进行实例化
      • 1.2 通过继承Model类进行实例化
    • 2. tf.Keras.Model类属性
    • 3. compile方法
    • 4. evaluate方法
    • 5. evaluate_generator方法
    • 6. fit方法
    • 7. fit_generator方法
    • 8. predict方法
    • 9. predict_generator方法
    • 10. train_on_batch方法
    • 11. predict_on_batch方法
    • 12. test_on_batch方法
    • 13. reset_metrics方法
    • 14. reset_states方法
    • 15. save方法
    • 16. save_weights方法
    • 17. load_weights方法
    • 18. summary方法
    • 19. to_json
    • 20. to_yaml方法
    • 21. get_layer方法
    • 20. to_yaml方法
    • 21. get_layer方法

tf.keras.Model类

tf.keras.Model类将定义好的网络结构封装入一个对象,用于训练、测试和预测。在这一块中,有两部分内容目前我还有疑惑,一个是xxx_on_batch三个方法,为什么要单独定义这个方法,而且train_on_batch方法为什么要强调是在单个batch上做梯度更新?第二个疑问是reset_metricsreset_states函数有什么用,重置不重置对训练过程有什么影响?这两个问题在后续的学习中还需要重点理解。

1. 创建一个tf.keras.Model类实例的方法

1.1 通过指定输入输出进行实例化

​ 在这种方法中,首先定义好网络,再将网络的输入和输出部分作为参数定义一个Model类对象。

import tensorflow as tfinputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

1.2 通过继承Model类进行实例化

​ 在这种方法中,我们需要定义自己的__init__并且在call方法中实现网络的前向传播结构(即在这个方法中定义网络结构)。

import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()

如果我们需要在训练和预测时实现不同的行为,可以在call方法中加入一个布尔类型的值training,当训练时这个参数值为True,进行评估或者预测时(非训练行为),这个参数值为False

import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)self.dropout = tf.keras.layers.Dropout(0.5)def call(self, inputs, training=False):x = self.dense1(inputs)if training:x = self.dropout(x, training=training)return self.dense2(x)model = MyModel()

2. tf.Keras.Model类属性

​ 这个类中包含的属性包括以下几个:

  • input_spec

  • layers:这个参数为一个列表,每个元素为网络中的一层的定义

  • metrics_names

  • run_eagerly:这里指定模型是否是动态运行的。Tensorflow一开始推出时采用的是静态机制,即首先定义模型结构对应的运算图再执行。新版本的Tensorflow引入了动态图机制(Eager execution),这种机制中Tensorflow不用在建立了完整的运算图之后才进行运算,而是可以在每一步运算定义后立即求出具体的数值,这使得Tensorflow的编码和调试过程变得简单,减少了冗余操作(模板化、公式化操作),但这种动态特性在一些情况下无可避免的会增加程序的运行开销(提前定义好的静态图的执行过程可以被更好地优化)。Tensorflow中默认构建的为静态图,要想用动态特性需要调用方法enable_eager_execution(),并且注意这个方法一定要在程序开头调用,否则会报错。只有调用了这个方法,run_eagerly参数才能被设置为True

    '''这里要特别注意,tf.contrib将在 tensorflow 2.0被弃用,所以这种使用动态图特性的方法在tensorflow后面的版本中并不适用,需要大家再去翻翻官方文档或者自行百度找了
    '''
    import tensorflow as tf
    import tensorflow.contrib.eager as tfe# 调用这个方法,使用tensorflow的动态图特性
    tfe.enable_eager_execution()inputs = tf.keras.Input(shape=(3,))
    x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
    outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs, run_eagerly=True)
  • sample_weights

  • state_updates

  • stateful

3. compile方法

​ 指定模型训练时的参数,函数定义如下:

compile(optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,distribute=None,**kwargs
)

参数定义如下:

  • optimizer:模型的优化方法,可选的有Adadelta, Adagrad, Adam, Adamax, FTRL, NAdam, optimizer, RMSprop, SGD。可以用keras自带的,也可以用tensorflow.train中提供的优化方法。
  • loss:模型的损失函数,keras中包含的损失函数如下表所示。如果存在多个输出,可以通过传递一个字典或者列表为每个输出分别指定不同的损失函数。注意字典的key为输出,value为指定的损失函数。如果传递的是列表,则列表中的损失函数应该和输出部分一一对应。
损失函数 含义
BinaryCrossentropy 计算真实标签和预测标签之间的交叉熵损失
CategoricalCrossentropy 计算标签和预测之间的交叉熵损失
CategoricalHinge 计算y_true和y_pred之间的分类铰链损失
CosineSimilarity 计算y_true和y_pred之间的余弦相似度
Hinge 计算y_true和y_pred之间的铰链损耗
Huber 计算y_true和y_pred之间的Huber损失
KLDivergence 计算y_true和y_pred之间的Kullback Leibler差异损失
LogCosh 计算预测误差的双曲余弦的对数
Loss 损失基类
MeanAbsoluteError 计算标签和预测之间的绝对差异的平均值
MeanAbsolutePercentageError 计算y_true和y_pred之间的平均绝对百分比误差
MeanSquaredError 计算标签和预测之间的误差平方的平均值
MeanSquaredLogarithmicError 计算y_true和y_pred之间的均方对数误差
Poisson 计算y_true和y_pred之间的泊松损失
Reduction 减少损失的类型
SparseCategoricalCrossentropy 计算标签和预测之间的交叉熵损失
SquaredHinge 计算y_true和y_pred之间的平方铰链损耗
  • metrics:训练和评价模型时使用的评价指标,同样可以传递一个全局使用的评价指标或者为每一个输出指定一个独立的评价指标, 可能出现的形式有:metrics=['accuracy']metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}metrics=[['accuracy'], ['accuracy', 'mse']]metrics=['accuracy', ['accuracy', 'mse']]
  • loss_weights:一个可选参数,为列表或者字典类型,指定每个输出在计算各自损失时的权重。模型的最终输出为每个输出对应的损失值的加权平均。这是针对单个样本的损失来说的,指定这个参数,是为了在单个样本输入时,对样本的损失在每个损失指标上进行加权求和。
  • sample_weight_mode:如果要做的是timestep-wise sample weighting(2D weights,时间步长加权),则这个参数设置为temporal,如果要做的是sample-wise weights(1D weights,采样权重),则这个参数设置为None。如果要为每个输出都指定模式,则需要传入一个字典或者对应的列表。这个参数是针对一个batch来说的,指定的是计算每个batch的损失的加权方式。
  • weighted_metrics:一个列表,指定在训练和测试期间由sample_weight或者class_weight(这两个指标前一个在evaluate方法中指定,后一个指标在fit方法中指定)进行计算和加权的评价指标
  • target_tensors
  • distribute

4. evaluate方法

​ 用于在测试集上进行模型评价,并返回模型样本的损失值和评价指标对应的值,计算是在一个batch上进行的。函数定义如下:

evaluate(x=None,y=None,batch_size=None,verbose=1,sample_weight=None,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
  • x:输入数据,可以是Numpy array, a list of array, tensor, a list of tensors, a dict mapping input names to the corresponding array/tensors
  • y:输入标签
  • batch_size:每个批次的大小。如果数据的生成方法已经自动生成了batch,则这个参数不需要指定
  • verbose:0, 1,默认为1。日志显示,批量输出,你可以控制输出的间隔。
  • sample_weight:用于对样本的损失函数的输出。ifcompile函数中的sample_weight_mode=None,则sample_weight的形状应该为(len(input samples), 1),ifsample_weight_mode=temporal,则sample_weight的形状应该为(samples, sequence_length)
  • steps:程序总的执行步数,每次执行的输入为一个batch
  • callbacks:定义在评估的不同阶段的行为,比如在on_batch_begin时自定义一个行为。这个参数涉及的内容比较多,这里先不做详述
  • max_queue_size:用在generatorkeras.utils.Sequence生成输入时,这两个方法分批次将数据加载入内存中,主要用于数据量过大的情况。这个参数指定加载入内存的批次的数量,默认为10。
  • workers:用在generatorkeras.utils.Sequence生成输入时,指定线程数
  • use_multiprocessing:用在generatorkeras.utils.Sequence生成输入时,指定是否使用多处理器。

5. evaluate_generator方法

​ 这个函数主要用在数据量过大,无法同时加载进内存时,所以需要指定一个数据集,每次加载一个批次的数据进入内存,批次的大小由数据生成函数指定,所以在这个函数中无法也不需要指定batch_size。函数定义如下,

evaluate_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)

​ 函数的参数中除了generator,其他的都和evaluate方法中的一致,而generator参数也是这一函数的核心,这个参数指定一个数据生成器函数或者keras.uitls.Sequence类的实例。

6. fit方法

​ 这个方法指定相关参数用于训练模型,函数定义如下:

fit(x=None,y=None,batch_size=None,epochs=1,verbose=1,callbacks=None,validation_split=0.0,validation_data=None,shuffle=True,class_weight=None,sample_weight=None,initial_epoch=0,steps_per_epoch=None,validation_steps=None,validation_freq=1,max_queue_size=10,workers=1,use_multiprocessing=False,**kwargs
)

参数意义:

  • x:输入
  • y:标注的标签
  • batch_size:批次大小
  • epochs:数据训练的轮数
  • verbose:显示信息的模式。如果为0不输出,如果为1则用进度条输出,如果为2则每个epoch占一行。
  • callbacks:定义在评估的不同阶段的行为。
  • validation_split:取值范围为[0, 1]内的浮点数,表示输入的训练数据中划分出来作为验证集的比例。这部分数据不用于训练模型,而是在每次每个epoch完成后对模型效果进行评价。由于如果使用了evaluate_generatortrain_generatorpredict_generator方法时,无法同时看到所有数据,所以如果要使用这些方法,这个参数不能指定
  • validation_data:验证集数据,这个参数定义了之后会覆盖validation_split。这个参数的输入一般为(x_val, y_val)或者(x_val, y_val, val_sample_weights)
  • shuffle:每个epoch开始训练之前是否需要打乱数据。
  • class_weight:一个字典定义输出中每个类别的权重,用于加权求的样本的损失值。
  • sample_weight:具体参见evaluate方法中的sample_weight参数。
  • initial_epoch:开始训练的epoch
  • steps_per_epoch:每个epoch中的迭代次数
  • validation_steps:只有定义了validation_data,才能定义这个参数。
  • validation_freq:进行验证的频率,可以是整数或者元组、列表。如果是整数的话, 表示每隔多少个epoch进行一次验证。如果是元组或者列表,则表示在哪几个epoch进行验证。eg. validation_freq=2表示每2个epoch进行一次验证;validation_freq=[1, 2, 10]表示在第1个、第2个和第10个epoch进行验证
  • max_queue_size
  • workers
  • use_multiprocessing

7. fit_generator方法

​ 和evaluate_generator相同,都是应用于数据量过大时批量将数据加载入内存。函数定义如下:

fit_generator(generator,steps_per_epoch=None,epochs=1,verbose=1,callbacks=None,validation_data=None,validation_steps=None,validation_freq=1,class_weight=None,max_queue_size=10,workers=1,use_multiprocessing=False,shuffle=True,initial_epoch=0
)

​ 其中的参数和fit方法相同,只是多了一个generator参数用于定义数据加载方法。

8. predict方法

​ 用于对输入做预测,并返回预测的结果。函数定义如下:

predict(x,batch_size=None,verbose=0,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)

9. predict_generator方法

函数定义:

predict_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)

10. train_on_batch方法

​ 在一个单一的batch上对模型的参数进行梯度更新。函数定义如下:

train_on_batch(x,y=None,sample_weight=None,class_weight=None,reset_metrics=True
)

11. predict_on_batch方法

​ 在一个batch上进行预测,返回预测的结果。函数定义如下:

predict_on_batch(x)

12. test_on_batch方法

​ 在一个单一的batch上测试模型。函数定义如下:

test_on_batch(x,y=None,sample_weight=None,reset_metrics=True
)

13. reset_metrics方法

​ 用于重置评价指标。函数形式为:

reset_metrics()

14. reset_states方法

​ 用于重置状态。函数形式为:

reset_states()

15. save方法

​ 用于保存当前模型,包括模型结构、模型参数和优化器的状态(保存下来之后允许模型在中断点重新训练)。从已经保存的模型文件中重新构建模型的方法在tf.keras.models中定义了。函数定义如下:

save(filepath,overwrite=True,include_optimizer=True,save_format=None
)

参数意义:

  • filepath:保存模型的文件路径
  • overwrite:是否默认覆盖文件路径中已有的文件
  • include_optimizer:保存的文件中是否包含优化器的状态
  • save_format:参数值可以为tf或者h5,分别对应Tensorflow SavedModel和HDF5格式

16. save_weights方法

​ 用于保存模型参数。这里需要注意的是,用tf.train.Checkpoint.save保存的模型必须用tf.train.Checkpoint.restore方法重新加载,用tf.keraas.Model.save_weights保存的模型必须用tf.keras.Model.load_weights来重新加载,两者不能混用,因为两者保存的文件的格式并不相同 。函数定义如下,参数的意义和前一个函数相同,不再赘述。

save_weights(filepath,overwrite=True,save_format=None
)

17. load_weights方法

​ 从TensorFlow或HDF5文件加载所有图层权重。函数定义如下:

load_weights(filepath,by_name=False
)

18. summary方法

​ 这个方法用于输出当前构建的网络的情况,包括了每一层的结构和参数数目。函数定义如下:

summary(line_length=None,positions=None,print_fn=None
)

19. to_json

​ 将网络的配置信息保存到一个JSON字符串中并返回,从JSON字符串中加载配置信息的函数在tf.keras.models中定义。函数定义如下:

to_json(**kwargs)

20. to_yaml方法

​ 将网络的配置信息保存到一个yaml字符串中并返回,从yaml字符串中加载配置信息的函数在tf.keras.models中定义。函数定义如下:

to_yaml(**kwargs)

21. get_layer方法

​ 根据名字或者索引得到网络的某层的实例。函数定义如下:

get_layer(name=None,index=None
)

络的配置信息保存到一个JSON字符串中并返回,从JSON字符串中加载配置信息的函数在tf.keras.models中定义。函数定义如下:

to_json(**kwargs)

20. to_yaml方法

​ 将网络的配置信息保存到一个yaml字符串中并返回,从yaml字符串中加载配置信息的函数在tf.keras.models中定义。函数定义如下:

to_yaml(**kwargs)

21. get_layer方法

​ 根据名字或者索引得到网络的某层的实例。函数定义如下:

get_layer(name=None,index=None
)

tf.Keras.Model类总结相关推荐

  1. tf.keras.Model之model.compile

    目录 model.compile的作用 model.compile的示例 tf.keras.Model类可能属于tf中拥有最多方法的类了,也最为常用.为啥?tensorflow就是一种机器学习框架,用 ...

  2. 解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题

    错误描述: 1.保存模型:model.save_weights('./model.h5') 2.脚本重启 3.加载模型:model.load_weights('./model.h5') 4.模型报错: ...

  3. tf.keras.Model之model.fit

    目录 model.fit的作用 model.fit的示例 model.fit的作用 model.fit可用于以指定的迭代次数训练模型.可以设置的参数很多,重点理解黄色标注的参数,这些比较常用. x=N ...

  4. yolov3从头实现(四)-- darknet53网络tf.keras搭建

    darknet53网络tf.keras搭建 一.定义darknet块类 1 .darknet块网络结构 2.darknet块实现 # 定义darknet块类 class _ResidualBlock( ...

  5. Tensorflow学习之tf.keras(一) tf.keras.layers.Model(另附compile,fit)

    模型将层分组为具有训练和推理特征的对象. 继承自:Layer, Module tf.keras.Model(*args, **kwargs ) 参数 inputs 模型的输入:keras.Input ...

  6. 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)

    日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...

  7. TensorFlow 2.7 正式版上线,改进 TF/Keras 调试,支持 Jax 模型到 TensorFlow Lite转换

    点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 转自 | 机器之心 TensorFlow2.7 正式发布,新版本包括对 tf.kera ...

  8. 使用估算器、tf.keras 和 tf.data 进行多 GPU 训练

    文 / Zalando Research 研究科学家 Kashif Rasul 来源 | TensorFlow 公众号 与大多数 AI 研究部门一样,Zalando Research 也意识到了对创意 ...

  9. Keras vs tf.keras: 在TensorFlow 2.0中有什么区别?

    导读 在本文中,您将发现Keras和tf.keras之间的区别,包括TensorFlow 2.0中的新增功能. 万众期待的TensorFlow 2.0于9月30日正式发布. 虽然肯定是值得庆祝的时刻, ...

最新文章

  1. 2015 年最受 Linux 爱好者欢迎的软硬件大盘点
  2. VC中使用Matlab Engine出现无法找到libeng.dll的问题
  3. 未来,大脑扫描背包将神经科学带入现实世界
  4. 阿里云服务器(BT面板)Vue+Node(Egg)部署流程
  5. python自动客服排班_使用或工具的护士排班问题,在某些日子增加不同的轮班时间...
  6. 19.复习:一般过去时、过去进行时和过去完成时
  7. Java 字符串替换String.replaceAll需注意
  8. 微信小程序 购物车代码
  9. win10下安装deepin双系统教程
  10. 使用计算机组成原理全加器设计,杭电计算机组成原理全加器设计实验1
  11. QQ空间抢车位刷钱方法汇总
  12. iptables屏蔽QQ与MSN
  13. 技能梳理7@stm32+OLED+flash掉电保存+按键
  14. 集五福主题的微信图文排版攻略已到!
  15. Diffusion Models专栏文章汇总:入门与实战
  16. IDEA快速移动光标到行首或行尾;
  17. php5.0 cms安装教程,MySQL_KingCMS5.0从安装到设置使用教程,1.首先到KingCMS官方下载KingCMS5.0 - phpStudy...
  18. 如何批量获取企业工商信息?
  19. 性格内向的人,是否适合做产品经理 ?
  20. Android 跳转到新浪微博

热门文章

  1. shopify二次开发 产品详情页面的开发一(结构布局)
  2. jvav是什么梗?jvav是什么?jvav史上最牛语言
  3. 测试中的Right-BICEP
  4. 都2022年你还不会安装系统?看我三分钟完事PE制作并进行Win11系统安装实践
  5. jsf服务_JSF ManagedBean ManagedProperty
  6. 微信默认表情符号的代码对照表(微信公众号使用到)
  7. 一键加速索尼相机SD卡文件的复制操作,文件操作批处理教程
  8. 使用ls筛选某一天的文件
  9. 解决MAC系统升级后虚拟机黑屏问题
  10. 集结Android开发里的各种大神