一个 epoch(代)是指整个数据集正向反向训练一次。它被用来提示模型的准确率并且不需要额外数据。本节我们将讲解 TensorFlow 里的 epochs,以及如何选择正确的 epochs。

下面是训练一个模型 10 代的 TensorFlow 代码:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import numpy as np
from helper import batches  # Helper function created in Mini-batching sectiondef print_epoch_stats(epoch_i, sess, last_features, last_labels):"""Print cost and validation accuracy of an epoch"""current_cost = sess.run(cost,feed_dict={features: last_features, labels: last_labels})valid_accuracy = sess.run(accuracy,feed_dict={features: valid_features, labels: valid_labels})print('Epoch: {:<4} - Cost: {:<8.3} Valid Accuracy: {:<5.3}'.format(epoch_i,current_cost,valid_accuracy))n_input = 784  # MNIST data input (img shape: 28*28)
n_classes = 10  # MNIST total classes (0-9 digits)# Import MNIST data
mnist = input_data.read_data_sets('/datasets/ud730/mnist', one_hot=True)# The features are already scaled and the data is shuffled
train_features = mnist.train.images
valid_features = mnist.validation.images
test_features = mnist.test.imagestrain_labels = mnist.train.labels.astype(np.float32)
valid_labels = mnist.validation.labels.astype(np.float32)
test_labels = mnist.test.labels.astype(np.float32)# Features and Labels
features = tf.placeholder(tf.float32, [None, n_input])
labels = tf.placeholder(tf.float32, [None, n_classes])# Weights & bias
weights = tf.Variable(tf.random_normal([n_input, n_classes]))
bias = tf.Variable(tf.random_normal([n_classes]))# Logits - xW + b
logits = tf.add(tf.matmul(features, weights), bias)# Define loss and optimizer
learning_rate = tf.placeholder(tf.float32)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels))
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate).minimize(cost)# Calculate accuracy
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))init = tf.global_variables_initializer()batch_size = 128
epochs = 10
learn_rate = 0.001train_batches = batches(batch_size, train_features, train_labels)with tf.Session() as sess:sess.run(init)# Training cyclefor epoch_i in range(epochs):# Loop over all batchesfor batch_features, batch_labels in train_batches:train_feed_dict = {features: batch_features,labels: batch_labels,learning_rate: learn_rate}sess.run(optimizer, feed_dict=train_feed_dict)# Print cost and validation accuracy of an epochprint_epoch_stats(epoch_i, sess, batch_features, batch_labels)# Calculate accuracy for test datasettest_accuracy = sess.run(accuracy,feed_dict={features: test_features, labels: test_labels})print('Test Accuracy: {}'.format(test_accuracy))

Running the code will output the following:

Epoch: 0    - Cost: 11.0     Valid Accuracy: 0.204
Epoch: 1    - Cost: 9.95     Valid Accuracy: 0.229
Epoch: 2    - Cost: 9.18     Valid Accuracy: 0.246
Epoch: 3    - Cost: 8.59     Valid Accuracy: 0.264
Epoch: 4    - Cost: 8.13     Valid Accuracy: 0.283
Epoch: 5    - Cost: 7.77     Valid Accuracy: 0.301
Epoch: 6    - Cost: 7.47     Valid Accuracy: 0.316
Epoch: 7    - Cost: 7.2      Valid Accuracy: 0.328
Epoch: 8    - Cost: 6.96     Valid Accuracy: 0.342
Epoch: 9    - Cost: 6.73     Valid Accuracy: 0.36
Test Accuracy: 0.3801000118255615

每个 epoch 都试图走向一个低 cost,得到一个更好的准确率。

模型直到 Epoch 9 准确率都一直有提升,让我们把 epochs 的数字提高到 100。

