tf.Keras.Model类总结
文章目录
- tf.keras.Model类
- 1. 创建一个tf.keras.Model类实例的方法
- 1.1 通过指定输入输出进行实例化
- 1.2 通过继承Model类进行实例化
- 2. tf.Keras.Model类属性
- 3. compile方法
- 4. evaluate方法
- 5. evaluate_generator方法
- 6. fit方法
- 7. fit_generator方法
- 8. predict方法
- 9. predict_generator方法
- 10. train_on_batch方法
- 11. predict_on_batch方法
- 12. test_on_batch方法
- 13. reset_metrics方法
- 14. reset_states方法
- 15. save方法
- 16. save_weights方法
- 17. load_weights方法
- 18. summary方法
- 19. to_json
- 20. to_yaml方法
- 21. get_layer方法
- 20. to_yaml方法
- 21. get_layer方法
tf.keras.Model类
tf.keras.Model
类将定义好的网络结构封装入一个对象,用于训练、测试和预测。在这一块中,有两部分内容目前我还有疑惑,一个是xxx_on_batch
三个方法,为什么要单独定义这个方法,而且train_on_batch
方法为什么要强调是在单个batch上做梯度更新?第二个疑问是reset_metrics
和reset_states
函数有什么用,重置不重置对训练过程有什么影响?这两个问题在后续的学习中还需要重点理解。
1. 创建一个tf.keras.Model类实例的方法
1.1 通过指定输入输出进行实例化
在这种方法中,首先定义好网络,再将网络的输入和输出部分作为参数定义一个Model
类对象。
import tensorflow as tfinputs = tf.keras.Input(shape=(3,))
x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
1.2 通过继承Model类进行实例化
在这种方法中,我们需要定义自己的__init__
并且在call
方法中实现网络的前向传播结构(即在这个方法中定义网络结构)。
import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)def call(self, inputs):x = self.dense1(inputs)return self.dense2(x)model = MyModel()
如果我们需要在训练和预测时实现不同的行为,可以在call
方法中加入一个布尔类型的值training
,当训练时这个参数值为True
,进行评估或者预测时(非训练行为),这个参数值为False
。
import tensorflow as tfclass MyModel(tf.keras.Model):def __init__(self):super(MyModel, self).__init__()self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)self.dropout = tf.keras.layers.Dropout(0.5)def call(self, inputs, training=False):x = self.dense1(inputs)if training:x = self.dropout(x, training=training)return self.dense2(x)model = MyModel()
2. tf.Keras.Model类属性
这个类中包含的属性包括以下几个:
input_spec
layers
:这个参数为一个列表,每个元素为网络中的一层的定义metrics_names
run_eagerly
:这里指定模型是否是动态运行的。Tensorflow一开始推出时采用的是静态机制,即首先定义模型结构对应的运算图再执行。新版本的Tensorflow引入了动态图机制(Eager execution),这种机制中Tensorflow不用在建立了完整的运算图之后才进行运算,而是可以在每一步运算定义后立即求出具体的数值,这使得Tensorflow的编码和调试过程变得简单,减少了冗余操作(模板化、公式化操作),但这种动态特性在一些情况下无可避免的会增加程序的运行开销(提前定义好的静态图的执行过程可以被更好地优化)。Tensorflow中默认构建的为静态图,要想用动态特性需要调用方法enable_eager_execution()
,并且注意这个方法一定要在程序开头调用,否则会报错。只有调用了这个方法,run_eagerly
参数才能被设置为True
。'''这里要特别注意,tf.contrib将在 tensorflow 2.0被弃用,所以这种使用动态图特性的方法在tensorflow后面的版本中并不适用,需要大家再去翻翻官方文档或者自行百度找了 ''' import tensorflow as tf import tensorflow.contrib.eager as tfe# 调用这个方法,使用tensorflow的动态图特性 tfe.enable_eager_execution()inputs = tf.keras.Input(shape=(3,)) x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs) outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x) model = tf.keras.Model(inputs=inputs, outputs=outputs, run_eagerly=True)
sample_weights
state_updates
stateful
3. compile方法
指定模型训练时的参数,函数定义如下:
compile(optimizer,loss=None,metrics=None,loss_weights=None,sample_weight_mode=None,weighted_metrics=None,target_tensors=None,distribute=None,**kwargs
)
参数定义如下:
optimizer
:模型的优化方法,可选的有Adadelta
,Adagrad
,Adam
,Adamax
,FTRL
,NAdam
,optimizer
,RMSprop
,SGD
。可以用keras自带的,也可以用tensorflow.train中提供的优化方法。loss
:模型的损失函数,keras中包含的损失函数如下表所示。如果存在多个输出,可以通过传递一个字典或者列表为每个输出分别指定不同的损失函数。注意字典的key为输出,value为指定的损失函数。如果传递的是列表,则列表中的损失函数应该和输出部分一一对应。
损失函数 | 含义 |
---|---|
BinaryCrossentropy | 计算真实标签和预测标签之间的交叉熵损失 |
CategoricalCrossentropy | 计算标签和预测之间的交叉熵损失 |
CategoricalHinge | 计算y_true和y_pred之间的分类铰链损失 |
CosineSimilarity | 计算y_true和y_pred之间的余弦相似度 |
Hinge | 计算y_true和y_pred之间的铰链损耗 |
Huber | 计算y_true和y_pred之间的Huber损失 |
KLDivergence | 计算y_true和y_pred之间的Kullback Leibler差异损失 |
LogCosh | 计算预测误差的双曲余弦的对数 |
Loss | 损失基类 |
MeanAbsoluteError | 计算标签和预测之间的绝对差异的平均值 |
MeanAbsolutePercentageError | 计算y_true和y_pred之间的平均绝对百分比误差 |
MeanSquaredError | 计算标签和预测之间的误差平方的平均值 |
MeanSquaredLogarithmicError | 计算y_true和y_pred之间的均方对数误差 |
Poisson | 计算y_true和y_pred之间的泊松损失 |
Reduction | 减少损失的类型 |
SparseCategoricalCrossentropy | 计算标签和预测之间的交叉熵损失 |
SquaredHinge | 计算y_true和y_pred之间的平方铰链损耗 |
metrics
:训练和评价模型时使用的评价指标,同样可以传递一个全局使用的评价指标或者为每一个输出指定一个独立的评价指标, 可能出现的形式有:metrics=['accuracy']
,metrics={'output_a': 'accuracy', 'output_b': ['accuracy', 'mse']}
,metrics=[['accuracy'], ['accuracy', 'mse']]
,metrics=['accuracy', ['accuracy', 'mse']]
loss_weights
:一个可选参数,为列表或者字典类型,指定每个输出在计算各自损失时的权重。模型的最终输出为每个输出对应的损失值的加权平均。这是针对单个样本的损失来说的,指定这个参数,是为了在单个样本输入时,对样本的损失在每个损失指标上进行加权求和。sample_weight_mode:如果要做的是timestep-wise sample weighting(2D weights,时间步长加权),则这个参数设置为
temporal,如果要做的是sample-wise weights
(1D weights,采样权重),则这个参数设置为None
。如果要为每个输出都指定模式,则需要传入一个字典或者对应的列表。这个参数是针对一个batch来说的,指定的是计算每个batch的损失的加权方式。weighted_metrics
:一个列表,指定在训练和测试期间由sample_weight
或者class_weight
(这两个指标前一个在evaluate
方法中指定,后一个指标在fit
方法中指定)进行计算和加权的评价指标target_tensors
distribute
:
4. evaluate方法
用于在测试集上进行模型评价,并返回模型样本的损失值和评价指标对应的值,计算是在一个batch上进行的。函数定义如下:
evaluate(x=None,y=None,batch_size=None,verbose=1,sample_weight=None,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
x
:输入数据,可以是Numpy array, a list of array, tensor, a list of tensors, a dict mapping input names to the corresponding array/tensorsy
:输入标签batch_size
:每个批次的大小。如果数据的生成方法已经自动生成了batch,则这个参数不需要指定verbose
:0, 1,默认为1。日志显示,批量输出,你可以控制输出的间隔。sample_weight
:用于对样本的损失函数的输出。ifcompile
函数中的sample_weight_mode=None
,则sample_weight
的形状应该为(len(input samples), 1)
,ifsample_weight_mode=temporal
,则sample_weight
的形状应该为(samples, sequence_length)
。steps
:程序总的执行步数,每次执行的输入为一个batchcallbacks
:定义在评估的不同阶段的行为,比如在on_batch_begin
时自定义一个行为。这个参数涉及的内容比较多,这里先不做详述max_queue_size
:用在generator
或keras.utils.Sequence
生成输入时,这两个方法分批次将数据加载入内存中,主要用于数据量过大的情况。这个参数指定加载入内存的批次的数量,默认为10。workers
:用在generator
或keras.utils.Sequence
生成输入时,指定线程数use_multiprocessing
:用在generator
或keras.utils.Sequence
生成输入时,指定是否使用多处理器。
5. evaluate_generator方法
这个函数主要用在数据量过大,无法同时加载进内存时,所以需要指定一个数据集,每次加载一个批次的数据进入内存,批次的大小由数据生成函数指定,所以在这个函数中无法也不需要指定batch_size
。函数定义如下,
evaluate_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)
函数的参数中除了generator
,其他的都和evaluate
方法中的一致,而generator
参数也是这一函数的核心,这个参数指定一个数据生成器函数或者keras.uitls.Sequence
类的实例。
6. fit方法
这个方法指定相关参数用于训练模型,函数定义如下:
fit(x=None,y=None,batch_size=None,epochs=1,verbose=1,callbacks=None,validation_split=0.0,validation_data=None,shuffle=True,class_weight=None,sample_weight=None,initial_epoch=0,steps_per_epoch=None,validation_steps=None,validation_freq=1,max_queue_size=10,workers=1,use_multiprocessing=False,**kwargs
)
参数意义:
x
:输入y
:标注的标签batch_size
:批次大小epochs
:数据训练的轮数verbose
:显示信息的模式。如果为0不输出,如果为1则用进度条输出,如果为2则每个epoch占一行。callbacks
:定义在评估的不同阶段的行为。validation_split
:取值范围为[0, 1]
内的浮点数,表示输入的训练数据中划分出来作为验证集的比例。这部分数据不用于训练模型,而是在每次每个epoch完成后对模型效果进行评价。由于如果使用了evaluate_generator
、train_generator
、predict_generator
方法时,无法同时看到所有数据,所以如果要使用这些方法,这个参数不能指定validation_data
:验证集数据,这个参数定义了之后会覆盖validation_split
。这个参数的输入一般为(x_val, y_val)
或者(x_val, y_val, val_sample_weights)
shuffle
:每个epoch开始训练之前是否需要打乱数据。class_weight
:一个字典定义输出中每个类别的权重,用于加权求的样本的损失值。sample_weight
:具体参见evaluate
方法中的sample_weight
参数。initial_epoch
:开始训练的epochsteps_per_epoch
:每个epoch中的迭代次数validation_steps
:只有定义了validation_data
,才能定义这个参数。validation_freq
:进行验证的频率,可以是整数或者元组、列表。如果是整数的话, 表示每隔多少个epoch进行一次验证。如果是元组或者列表,则表示在哪几个epoch进行验证。eg.validation_freq=2
表示每2个epoch进行一次验证;validation_freq=[1, 2, 10]
表示在第1个、第2个和第10个epoch进行验证max_queue_size
workers
use_multiprocessing
7. fit_generator方法
和evaluate_generator
相同,都是应用于数据量过大时批量将数据加载入内存。函数定义如下:
fit_generator(generator,steps_per_epoch=None,epochs=1,verbose=1,callbacks=None,validation_data=None,validation_steps=None,validation_freq=1,class_weight=None,max_queue_size=10,workers=1,use_multiprocessing=False,shuffle=True,initial_epoch=0
)
其中的参数和fit
方法相同,只是多了一个generator
参数用于定义数据加载方法。
8. predict方法
用于对输入做预测,并返回预测的结果。函数定义如下:
predict(x,batch_size=None,verbose=0,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False
)
9. predict_generator方法
函数定义:
predict_generator(generator,steps=None,callbacks=None,max_queue_size=10,workers=1,use_multiprocessing=False,verbose=0
)
10. train_on_batch方法
在一个单一的batch上对模型的参数进行梯度更新。函数定义如下:
train_on_batch(x,y=None,sample_weight=None,class_weight=None,reset_metrics=True
)
11. predict_on_batch方法
在一个batch上进行预测,返回预测的结果。函数定义如下:
predict_on_batch(x)
12. test_on_batch方法
在一个单一的batch上测试模型。函数定义如下:
test_on_batch(x,y=None,sample_weight=None,reset_metrics=True
)
13. reset_metrics方法
用于重置评价指标。函数形式为:
reset_metrics()
14. reset_states方法
用于重置状态。函数形式为:
reset_states()
15. save方法
用于保存当前模型,包括模型结构、模型参数和优化器的状态(保存下来之后允许模型在中断点重新训练)。从已经保存的模型文件中重新构建模型的方法在tf.keras.models
中定义了。函数定义如下:
save(filepath,overwrite=True,include_optimizer=True,save_format=None
)
参数意义:
filepath
:保存模型的文件路径overwrite
:是否默认覆盖文件路径中已有的文件include_optimizer
:保存的文件中是否包含优化器的状态save_format
:参数值可以为tf
或者h5
,分别对应Tensorflow SavedModel和HDF5格式
16. save_weights方法
用于保存模型参数。这里需要注意的是,用tf.train.Checkpoint.save
保存的模型必须用tf.train.Checkpoint.restore
方法重新加载,用tf.keraas.Model.save_weights
保存的模型必须用tf.keras.Model.load_weights
来重新加载,两者不能混用,因为两者保存的文件的格式并不相同 。函数定义如下,参数的意义和前一个函数相同,不再赘述。
save_weights(filepath,overwrite=True,save_format=None
)
17. load_weights方法
从TensorFlow或HDF5文件加载所有图层权重。函数定义如下:
load_weights(filepath,by_name=False
)
18. summary方法
这个方法用于输出当前构建的网络的情况,包括了每一层的结构和参数数目。函数定义如下:
summary(line_length=None,positions=None,print_fn=None
)
19. to_json
将网络的配置信息保存到一个JSON字符串中并返回,从JSON字符串中加载配置信息的函数在tf.keras.models
中定义。函数定义如下:
to_json(**kwargs)
20. to_yaml方法
将网络的配置信息保存到一个yaml字符串中并返回,从yaml字符串中加载配置信息的函数在tf.keras.models
中定义。函数定义如下:
to_yaml(**kwargs)
21. get_layer方法
根据名字或者索引得到网络的某层的实例。函数定义如下:
get_layer(name=None,index=None
)
络的配置信息保存到一个JSON字符串中并返回,从JSON字符串中加载配置信息的函数在tf.keras.models
中定义。函数定义如下:
to_json(**kwargs)
20. to_yaml方法
将网络的配置信息保存到一个yaml字符串中并返回,从yaml字符串中加载配置信息的函数在tf.keras.models
中定义。函数定义如下:
to_yaml(**kwargs)
21. get_layer方法
根据名字或者索引得到网络的某层的实例。函数定义如下:
get_layer(name=None,index=None
)
tf.Keras.Model类总结相关推荐
- tf.keras.Model之model.compile
目录 model.compile的作用 model.compile的示例 tf.keras.Model类可能属于tf中拥有最多方法的类了,也最为常用.为啥?tensorflow就是一种机器学习框架,用 ...
- 解决Tensorflow2.0 tf.keras.Model.load_weights() 报错处理问题
错误描述: 1.保存模型:model.save_weights('./model.h5') 2.脚本重启 3.加载模型:model.load_weights('./model.h5') 4.模型报错: ...
- tf.keras.Model之model.fit
目录 model.fit的作用 model.fit的示例 model.fit的作用 model.fit可用于以指定的迭代次数训练模型.可以设置的参数很多,重点理解黄色标注的参数,这些比较常用. x=N ...
- yolov3从头实现(四)-- darknet53网络tf.keras搭建
darknet53网络tf.keras搭建 一.定义darknet块类 1 .darknet块网络结构 2.darknet块实现 # 定义darknet块类 class _ResidualBlock( ...
- Tensorflow学习之tf.keras(一) tf.keras.layers.Model(另附compile,fit)
模型将层分组为具有训练和推理特征的对象. 继承自:Layer, Module tf.keras.Model(*args, **kwargs ) 参数 inputs 模型的输入:keras.Input ...
- 深度学习框架 TensorFlow:张量、自动求导机制、tf.keras模块(Model、layers、losses、optimizer、metrics)、多层感知机(即多层全连接神经网络 MLP)
日萌社 人工智能AI:Keras PyTorch MXNet TensorFlow PaddlePaddle 深度学习实战(不定时更新) 安装 TensorFlow2.CUDA10.cuDNN7.6. ...
- TensorFlow 2.7 正式版上线,改进 TF/Keras 调试,支持 Jax 模型到 TensorFlow Lite转换
点击上方"AI遇见机器学习",选择"星标"公众号 重磅干货,第一时间送达 转自 | 机器之心 TensorFlow2.7 正式发布,新版本包括对 tf.kera ...
- 使用估算器、tf.keras 和 tf.data 进行多 GPU 训练
文 / Zalando Research 研究科学家 Kashif Rasul 来源 | TensorFlow 公众号 与大多数 AI 研究部门一样,Zalando Research 也意识到了对创意 ...
- Keras vs tf.keras: 在TensorFlow 2.0中有什么区别?
导读 在本文中,您将发现Keras和tf.keras之间的区别,包括TensorFlow 2.0中的新增功能. 万众期待的TensorFlow 2.0于9月30日正式发布. 虽然肯定是值得庆祝的时刻, ...
最新文章
- 2015 年最受 Linux 爱好者欢迎的软硬件大盘点
- VC中使用Matlab Engine出现无法找到libeng.dll的问题
- 未来,大脑扫描背包将神经科学带入现实世界
- 阿里云服务器(BT面板)Vue+Node(Egg)部署流程
- python自动客服排班_使用或工具的护士排班问题,在某些日子增加不同的轮班时间...
- 19.复习:一般过去时、过去进行时和过去完成时
- Java 字符串替换String.replaceAll需注意
- 微信小程序 购物车代码
- win10下安装deepin双系统教程
- 使用计算机组成原理全加器设计,杭电计算机组成原理全加器设计实验1
- QQ空间抢车位刷钱方法汇总
- iptables屏蔽QQ与MSN
- 技能梳理7@stm32+OLED+flash掉电保存+按键
- 集五福主题的微信图文排版攻略已到!
- Diffusion Models专栏文章汇总:入门与实战
- IDEA快速移动光标到行首或行尾;
- php5.0 cms安装教程,MySQL_KingCMS5.0从安装到设置使用教程,1.首先到KingCMS官方下载KingCMS5.0 - phpStudy...
- 如何批量获取企业工商信息?
- 性格内向的人,是否适合做产品经理 ?
- Android 跳转到新浪微博