简介

Tensorflow API提供了Cluster、Server以及Supervisor来支持模型的分布式训练。

关于Tensorflow的分布式训练介绍可以参考Distributed Tensorflow。简单的概括说明如下:

  • Tensorflow分布式Cluster由多个Task组成,每个Task对应一个tf.train.Server实例,作为Cluster的一个单独节点;
  • 多个相同作用的Task可以被划分为一个job,例如ps job作为参数服务器只保存Tensorflow model的参数,而worker job则作为计算节点只执行计算密集型的Graph计算。
  • Cluster中的Task会相对进行通信,以便进行状态同步、参数更新等操作。

Tensorflow分布式集群的所有节点执行的代码是相同的。分布式任务代码具有固定的模式:

# 第1步:命令行参数解析,获取集群的信息ps_hosts和worker_hosts,以及当前节点的角色信息job_name和task_index# 第2步:创建当前task结点的Server
cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})
server = tf.train.Server(cluster, job_name=FLAGS.job_name, task_index=FLAGS.task_index)# 第3步:如果当前节点是ps,则调用server.join()无休止等待;如果是worker,则执行第4步。
if FLAGS.job_name == "ps":server.join()# 第4步:则构建要训练的模型
# build tensorflow graph model# 第5步:创建tf.train.Supervisor来管理模型的训练过程
# Create a "supervisor", which oversees the training process.
sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), logdir="/tmp/train_logs")
# The supervisor takes care of session initialization and restoring from a checkpoint.
sess = sv.prepare_or_wait_for_session(server.target)
# Loop until the supervisor shuts down
while not sv.should_stop()# train model

Tensorflow分布式训练代码框架

根据上面说到的Tensorflow分布式训练代码固定模式,如果要编写一个分布式的Tensorlfow代码,其框架如下所示。

import tensorflow as tf# Flags for defining the tf.train.ClusterSpec
tf.app.flags.DEFINE_string("ps_hosts", "","Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", "","Comma-separated list of hostname:port pairs")# Flags for defining the tf.train.Server
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")FLAGS = tf.app.flags.FLAGSdef main(_):ps_hosts = FLAGS.ps_hosts.split(",")worker_hosts = FLAGS.worker_hosts(",")# Create a cluster from the parameter server and worker hosts.cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})# Create and start a server for the local task.server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)if FLAGS.job_name == "ps":server.join()elif FLAGS.job_name == "worker":# Assigns ops to the local worker by default.
    with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):# Build model...loss = ...global_step = tf.Variable(0)train_op = tf.train.AdagradOptimizer(0.01).minimize(loss, global_step=global_step)saver = tf.train.Saver()summary_op = tf.merge_all_summaries()init_op = tf.initialize_all_variables()# Create a "supervisor", which oversees the training process.sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0),logdir="/tmp/train_logs",init_op=init_op,summary_op=summary_op,saver=saver,global_step=global_step,save_model_secs=600)# The supervisor takes care of session initialization and restoring from# a checkpoint.sess = sv.prepare_or_wait_for_session(server.target)# Start queue runners for the input pipelines (if any).
    sv.start_queue_runners(sess)# Loop until the supervisor shuts down (or 1000000 steps have completed).step = 0while not sv.should_stop() and step < 1000000:# Run a training step asynchronously.# See `tf.train.SyncReplicasOptimizer` for additional details on how to# perform *synchronous* training._, step = sess.run([train_op, global_step])if __name__ == "__main__":tf.app.run()

对于所有Tensorflow分布式代码,可变的只有两点:

  1. 构建tensorflow graph模型代码;
  2. 每一步执行训练的代码

分布式MNIST任务

我们通过修改tensorflow/tensorflow提供的mnist_softmax.py来构造分布式的MNIST样例来进行验证。修改后的代码请参考mnist_dist.py。

我们同样通过tensorlfow的Docker image来启动一个容器来进行验证。

$ docker run -d -v /path/to/your/code:/tensorflow/mnist --name tensorflow tensorflow/tensorflow

启动tensorflow之后,启动4个Terminal,然后通过下面命令进入tensorflow容器,切换到/tensorflow/mnist目录下

$ docker exec -ti tensorflow /bin/bash
$ cd /tensorflow/mnist

然后在四个Terminal中分别执行下面一个命令来启动Tensorflow cluster的一个task节点,

# Start ps 0
python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=0# Start ps 1
python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=ps --task_index=1# Start worker 0
python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=0# Start worker 1
python mnist_dist.py --ps_hosts=localhost:2221,localhost:2222 --worker_hosts=localhost:2223,localhost:2224 --job_name=worker --task_index=1

具体效果自己验证哈。