...
Epoch: 79   - Cost: 0.111    Valid Accuracy: 0.86
Epoch: 80   - Cost: 0.11     Valid Accuracy: 0.869
Epoch: 81   - Cost: 0.109    Valid Accuracy: 0.869
....
Epoch: 85   - Cost: 0.107    Valid Accuracy: 0.869
Epoch: 86   - Cost: 0.107    Valid Accuracy: 0.869
Epoch: 87   - Cost: 0.106    Valid Accuracy: 0.869
Epoch: 88   - Cost: 0.106    Valid Accuracy: 0.869
Epoch: 89   - Cost: 0.105    Valid Accuracy: 0.869
Epoch: 90   - Cost: 0.105    Valid Accuracy: 0.869
Epoch: 91   - Cost: 0.104    Valid Accuracy: 0.869
Epoch: 92   - Cost: 0.103    Valid Accuracy: 0.869
Epoch: 93   - Cost: 0.103    Valid Accuracy: 0.869
Epoch: 94   - Cost: 0.102    Valid Accuracy: 0.869
Epoch: 95   - Cost: 0.102    Valid Accuracy: 0.869
Epoch: 96   - Cost: 0.101    Valid Accuracy: 0.869
Epoch: 97   - Cost: 0.101    Valid Accuracy: 0.869
Epoch: 98   - Cost: 0.1      Valid Accuracy: 0.869
Epoch: 99   - Cost: 0.1      Valid Accuracy: 0.869
Test Accuracy: 0.8696000006198883

从上述输出来看,在 epoch 80 的时候,模型的验证准确率就不提升了。让我们看看提升学习率会怎样。

learn_rate = 0.1

Epoch: 76   - Cost: 0.214    Valid Accuracy: 0.752
Epoch: 77   - Cost: 0.21     Valid Accuracy: 0.756
Epoch: 78   - Cost: 0.21     Valid Accuracy: 0.756
...
Epoch: 85   - Cost: 0.207    Valid Accuracy: 0.756
Epoch: 86   - Cost: 0.209    Valid Accuracy: 0.756
Epoch: 87   - Cost: 0.205    Valid Accuracy: 0.756
Epoch: 88   - Cost: 0.208    Valid Accuracy: 0.756
Epoch: 89   - Cost: 0.205    Valid Accuracy: 0.756
Epoch: 90   - Cost: 0.202    Valid Accuracy: 0.756
Epoch: 91   - Cost: 0.207    Valid Accuracy: 0.756
Epoch: 92   - Cost: 0.204    Valid Accuracy: 0.756
Epoch: 93   - Cost: 0.206    Valid Accuracy: 0.756
Epoch: 94   - Cost: 0.202    Valid Accuracy: 0.756
Epoch: 95   - Cost: 0.2974   Valid Accuracy: 0.756
Epoch: 96   - Cost: 0.202    Valid Accuracy: 0.756
Epoch: 97   - Cost: 0.2996   Valid Accuracy: 0.756
Epoch: 98   - Cost: 0.203    Valid Accuracy: 0.756
Epoch: 99   - Cost: 0.2987   Valid Accuracy: 0.756
Test Accuracy: 0.7556000053882599

看来学习率提升的太多了,最终准确率更低了。准确率也更早的停止了改进。我们还是用之前的学习率,把 epochs 改成 80

Epoch: 65   - Cost: 0.122    Valid Accuracy: 0.868
Epoch: 66   - Cost: 0.121    Valid Accuracy: 0.868
Epoch: 67   - Cost: 0.12     Valid Accuracy: 0.868
Epoch: 68   - Cost: 0.119    Valid Accuracy: 0.868
Epoch: 69   - Cost: 0.118    Valid Accuracy: 0.868
Epoch: 70   - Cost: 0.118    Valid Accuracy: 0.868
Epoch: 71   - Cost: 0.117    Valid Accuracy: 0.868
Epoch: 72   - Cost: 0.116    Valid Accuracy: 0.868
Epoch: 73   - Cost: 0.115    Valid Accuracy: 0.868
Epoch: 74   - Cost: 0.115    Valid Accuracy: 0.868
Epoch: 75   - Cost: 0.114    Valid Accuracy: 0.868
Epoch: 76   - Cost: 0.113    Valid Accuracy: 0.868
Epoch: 77   - Cost: 0.113    Valid Accuracy: 0.868
Epoch: 78   - Cost: 0.112    Valid Accuracy: 0.868
Epoch: 79   - Cost: 0.111    Valid Accuracy: 0.868
Epoch: 80   - Cost: 0.111    Valid Accuracy: 0.869
Test Accuracy: 0.86909999418258667

准确率只到 0.86。这有可能是学习率太高造成的。降低学习率需要更多的 epoch,但是可以最终得到更好的准确率。

