本文主要谈谈自己对step,epoch,batch_size这几个常见参数的理解。
最近在调试模型的时候,发现在使用keras.optimizer.adam时,模型在添加了新的一层2D卷积层后难以收敛,在不调整初始权重矩阵的情况下,想通过衰减学习率来使loss function的收敛性更好。

tf.keras.optimizers.Adam(learning_rate=0.001,beta_1=0.9,beta_2=0.999,epsilon=1e-07,amsgrad=False,name="Adam",**kwargs
)

可以看到,adam这个optimizer在没有其他参数条件的情况下,默认学习率为固定0.001。

为了调整学习率,在keras的文档中找到了下述示例代码,代码的意思很简单,初始学习率为0.01,衰减需要的step为10000,衰减率为0.9,即每次经过10000 steps,学习率就衰减为原来的0.9。

lr_schedule = keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=1e-2,decay_steps=10000,decay_rate=0.9)
optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)

那么这里的step和我们在编译模型时选的epoch与batch_size有什么关系呢。
需要理解的是,在训练模型的过程中,一个step其实指的就是一次梯度更新的过程。例如在每个epoch中有2000个用于训练的图片,我们选取了batch_size=100,那么我们就需要2000 images / 100 (images/step) = 20 steps来完成这个epoch。

换个角度,从神经网络的角度来说,我们都知道机器学习的最终目的,就是最小化Loss function损失函数。L(W)=1K∑t=1Kl(yt,ytetoile)L(W) = \frac{1} {K}\sum_{t=1}^{K}l(y_t,y_{t_{etoile}})L(W)=K1​∑t=1K​l(yt​,ytetoile​​)。 我们会发现这里的loss function 是K组训练数据的平均误差,这里的K其实就是我们在训练模型时选择的batch_size,即将多个训练数据整合到一起,再通过最小化他们的平均误差来优化权重矩阵。那么经过每个batch_size的训练,我们计算梯度,更新权重的过程就称为一个step。

有了对于step的更深刻的认识,我们就可以轻松地根据step自行调整学习率了。

机器学习模型中step与epoch,batch_size之间的关系相关推荐

  1. 一文讲述如何将预测范式引入到机器学习模型中

    作者 | Filip Piekniewski 编译 |ziqi zhang 随着人工智能的持续深入,深度学习技术在多智能体学习.推理系统和推荐系统上取得了很大进展. 对于多智能体来说,预测能力有着关键 ...

  2. 机器学习:贝叶斯和优化方法_Facebook使用贝叶斯优化在机器学习模型中进行更好的实验

    机器学习:贝叶斯和优化方法 I recently started a new newsletter focus on AI education. TheSequence is a no-BS( mea ...

  3. 利用colab保存模型_在Google Colab上训练您的机器学习模型中的“后门”

    利用colab保存模型 Note: This post is for educational purposes only. 注意:此职位仅用于教育目的. In this post, I would f ...

  4. sql语句和java的关系_java中Statement 与 PreparedStatement接口之间的关系和区别

    Statement 和 PreparedStatement之间的关系和区别. 关系:PreparedStatement继承自Statement,都是接口 区别:PreparedStatement可以使 ...

  5. 机器学习模型中的损失函数loss function

    1. 概述 在机器学习算法中,有一个重要的概念就是损失函数(Loss Function).损失函数的作用就是度量模型的预测值f(x)f\left ( \mathbf{x} \right )f(x)与真 ...

  6. 机器学习模型中,偏差与方差的权衡及计算

    衡量一个机器学习模型的性能,可以用偏差和方差作为依据. 一个高偏差的模型,总是会对数据分布做出强假设,比如线性回归.而一个高方差的模型,总是会过度依赖于它的训练集,例如未修剪的决策树.我们希望一个模型 ...

  7. 机器学习模型的衡量指标_在机器学习模型中衡量公平性

    机器学习模型的衡量指标 In our previous article, we gave an in-depth review on how to explain biases in data. Th ...

  8. 机器学习模型中的评价指标

    1.回归模型 1.1 MSE(均方误差) MSE是Mean Square Error的缩写,其计算公式如下: m s e = 1 m ∑ i = 1 m ( y i − y i ^ ) 2 mse=\ ...

  9. 模型中AIC和BIC以及loglikelihood的关系

    目录 1. AIC的解释 2. BIC的解释 3. AIC和BIC的比较 4. 实例演示 4.1 模型1的AIC和BIC 4.2 模型2的AIC和BIC 4.3 模型1和模型2比较 5. LRT似然比 ...

最新文章

  1. ios tableview 滑动到底部
  2. 【字符串操作之】从原字符串中切出一段,返回一个新的字符串→→slice方法...
  3. requestbody接收不到参数_使用Spring MVC解析嵌套参数在三种 ContentType 下的绑定方式...
  4. Xamarin的坑 - 绑定(一) - 拿微信iOS SDK 简单说起
  5. 大数据驱动智能制造 物联网引爆工业革命商机
  6. 解决mysql中表字符集gbk,列字符集Latin1,python查询乱码问题
  7. 中英文对照 —— 英语语法与文法概念
  8. 原生JS大揭秘—原型链
  9. java如何让源码加密还能运行_如何有效防止Java程序源码被人偷窥?
  10. Linux中查看文件夹大小的命令
  11. python语音识别(语音转文字)
  12. JavaWeb 页面跳转方式连接数据库
  13. 高德citycode和国家citycode编码转换
  14. 主机连接服务器的过程
  15. 阿里云实时音视频直播鉴权java代码示例
  16. Quill编辑器介绍及扩展
  17. c++ 3D笔记整理
  18. STM32实现六轴姿态测量陀螺仪模块JY61P(标准库与HAL库实现)
  19. 腾讯管家禁用好压右键进程,影响用户使用,的终极解决办法
  20. HTML页悬浮div的两种方式

热门文章

  1. mysql generaton_Mysql 集成随机唯一id mysql unique number generation
  2. Java ObjectOutputStream flush()方法与示例
  3. ServletContext(核心内容)
  4. web安全---SSRF漏洞
  5. c语言程序设计编程解读,【答题】C语言程序设计问题与解释实验
  6. c语言函数调用数组_第七讲:C语言基础之函数,第二节,实现汉诺塔
  7. 如何在Visual Studio项目中正确添加汇编代码 .
  8. uva 1347——Tour
  9. java毛玻璃_模糊效果(毛玻璃效果)
  10. 软件工程---16.基于构件的软件工程