在TensorFLow2中进行神经网络模型的训练主要包括以下几个主要的步骤:

  • 导入相关模块import
  • 准备数据,拆分训练集train、测试集test
  • 搭建神经网络模型model (两种方法:Sequential或自定义模型class)
  • 模型编译model.compile()
  • 模型训练model.fit()
  • 查看模型model.summary()
  • 模型评价
  • 模型预测model.predict()

model.compile()的作用就是为搭建好的神经网络模型设置损失函数loss、优化器optimizer、准确性评价函数metrics。

这些方法的作用分别是:

  • 损失函数和优化器用在反向传播的时候,我们会求损失函数对训练变量的导数,即梯度,然后根据选择的优化器来确定参数更新公式,根据公式对可训练参数进行更新。
  • 准确性评价函数用在评估模型预测的准确性。在模型训练的过程中,我们会记录模型在训练集、验证集上的预测准确性,之后会据此绘制准确率随着训练次数的变化曲线。通过查看和对比训练集、测试集随着训练次数的准确率曲线,我们能发现模型是否是过拟合、欠拟合,或者也可以发现多少轮后可以停止模型训练了。

由上可以看出,神经网络模型建模训练的过程中,核心的灵魂环节就是搭建模型和编译compile了。所以,这是非常非常重要的一个模块。

1、首先,上代码,直观看下model.compile()在神经网络建模中的使用示例

#model.compile()配置模型训练方法
model.compile(  optimizer = tf.keras.optimizers.SGD(lr = 0.1), #使用SGD优化器,学习率为0.1loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = False), #配置损失函数metrics = ['sparse_categorical_accuracy'] #标注准确性评价指标
)

2、解读model.compile()中配置方法

compile(optimizer,  #优化器loss=None,  #损失函数metrics=None,   # ["准确率”]loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

2.1  loss可以是字符串形式给出的损失函数的名字,也可以是函数形式

例如:”mse" 或者 tf.keras.losses.MeanSquaredError()

"sparse_categorical_crossentropy"  或者  tf.keras.losses.SparseCatagoricalCrossentropy(from_logits = False)

损失函数经常需要使用softmax函数来将输出转化为概率分布的形式,在这里from_logits代表是否将输出转为概率分布的形式,为False时表示转换为概率分布,为True时表示不转换,直接输出

2.2  optimizer可以是字符串形式给出的优化器名字,也可以是函数形式,使用函数形式可以设置学习率、动量和超参数

例如:“sgd”   或者   tf.optimizers.SGD(lr = 学习率,

decay = 学习率衰减率,

momentum = 动量参数)

“adagrad"  或者  tf.keras.optimizers.Adagrad(lr = 学习率,

decay = 学习率衰减率)

”adadelta"  或者  tf.keras.optimizers.Adadelta(lr = 学习率,

decay = 学习率衰减率)

“adam"  或者  tf.keras.optimizers.Adam(lr = 学习率,

decay = 学习率衰减率)

2.3 Metrics神经网络模型的准确性评价指标

例如:

"accuracy" : y_ 和 y 都是数值,如y_ = [1] y = [1]  #y_为真实值,y为预测值

“sparse_accuracy":y_和y都是以独热码 和概率分布表示,如y_ = [0, 1, 0], y = [0.256, 0.695, 0.048]

"sparse_categorical_accuracy" :y_是以数值形式给出,y是以 独热码给出,如y_ = [1], y = [0.256 0.695, 0.048]

