上一篇我们分析了XDL的framework的架构设计,了解了XDL的模型构建和运行机制,以及XDL如何将tensorflow作为自己的backend。

本篇继续分析XDL的架构设计,重点关注XDL的参数训练过程。

样例

还是以XDL自带的deepctr(代码位置:xdl/xdl/examples/deepctr/deepctr.py)为例:

def train():...emb2 = xdl.embedding('emb2', batch['sparse1'], xdl.TruncatedNormal(stddev=0.001), 8, 1024, vtype='hash')## 可以看到,loss op并没有传入train_op,那么train_op的优化目标是什么呢?loss = model(batch['deep0'], [emb1, emb2], batch['label'])train_op = xdl.SGD(0.5).optimize()log_hook = xdl.LoggerHook(loss, "loss:{0}", 10)sess = xdl.TrainSession(hooks=[log_hook])while not sess.should_stop():sess.run(train_op)@xdl.tf_wrapper()
def model(deep, embeddings, labels):input = tf.concat([deep] + embeddings, 1)fc1 = tf.layers.dense(input, 128, kernel_initializer=tf.truncated_normal_initializer(stddev=0.001, dtype=tf.float32))fc2 = tf.layers.dense(fc1, 64, kernel_initializer=tf.truncated_normal_initializer(stddev=0.001, dtype=tf.float32))fc3 = tf.layers.dense(fc2, 32, kernel_initializer=tf.truncated_normal_initializer(stddev=0.001, dtype=tf.float32))y = tf.layers.dense(fc3, 1, kernel_initializer=tf.truncated_normal_initializer(stddev=0.001, dtype=tf.float32))loss = tf.losses.sigmoid_cross_entropy(labels, y)## model函数返回的是模型的目标loss op, 而不是train op,为什么呢?return loss

细心的读者发现,XDL采用的优化器是自己实现的优化器xdl.SGD,没有采用tensorflow的优化器,而且并没有将loss op传入优化器xdl.SGD,再接下来就直接sess.run(train_op)了,那么优化器的优化目标是什么呢?xdl是如何实现参数更新的呢?

关键在于三个地方:

  • 装饰器xdl.tf_wrapper
  • backend op: TFBackendOp
  • 优化器xdl.SGD,继承自xdl.Optimizer

参数更新的步骤

  • 装饰器xdl.tf_wrapper与被装饰函数约定,被装饰函数的第一个返回值必须是模型的loss
  • 由前篇分析,我们知道xdl的sparse参数,会赋值给对应的tensorflow placeholder中然后参与tensorflow graph的运行,xdl的dense类型参数,会赋值给对应的tensorflow variable然后参与tensroflow graph的运行
  • 赋值的方式是通过TFBackendOp的Input
  • 装饰器xdl.tf_wrapper调用函数add_backprop_ops,为之前添加的tensorflow placeholder或则tensorflow variable添加梯度计算op。注意这些梯度计算op都是添加到了tensorflow graph中。
  • TFBackendOp的输出包括了以上的梯度计算Op的输出
  • TFBackendOp传输的梯度是loss关于指定的tensorflow placeholder和tensorflow variable的梯度,实际上,也就是loss关于xdl sparse和dense参数的梯度
  • 有了loss关于xdl参数的梯度后,接下来就是通过几类学习算法,更新XDL参数

总结

我们知道,参数学习的过程,可以大致分为前向计算、梯度后向传递、参数更新三个阶段,从这个角度出发,我们分析一下XDL的参数学习过程:

  • xdl参数由参数服务器存储与维护
  • 参数值经由worker中的PsDensePullOp或PsSparsePullOp,传至worker
  • 参数作为TFBackendOp的Input,feed于相应的tensorflow placeholder或variable
  • 运行TFBackendOp,实际上也就是调用tensorflow runtime驱动tensorflow graph运行
  • tensorflow graph的执行结果,通过TFBackendOp的output输出
  • TFBackendOp的output输出主要是指定的Op以及用户定义的模型loss关于模型参数的梯度
  • 梯度传入特定的xdl update op,具体的update op取决于用户采用的优化器,比如xdl.SGD,xdl.Momentum,xdl.Adagrad,xdl.Ftrl等等
  • xdl update op通过DensePush或SparsePush API,更新xdl参数
  • 至此一次完整的参数学习过程结束