Tensorflow学习笔记4:分布式Tensorflow相关推荐

  1. 学习笔记TF061:分布式TensorFlow,分布式原理、最佳实践

    分布式TensorFlow由高性能gRPC库底层技术支持.Martin Abadi.Ashish Agarwal.Paul Barham论文<TensorFlow:Large-Scale Mac ...

  2. 【学习笔记】分布式Tensorflow

    https://www.cnblogs.com/zhangfengxian/p/10690218.html 目录 分布式原理 单机多卡 多机多卡(分布式) 分布式的架构 节点之间的关系 分布式的模式 ...

  3. tensorflow学习笔记:查看tensorflow可配置运算资源以及配置使用GPU运算

    查看tensorflow可配置运算资源以及配置使用GPU运算 因为还用不到分布式的tensorflow,自己没有尝试过所以就不写分布式tensorflow的使用了(等自己用上了再说),这里记录一下在跑 ...

  4. TensorFlow学习笔记之一(TensorFlow基本介绍)

    文章目录 TensorFlow计算模型---计算图 计算图的使用 TensorFlow数据模型---张量 TensorFlow运算模型---会话 使用tf.InteractiveSession在交互式 ...

  5. TensorFlow学习笔记01:TensorFlow入门

    文章目录 一.TensorFlow基本概念 1.TensorFlow的Hello World 2.TensorFlow的概念 3.计算图&#

  6. TensorFlow学习笔记Day01-安装TensorFlow

    知识经济的时代,数据为王的时代,互联网的世界,什么东西都在不断的更新中,为此,我们自己也必须前行,不前行就会遭到淘汰.TensorFlow作为Google推出的便捷框架,已经受到了许多技术开发者的使用 ...

  7. TensorFlow学习笔记--第三节张量(tensor)及其定义方法

    目录 在TensorFlow中,所有的数据通过张量的形式来表示 1张量及属性: 1.1维数(阶) 1.2 形状 1.3数据类型 TensorFlow 支持以下三种类型的张量: **1.常量** **2 ...

  8. tensorflow学习笔记

    tensorflow学习笔记 按照<TensorFlow:实战Google深度学习框架>一书学习的tensorflow,书中使用的是0.9.0版本,而我安装的是1.2.1,出现了一些问题: ...

  9. Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题

    Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 参考文章: (1)Tensorflow学习笔记6:解决tensorflow训练过程中GPU未调用问题 (2)http ...

  10. tensorflow学习笔记(三十二):conv2d_transpose (解卷积)

    tensorflow学习笔记(三十二):conv2d_transpose ("解卷积") deconv解卷积,实际是叫做conv_transpose, conv_transpose ...

最新文章

  1. 1.1 基本图像导入、处理和导出
  2. NameNode与DataNode的工作原理剖析
  3. 性能媲美BERT却只有其1/10参数量? | 近期最火模型ELECTRA解析
  4. linux 识别文件类型,技术|Linux 中 7 个判断文件系统类型的方法
  5. 程序防止SqlServer使用SqlServer Profiler跟踪
  6. CCF201903-5 317号子任务(100分题解链接)
  7. erl_0016 《硝烟中的erlang》 读书笔记003 “error_logger 爆炸”
  8. 项目后台运行关闭_iOS到底有没有必要上滑强制关闭APP?
  9. 常用模块以及常用方法
  10. 多媒体计算机网络解释,多媒体-名词解释及填空解读.doc
  11. python超链接格式_用Python在本地文件夹中插入超链接
  12. linux基础的基础命令操作
  13. SQL server修改字段名,属性
  14. 调用百度地图进行路线规划
  15. 方舟非主机服务器无限距离,方舟生存进化怎么调主机距离
  16. 关于VA过期的解决办法
  17. 网络篇 OSPF的DR与BDR的选举-48
  18. VideoCodec 入门篇 - 00 (编解码简介)
  19. html5 任务列表,《怪物猎人 世界:冰原》每周活动任务列表(不断更新中)
  20. 验证银行卡卡号是否符合规则

热门文章

  1. hdc mfc 画扇形图_MFC画图总结-DIB图形绘制
  2. python vbs库_Python语言之requests库
  3. rust队友开挂_腐蚀RUST开挂玩家识别方法 如何识别玩家开挂
  4. java两个长度不同数组_两组数组,长度不一样,如果其中一个数组的值在另一个中不存在,则不符合要求.怎么算?...
  5. 支付宝 android 2.3,app被拒记录-2.3-包含支付宝
  6. html安卓手机打开后只有半屏,宽度设置100%在移动端时变成一半
  7. 一个报文的路由器之旅_一个报文的路由器之旅
  8. mysql的增_MySQL之增_insert-replace
  9. 【JAVA中级篇】线程池
  10. SQLExecption:Operation not allowed after ResultSet closed解决办法