摘要

本文将低级api实现的tensorflow网络移植到高级api上遇到的loss值不变和训练结果不收敛问题

引言

tensorflow版本更新很快,猛一回头发现已经推出更高级的api了

主题

tensorflow高级api

上图是tensorflow软件栈图,我之前学习和实现的网络模型(0.12a)使用的是 低级api, 现在的新版本(1.10)对低级api进行了封装,形成了高级api(estimator keras),所以对原有模型进行了一次api替换

移植

整个代码的移植过程还是比较顺利的,移植完成以后,代码量比以前减少了很多,但移植完成,在运行时却发现了一些奇怪的现象

问题现象

我这里有两个自己的数据集,其中一个数据集(单个数据量130)在移植后的代码上运行正常,另一个数据集(单个数据集60000)在移植后的代码上运行却出现以下问题:

  • 二分类问题的准确率在50%左右
  • 训练过程中loss值会变化到一个固定值,然后就不再变化了

epoch:0Evaluation results:   {'true_negatives': 48.0, 'global_step': 0, 'loss': 0.7953733, 'true_positives': 0.0, 'accuracy': 0.48, 'false_negatives': 52.0, 'false_positives': 0.0}epoch:1Evaluation results: {'true_negatives': 0.0, 'global_step': 1, 'loss': 0.79326165, 'true_positives': 52.0, 'accuracy': 0.52, 'false_negatives': 0.0, 'false_positives': 48.0}epoch:2Evaluation results:    {'true_negatives': 0.0, 'global_step': 2, 'loss': 0.79326165, 'true_positives': 52.0, 'accuracy': 0.52, 'false_negatives': 0.0, 'false_positives': 48.0}epoch:3Evaluation results:    {'true_negatives': 0.0, 'global_step': 3, 'loss': 0.79326165, 'true_positives': 52.0, 'accuracy': 0.52, 'false_negatives': 0.0, 'false_positives': 48.0}epoch:4Evaluation results:    {'true_negatives': 0.0, 'global_step': 4, 'loss': 0.79326165, 'true_positives': 52.0, 'accuracy': 0.52, 'false_negatives': 0.0, 'false_positives': 48.0}epoch:5Evaluation results:    {'true_negatives': 0.0, 'global_step': 5, 'loss': 0.79326165, 'true_positives': 52.0, 'accuracy': 0.52, 'false_negatives': 0.0, 'false_positives': 48.0}

问题分析

因为不同数据集对应的结论不同,所以问题的排查就主要集中在对比两份代码的差异上了。训练正常的代码简称为T代码,训练异常的代码简称为F代码

  • 网络排查

    怀疑模型代码有问题。将F代码尽量用T代码代替,最终替换后,只剩下tfrecord文件读取和解析、网络输入层不一样,然后再训练更新后的F代码,现象依然存在

  • 数据排查

    通过网络排查基本排除网络问题,唯一不同在于数据,于是进行数据验证。将F代码训练过程中的输入数据记录到文件,然后对比F代码读到的数据和制作tfrecord的数据。排查数据编号、数据内容是否一致。最终发现数据一致。

    class _LoggerHook(tf.train.SessionRunHook):"""Logs loss and runtime."""def begin(self):# print('begin')self._step = -1def before_run(self, run_context):# print('before_run')self._step += 1return tf.train.SessionRunArgs(features)  # Asks for loss value.def after_run(self, run_context, run_values):if self._step == 2:logit_value = run_values.resultsprint('step ' + str(self._step) + ', features = ' + str(logit_value))f1.write(logit_value['data'])  # 训练准确率写入文件f1.flush()numpy.savetxt(r"/home/zq537/ckpt/ecg_data.txt", logit_value['data'])numpy.savetxt(r"/home/zq537/ckpt/index.txt", logit_value['name'])
    
  • 内部训练过程排查

    如果网络和数据都没问题,那么问题排查起来就比较困难了。接下来的方向可能需要深入模型内部的训练过程,看哪些步骤导致loss和准确率不变,将所有操作的输出和反向传播的梯度都记录到tensorboard中进行查看,发现进行少量的训练后,反向传播的梯度值分布都在0附近,这样网络权重基本就不会更新了,网络参数没有变化,自然准确率和loss也不会变化了。
    是什么原因导致梯度为0?梯度是从loss开始,一层一层往前传的;而loss是由预测值和实际标签共同决定的。于是我开始查看预测值和实际标签的数值。先把数据集进行简化,生成10个数据的数据集,batch-size设为10,然后在训练回调中把稳定(loss不变)时的预测值和实际标签打印出来,发现了问题:稳定状态下预测值大部分两分类的准确率都是1,这很明显是给网络的评判标准(loss函数)有问题

    labels = [[1 0][1 0][0 1][1 0][1 0][0 1][0 1][0 1][0 1][0 1]]logits = [[1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.4896728e-36][1.0000000e+00 5.1295980e-18][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00][1.0000000e+00 1.0000000e+00]]
    

    新代码中使用的损失函数是tf.losses.softmax_cross_entropy(onehot_labels=labels, logits=logits),搜索并替换为tensorflow官方module库中的损失函数tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)再进行训练,发现一切正常