深入浅出XDL(四):模型训练相关推荐

  1. Qt5.7+Opencv2.4.9人脸识别(四)模型训练

    [注意]本博文的档次适合OpenCV初学者,和要做本科生毕业设计这类档次. 源码的下载地址和原理理论部分请走下面连接 http://blog.csdn.net/qq78442761/article/d ...

  2. 用深度学习做命名实体识别(四)——模型训练

    通过本文你将了解如何训练一个人名.地址.组织.公司.产品.时间,共6个实体的命名实体识别模型. 准备训练样本 下面的链接中提供了已经用brat标注好的数据文件以及brat的配置文件,因为标注内容较多放 ...

  3. pytorch使用detectron2模型库模型训练自己的数据

    一 应用场景 在x86 (Ubuntu18.04)cpu,在pytorch1.10框架下,使用detectron2模型库模型训练自己的数据集,并进行目标检测推理. 二 环境配置 我的环境是: pyto ...

  4. 【深度学习】深入浅出数字图像处理基础(模型训练的先修课)

    [深度学习]深入浅出数字图像处理基础(模型训练的先修课) 文章目录 1 图像的表示 2 图像像素运算 3 采样与量化3.1 采样3.2 量化3.3 图像上采样与下采样 4 插值算法分类 5 什么是池化 ...

  5. 训练softmax分类器实例_第四章.模型训练

    迄今为止,我们只是把机器学习模型及其大多数训练算法视为黑盒.但是如果你做了前面几章的一些练习,你可能会惊讶于你可以在不知道任何关于背后原理的情况下完成很多工作:优化一个回归系统,改进一个数字图像分类器 ...

  6. 《OpenCv视觉之眼》Python图像处理十九:Opencv图像处理实战四之通过OpenCV进行人脸口罩模型训练并进行口罩检测

    本专栏主要介绍如果通过OpenCv-Python进行图像处理,通过原理理解OpenCv-Python的函数处理原型,在具体情况中,针对不同的图像进行不同等级的.不同方法的处理,以达到对图像进行去噪.锐 ...

  7. 强化学习技巧四:模型训练速度过慢、GPU利用率较低,CPU利用率很低问题总结与分析。

    1.PyTorchGPU利用率较低问题原因: 在服务器端或者本地pc端, 输入nvidia-smi 来观察显卡的GPU内存占用率(Memory-Usage),显卡的GPU利用率(GPU-util),然 ...

  8. Pytorch模型训练实用教程学习笔记:四、优化器与学习率调整

    前言 最近在重温Pytorch基础,然而Pytorch官方文档的各种API是根据字母排列的,并不适合学习阅读. 于是在gayhub上找到了这样一份教程<Pytorch模型训练实用教程>,写 ...

  9. NVIDIA 7th SkyHackathon(四)Nemo ASR 模型训练与评估

    1.模型加载 1.1 导入 NeMo import nemo import nemo.collection.asr as nemo_asr import torch# 检查 nemo 版本 '1.4. ...

最新文章

  1. ubuntu共享无线链接
  2. noip 2017棋盘
  3. iPhone 13发售日期偷跑:9月17日全系开售、共4款
  4. jQuery获取隐藏域和radio单项框的值
  5. 基于CSS+dIV的网页层,点击后隐藏或显示
  6. uniapp中获取元素页面信息的方法
  7. NodeJs快速入门
  8. 计算机页面添加文字水印在哪,轻松学会给office2013 word文档添加图片/文字背景水印以及让水印铺满整个页面-网络教程与技术 -亦是美网络...
  9. stata中计算公式命令_stata 计算命令:
  10. ctf 实验吧 围在栅栏中的爱 (最近一直在好奇一个问题,QWE到底等不等于ABC? )
  11. 无线网首选dns服务器怎么设置,怎么设置无线路由器dns
  12. azw3文件怎么打开?
  13. 手绘vs码绘1——Q版小人
  14. 百度2019Q3财报和战略分析
  15. 这100 个网络基础知识,看完成半个网络高手
  16. 快速批量把jpg转换成pdf的方法
  17. 各样本观察值均加同一常数_医药数理统计学试题及答案
  18. 期货开户手续费的收取方式是什么?
  19. 【MindMapper2008】选中文字自动生成节点
  20. 那些年,你看过有哪些让你记忆犹新的书

热门文章

  1. echarts市级区域地图数据展示
  2. Java 两个中文字符串异或问题
  3. 北大光华管理学院-宏观经济学
  4. 软件测试中测试版本的质量状况,测试结果分析和质量报告
  5. 【蚂蚁金服6面】成功进入核心拿了36K,突然感觉貌似不太难!
  6. 【Meetup预告】OpenMLDB+37手游:一键查收实时特征计算场景案例及进阶使用攻略
  7. 开源数据库MySQL DBA运维实战 第2章 SQL1
  8. Android 面试真题收录~
  9. 祝贺一个逃离科研的博士
  10. 画像ToB独角兽,怎么做风口下能飞的猪?