Tensorflow2 model.compile()理解相关推荐

  1. tensorflow model.compile() 示例

    model.compile()方法用于在配置训练方法时,告知训练时用的优化器.损失函数和准确率评测标准 model.compile(optimizer=tf.keras.optimizers.Adam ...

  2. model.compile中metrics的参数accuracy

    知乎大佬链接 model.compile(optimizer = tf.keras.optimizers.Adam(0.01),loss = tf.keras.losses.SparseCategor ...

  3. tf.keras.Model之model.compile

    目录 model.compile的作用 model.compile的示例 tf.keras.Model类可能属于tf中拥有最多方法的类了,也最为常用.为啥?tensorflow就是一种机器学习框架,用 ...

  4. tensorflow中model.compile()

    model.compile()用来配置模型的优化器.损失函数,评估指标等 里面的具体参数有: compile(optimizer='rmsprop',loss=None,metrics=None,lo ...

  5. STAR: A Structure and Texture Aware Retinex Model论文理解

    STAR:A Structure and Texture Aware Retinex Model论文理解,2020,TIP 引言:在Retinex理论中,较大的导数是由于反射率的变化,而较小的导数出现 ...

  6. 【tensorflow】Sequential 模型方法 compile, model.compile

    Sequential 顺序模型 API - Keras 中文文档 https://keras.io/zh/models/sequential/ Sequential 序贯模型 序贯模型是函数式模型的简 ...

  7. Author Topic Model[ATM理解及公式推导]

    参考论文 Modeling documents with topics Modeling authors with words The author-topic model Gibbs samplin ...

  8. TensorFlow2.0(十一)--理解LSTM网络

    理解LSTM网络 前言 1. 循环神经网络 2. 长期依赖问题 3. LSTM网络 4. LSTM背后的核心思想 5. 单步解析LSTM网络结构 5.1 遗忘门结构 5.2 输入门结构 5.3 输出门 ...

  9. IPMI channel model的理解

    IPMI 1.5V版本后提出了"Channel Model"的概念.通道模型是IPMI标准中IPMI消息传递的公共通道,所有IPMI消息(请求和响应)都是通过通道传递的.IPMI一 ...

  10. springmvc中Model的理解

    spring的Model相当于前端的一个数据库,就好比后端中的user实体类所对应的数据库User,从Model中获取数据比从后端的User实体类中获取数据更加方便. 如下图是一个简单的实例,简单展示 ...

最新文章

  1. 「NLP」用于序列标注问题的条件随机场
  2. python在excel中的应用-python中的excel操作
  3. python微控制器编程从零开始-单片机可以使用Python语言来控制了!
  4. 解决Android Device Chooser 找不到设备问题
  5. 深刻理解Python中的元类(metaclass)以及元类实现单例模式
  6. matlab绘制烟花,[原创]利用MATLAB燃放烟花(礼花)
  7. java弹窗 触发事件_关于ElementUI中MessageBox弹框的取消键盘触发事件(enter,esc)关闭弹窗(执行事件)的解决方法...
  8. python计算最大回撤_最大回撤线性算法实现
  9. 运用Mono.Cecil 反射读取.NET程序集元数据
  10. 视频剪辑计算机配置要求,视频剪辑需要电脑的什么配置
  11. 天若OCR文字识别本地版
  12. 《合约星期五》OKEx BTC季度合约 0726周报
  13. QQ自由幻想刺客的属性点
  14. 解决在串口调试助手中每次复位后只能发送一次数据的问题
  15. 微软搜购诺基亚是这样的吗
  16. php 导出excel (html),php两种导出excel的方法
  17. css实现标题左右横线
  18. ABAP中时间戳的处理
  19. 马晓:Serverless SSR 在人人视频的落地探索
  20. ubuntu16.04安装网易云音乐方法出现问题及解决方法(桌面图标打不开、不能输入中文等问题)

热门文章

  1. 张晓龙2018微信公开课
  2. [RK3399][Android7.1] DDR动态频率调节驱动小结
  3. mysql 内联 外联_sql中的内联和外联(简单用法)
  4. 职场技巧:高效实用的四象限法则
  5. JEECMS——源码下载和安转运行
  6. Android 系列 5.13添加简单光栅动画
  7. java 苹果cms 萌果_苹果maccms8x最新程序会员中心全新美化171模板分享
  8. scrapy技术进阶-URL路径依赖
  9. 使用curl完成POST数据给飞信接口
  10. 聚石塔服务器 微信,聚石塔云服务器