问题原因

  • 网络排查
  • 数据排查
  • 内部训练过程排查

解决方案

总结

  • 在pypi官网上找相关模块信息

最开始在网上搜到的方案是ConcurrentLogHandler,但在13年就停止维护了,
合入代码也无法运行。于是又在网上找其他方案(这里浪费了不少时间),其实
ConcurrentLogHandler的homepage页已经说明了替代的库

附录

参考


  • tensorflow官方文档

tensorflow使用高阶api导致训练不收敛问题相关推荐

  1. TensorFlow高阶 API: keras教程-使用tf.keras搭建mnist手写数字识别网络

    TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字识别网络 目录 TensorFlow高阶 API:keras教程-使用tf.keras搭建mnist手写数字 ...

  2. 华为昇思高阶API套件迎来全新升级!解决无人驾驶疑难杂症真得靠它!

    点击蓝字 MindSpore 关注我们 对于程序员来说,拥有一款低门槛.易操作的深度学习开发工具包,可以说赢在了起跑线!来自华为的全场景AI框架昇思MindSpore在历经短短一年多时间的迭代,为专业 ...

  3. Pytorch高阶API示范——线性回归模型

    本文与<20天吃透Pytorch>有所不同,<20天吃透Pytorch>中是继承之前的模型进行拟合,本文是单独建立网络进行拟合. 代码实现: import torch impo ...

  4. WebDriver高阶API(8)

    17.测试HTML5语言实现的视频播放器 #encoding=utf-8 import unittest import time from selenium import webdriverclass ...

  5. Pytorch高阶API示范——DNN二分类模型

    代码部分: import numpy as np import pandas as pd from matplotlib import pyplot as plt import torch from ...

  6. 【进阶篇】全流程学习《20天掌握Pytorch实战》纪实 | Day10 | 高阶API示范

  7. Tensorflow框架是如何支持分布式训练的?

    参加 2019 Python开发者日,请扫码咨询 ↑↑↑ 作者 | 杨旭东 转载自知乎<算法工程师的自我修养>专栏 Methods that scale with computation ...

  8. TensorFlow 笔记 (五)自定义训练: 演示

    这个教程将利用机器学习的手段来对鸢尾花按照物种进行分类.本教程将利用 TensorFlow 来进行以下操作: 构建一个模型, 用样例数据集对模型进行训练,以及 利用该模型对未知数据进行预测. Tens ...

  9. TensorFlow 2 Object Detection API 教程: model 命名规则

    TensorFlow 2 Object Detection API 教程: model 命名规则 COCO-trained models {#coco-models} TensorFlow 2 Obj ...

最新文章

  1. linux sh for ls,Linux shell for while 循环
  2. ssh暴力破解解决方案
  3. python基础语法手册format-python基础_格式化输出(%用法和format用法)
  4. 如何在一个文件中写多个Vue组件(译-有删改)
  5. 用策略屏蔽135 139 445 3389端口+网络端口安全防护技
  6. JavaScript实现闭式函数计算特定位置的斐波那契数fibonacciNthClosedForm算法(附完整源码)
  7. 使用PM2搭建在线vue.js开发环境(以守护进程方式热启动)
  8. c语言对中文字符串编码_Python || 学习笔记(1):数据类型字符串变量和编码
  9. 河南理工大学c语言报告封面,河南理工大学图书信息管理系统设计_纯c语言课程设计.doc...
  10. 玩玩机器学习5——构造单层神经网络解决非线性函数(三次函数)的曲线拟合
  11. 一行或多行文本内容溢出显示省略号
  12. mysql+存储器_mysql内存储器计算公式_mysql
  13. 使用 PlantUML 绘制时序图
  14. Linux命令解释之useradd,userdel,usermod
  15. Javaweb安全——Java类加载机制
  16. python 爬阳光高考高校数据
  17. 3Dmax_三维模型无法处理平滑解决方案
  18. uni-app 超好用的时间选择器组件(起止时间)
  19. 马斯洛的需要层次理论
  20. 苹果隐藏app_iOS14隐藏功能,很实用!附部分BUG解决方案

热门文章

  1. C语言中的restrict限定符
  2. GNN金融应用之Classifying and Understanding Financial Data Using Graph Neural Network学习笔记
  3. 自学IOS开发第2天·学习基础SwiftUI
  4. 【2020年保研记】浙大软院+中科院信工所+北师大人工智能学院+华中科技网安学院+四川大学网安学院+中山大学系统科学与工程学院
  5. cfa专题突破网课资源
  6. 在Linux终端命令行下播放音乐的命令
  7. PyQt5入门学习(一)【PyQt5及PyQt5-tools的安装】
  8. 重塑矩阵(一个矩阵转化成另一个矩阵)
  9. 【2024】末两位数
  10. 关于解决虚拟机不能挂起的问题