深度学习之epoch相关推荐

  1. 深度学习中epoch,batch的概念--笔记

    深度学习中epoch,batch的概念 batch.epoch和iteration是深度学习中几个常见的超参数. (1) batch_ size: 每批数据量的大小.DL通常用SGD的优化算法进行训练 ...

  2. 深度学习中 epoch,[batch size], iterations概念解释

    one epoch:所有的训练样本完成一次Forword运算以及一次BP运算 batch size:一次Forword运算以及BP运算中所需要的训练样本数目,其实深度学习每一次参数的更新所需要损失函数 ...

  3. 【CV】深度学习中Epoch, Batch, Iteration的含义

    Epoch 使用训练集的全部数据样本进行一次训练,称为一次epoch,即所有训练集的样本都在神经网络中进行了一次正向传播和一次反向传播 神经网络中需要有多次epoch,每次epoch中会进行一次更新权 ...

  4. 深度学习之 epoch batch iteration

    知识点 无论是使用yolo3,4 都是一样的过程,例如使用yolo3 去训练的时候,使用参数tran来训练,darknet的好处是可以使用opencv直接来进行模型推理,但是在训练过程中,我们经常会遇 ...

  5. 深度学习概念——Epoch, Batch, Iteration

    目录 定义 示例 Epoch数量多少合适? 定义 Epoch(时期) 所有训练样本在神经网络中都进行了一次正向传播和一次反向传播的过程,称为1个Epoch Batch(批) 将训练样本分为若干个Bat ...

  6. 深度学习(二)——深度学习常用术语解释, Neural Network Zoo, CNN, Autoencoder

    Dropout(续) 除了Dropout之外,还有DropConnect.两者原理上类似,后者只隐藏神经元之间的连接. 总的来说,Dropout类似于机器学习中的L1.L2规则化等增加稀疏性的算法,也 ...

  7. 【深度学习笔记】深度学习中关于epoch

    (1)iteration:表示1次迭代,每次迭代更新1次网络结构的参数: (2)batch_size:1次迭代所使用的样本量: (3)epoch:1个epoch表示过了1遍训练集中的所有样本. 需要补 ...

  8. DL-4 深度学习中的batch_size、epoch、iteration的区别

    (1)batchsize:批大小.在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练: (2)iteration:1个iteration等于使用batchsize个样 ...

  9. 深度学习中batch_size、epoch和iteration的含义

    iteration:1个iteration等于使用batchsize个样本训练一次: epoch:1个epoch等于使用训练集中的全部样本训练一次,通俗的讲epoch的值就是整个数据集被轮几次. 比如 ...

  10. 深度学习中Batch、Iteration、Epoch的概念与区别

    在神经网络训练中,一般采用小批量梯度下降的方式. Batch Epoch Iteration 就是其中的重要的概念.我们要理解懂得它们都是什么以及它们之间的区别. 1.Batch 每次迭代时使用的一批 ...

最新文章

  1. 收缩 tempdb 数据库
  2. C++volatile
  3. IAR下STM32进入HardFault_Handler
  4. es6 字符串模板 随手记
  5. MySQL(6)-----数据类型
  6. TeamViewer(TV)锁屏后黑屏无法远程的解决方法
  7. 微型计算机通常是由控制器等几部分组成,计算机基础试题及答案
  8. 异名一文带你读懂Chrome小恐龙跑酷!
  9. canvas绘制五角星
  10. android启动过程中cpu降频,android省电开发之cpu降频
  11. js函数提升和变量提升_关于在js中提升的真相
  12. matlab成功安装libsvm后,运行程序仍报错“svmtrain has been removed”解决方法记录
  13. 基金入门-基金的分类
  14. EMC测试仪器_EMC测试整改流程及常见问题
  15. gradle-5.4.1-all gradle-6.1.1.all.zip下载包
  16. 操作系统实验之掌握基本SHELL命令(一)
  17. 网页版google语音识别
  18. 初学STM32之看门狗
  19. Ubuntu下快捷键操作
  20. C语言数据结构-2020级ICODING答案分享

热门文章

  1. python基础之模块
  2. php的parent_php中parent::是如何使用的?
  3. android edittext过滤空格,关于android:在EditText中拦截空格键的问题
  4. Redhat8认证考试(第三题)
  5. 计算机英语作文150字,作文试题_150字_英语作文
  6. 车联网智能终端GB/T 32960国标协议规范 、国标新能源车联网终端GB/T32960标准T-BOX应用
  7. VScode透明主题
  8. YOCTO开机画面修改
  9. Flash制作雾效果
  10. 萝卜青菜各有所爱------深谈React和Vue