tensorflow分布式训练

博客http://blog.csdn.net/hjimce

微博黄锦池-hjimce   qq:1393852684

情况一、单机单卡

单机单卡是最普通的情况,当然也是最简单的,示例代码如下:

#coding=utf-8
#单机单卡
#对于单机单卡,可以把参数和计算都定义再gpu上,不过如果参数模型比较大,显存不足等情况,就得放在cpu上
import  tensorflow as tfwith tf.device('/cpu:0'):#也可以放在gpu上w=tf.get_variable('w',(2,2),tf.float32,initializer=tf.constant_initializer(2))b=tf.get_variable('b',(2,2),tf.float32,initializer=tf.constant_initializer(5))with tf.device('/gpu:0'):addwb=w+bmutwb=w*bini=tf.initialize_all_variables()
with tf.Session() as sess:sess.run(ini)np1,np2=sess.run([addwb,mutwb])print np1print np2

情况二、单机多卡

单机多卡,只要用device直接指定设备,就可以进行训练,SGD采用各个卡的平均值,示例代码如下:

#coding=utf-8
#单机多卡:
#一般采用共享操作定义在cpu上,然后并行操作定义在各自的gpu上,比如对于深度学习来说,我们一把把参数定义、参数梯度更新统一放在cpu上
#各个gpu通过各自计算各自batch 数据的梯度值,然后统一传到cpu上,由cpu计算求取平均值,cpu更新参数。
#具体的深度学习多卡训练代码,请参考:https://github.com/tensorflow/models/blob/master/inception/inception/inception_train.py
import  tensorflow as tfwith tf.device('/cpu:0'):w=tf.get_variable('w',(2,2),tf.float32,initializer=tf.constant_initializer(2))b=tf.get_variable('b',(2,2),tf.float32,initializer=tf.constant_initializer(5))with tf.device('/gpu:0'):addwb=w+b
with tf.device('/gpu:1'):mutwb=w*bini=tf.initialize_all_variables()
with tf.Session() as sess:sess.run(ini)while 1:print sess.run([addwb,mutwb])



情况三、多机多卡

一、基本概念

Cluster、Job、task概念:三者可以简单的看成是层次关系,task可以看成每台机器上的一个进程,多个task组成job;job又有:ps、worker两种,分别用于参数服务、计算服务,组成cluster。

二、同步SGD与异步SGD

1、所谓的同步更新指的是:各个用于并行计算的电脑,计算完各自的batch 后,求取梯度值,把梯度值统一送到ps服务机器中,由ps服务机器求取梯度平均值,更新ps服务器上的参数。

如下图所示,可以看成有四台电脑,第一台电脑用于存储参数、共享参数、共享计算,可以简单的理解成内存、计算共享专用的区域,也就是ps job;另外三台电脑用于并行计算的,也就是worker task。

这种计算方法存在的缺陷是:每一轮的梯度更新,都要等到A、B、C三台电脑都计算完毕后,才能更新参数,也就是迭代更新速度取决与A、B、C三台中,最慢的那一台电脑,所以采用同步更新的方法,建议A、B、C三台的计算能力都不想。

2、所谓的异步更新指的是:ps服务器收到只要收到一台机器的梯度值,就直接进行参数更新,无需等待其它机器。这种迭代方法比较不稳定,收敛曲线震动比较厉害,因为当A机器计算完更新了ps中的参数,可能B机器还是在用上一次迭代的旧版参数值。


三、代码编写

1、定义集群

比如假设上面的图所示,我们有四台电脑,四台电脑的名字假设为:A、B、C、D,那么集群可以定义如下

#coding=utf-8
#多台机器,每台机器有一个显卡、或者多个显卡,这种训练叫做分布式训练
import  tensorflow as tf
#现在假设我们有A、B、C、D四台机器,首先需要在各台机器上写一份代码,并跑起来,各机器上的代码内容大部分相同
# ,除了开始定义的时候,需要各自指定该台机器的task之外。以机器A为例子,A机器上的代码如下:
cluster=tf.train.ClusterSpec({"worker": ["A_IP:2222",#格式 IP地址:端口号,第一台机器A的IP地址 ,在代码中需要用这台机器计算的时候,就要定义:/job:worker/task:0"B_IP:1234"#第二台机器的IP地址 /job:worker/task:1"C_IP:2222"#第三台机器的IP地址 /job:worker/task:2],"ps": ["D_IP:2222",#第四台机器的IP地址 对应到代码块:/job:ps/task:0]})

然后我们需要写四分代码,这四分代码文件大部分相同,但是有几行代码是各不相同的。

2、在各台机器上,定义server

比如A机器上的代码server要定义如下:

server=tf.train.Server(cluster,job_name='worker',task_index=0)#找到‘worker’名字下的,task0,也就是机器A

3、在代码中,指定device

with tf.device('/job:ps/task:0'):#参数定义在机器D上w=tf.get_variable('w',(2,2),tf.float32,initializer=tf.constant_initializer(2))b=tf.get_variable('b',(2,2),tf.float32,initializer=tf.constant_initializer(5))with tf.device('/job:worker/task:0/cpu:0'):#在机器A cpu上运行addwb=w+b
with tf.device('/job:worker/task:1/cpu:0'):#在机器B cpu上运行mutwb=w*b
with tf.device('/job:worker/task:2/cpu:0'):#在机器C cpu上运行divwb=w/b

在深度学习训练中,一般图的计算,对于每个worker task来说,都是相同的,所以我们会把所有图计算、变量定义等代码,都写到下面这个语句下:

with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:indexi',cluster=cluster)):

函数replica_deviec_setter会自动把变量参数定义部分定义到ps服务中(如果ps有多个任务,那么自动分配)。下面举个例子,假设现在有两台机器A、B,A用于计算服务,B用于参数服务,那么代码如下:

#coding=utf-8
#上面是因为worker计算内容各不相同,不过再深度学习中,一般每个worker的计算内容是一样的,
# 以为都是计算神经网络的每个batch 前向传导,所以一般代码是重用的
import  tensorflow as tf
#现在假设我们有A、B台机器,首先需要在各台机器上写一份代码,并跑起来,各机器上的代码内容大部分相同
# ,除了开始定义的时候,需要各自指定该台机器的task之外。以机器A为例子,A机器上的代码如下:
cluster=tf.train.ClusterSpec({"worker": ["192.168.11.105:1234",#格式 IP地址:端口号,第一台机器A的IP地址 ,在代码中需要用这台机器计算的时候,就要定义:/job:worker/task:0],"ps": ["192.168.11.130:2223"#第四台机器的IP地址 对应到代码块:/job:ps/task:0]})#不同的机器,下面这一行代码各不相同,server可以根据job_name、task_index两个参数,查找到集群cluster中对应的机器isps=False
if isps:server=tf.train.Server(cluster,job_name='ps',task_index=0)#找到‘worker’名字下的,task0,也就是机器Aserver.join()
else:server=tf.train.Server(cluster,job_name='worker',task_index=0)#找到‘worker’名字下的,task0,也就是机器Awith tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:0',cluster=cluster)):w=tf.get_variable('w',(2,2),tf.float32,initializer=tf.constant_initializer(2))b=tf.get_variable('b',(2,2),tf.float32,initializer=tf.constant_initializer(5))addwb=w+bmutwb=w*bdivwb=w/bsaver = tf.train.Saver()
summary_op = tf.merge_all_summaries()
init_op = tf.initialize_all_variables()
sv = tf.train.Supervisor(init_op=init_op, summary_op=summary_op, saver=saver)
with sv.managed_session(server.target) as sess:while 1:print sess.run([addwb,mutwb,divwb])

把该代码在机器A上运行,你会发现,程序会进入等候状态,等候用于ps参数服务的机器启动,才会运行。因此接着我们在机器B上运行如下代码:

#coding=utf-8
#上面是因为worker计算内容各不相同,不过再深度学习中,一般每个worker的计算内容是一样的,
# 以为都是计算神经网络的每个batch 前向传导,所以一般代码是重用的
#coding=utf-8
#多台机器,每台机器有一个显卡、或者多个显卡,这种训练叫做分布式训练
import  tensorflow as tf
#现在假设我们有A、B、C、D四台机器,首先需要在各台机器上写一份代码,并跑起来,各机器上的代码内容大部分相同
# ,除了开始定义的时候,需要各自指定该台机器的task之外。以机器A为例子,A机器上的代码如下:
cluster=tf.train.ClusterSpec({"worker": ["192.168.11.105:1234",#格式 IP地址:端口号,第一台机器A的IP地址 ,在代码中需要用这台机器计算的时候,就要定义:/job:worker/task:0],"ps": ["192.168.11.130:2223"#第四台机器的IP地址 对应到代码块:/job:ps/task:0]})#不同的机器,下面这一行代码各不相同,server可以根据job_name、task_index两个参数,查找到集群cluster中对应的机器isps=True
if isps:server=tf.train.Server(cluster,job_name='ps',task_index=0)#找到‘worker’名字下的,task0,也就是机器Aserver.join()
else:server=tf.train.Server(cluster,job_name='worker',task_index=0)#找到‘worker’名字下的,task0,也就是机器Awith tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:0',cluster=cluster)):w=tf.get_variable('w',(2,2),tf.float32,initializer=tf.constant_initializer(2))b=tf.get_variable('b',(2,2),tf.float32,initializer=tf.constant_initializer(5))addwb=w+bmutwb=w*bdivwb=w/bsaver = tf.train.Saver()
summary_op = tf.merge_all_summaries()
init_op = tf.initialize_all_variables()
sv = tf.train.Supervisor(init_op=init_op, summary_op=summary_op, saver=saver)
with sv.managed_session(server.target) as sess:while 1:print sess.run([addwb,mutwb,divwb])

分布式训练需要熟悉的函数:

  • tf.train.Server
  • tf.train.Supervisor
  • tf.train.SessionManager
  • tf.train.ClusterSpec
  • tf.train.replica_device_setter
  • tf.train.MonitoredTrainingSession
  • tf.train.MonitoredSession
  • tf.train.SingularMonitoredSession
  • tf.train.Scaffold
  • tf.train.SessionCreator
  • tf.train.ChiefSessionCreator
  • tf.train.WorkerSessionCreator

参考文献:

https://www.tensorflow.org/versions/master/how_tos/distributed/index.html

深度学习(五十五)tensorflow分布式训练相关推荐

  1. 花书+吴恩达深度学习(十五)序列模型之循环神经网络 RNN

    目录 0. 前言 1. RNN 计算图 2. RNN 前向传播 3. RNN 反向传播 4. 导师驱动过程(teacher forcing) 5. 不同序列长度的 RNN 如果这篇文章对你有一点小小的 ...

  2. 深度学习(十五)基于级联卷积神经网络的人脸特征点定位

    基于级联卷积神经网络的人脸特征点定位 原文地址:http://blog.csdn.net/hjimce/article/details/49955149 作者:hjimce 一.相关理论 本篇博文主要 ...

  3. 系统学习深度学习(十五)--AlexNet译文

    转自:http://www.aichengxu.com/other/2557713.htm http://blog.csdn.net/maweifei/article/details/53117830 ...

  4. 深度学习(十五)——SPPNet, Fast R-CNN

    https://antkillerfarm.github.io/ RCNN(续) RCNN算法的基本流程 RCNN算法分为4个步骤: Step 1:候选区域生成.一张图像生成1K~2K个候选区域(采用 ...

  5. 深度学习(七十二)tensorflow 集群训练

    #encoding:utf-8 # -*- coding: utf-8 -*- #使用说明:1.修改分类数目;2.修改输入图片大小: # 3.修改是否启用集群: 4.修改batch size大小:5. ...

  6. tensowflow 训练 远程提交_一文说清楚Tensorflow分布式训练必备知识

    Note: 原文发表于我的知乎专栏:算法工程师的自我修养,欢迎关注! Methods that scale with computation are the future of AI. -Rich S ...

  7. 花书+吴恩达深度学习(十六)序列模型之双向循环网络 BRNN 和深度循环网络 Deep RNN

    目录 0. 前言 1. 双向循环网络 BRNN(Bidirectional RNN) 2. 深度循环网络 Deep RNN 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花 ...

  8. 深度学习(16)TensorFlow高阶操作五: 张量限幅

    深度学习(16)TensorFlow高阶操作五: 张量限幅 1. clip_by_value 2. relu 3. clip_by_norm 4. Gradient clipping 5. 梯度爆炸实 ...

  9. 深度学习入门(五十二)计算机视觉——风格迁移

    深度学习入门(五十二)计算机视觉--风格迁移 前言 计算机视觉--风格迁移 课件 样式迁移 易于CNN的样式迁移 教材 1 方法 2 阅读内容和风格图像 3 预处理和后处理 4 抽取图像特征 5 定义 ...

最新文章

  1. tkinter项目实战_Python GUI项目实战(二)主窗体的界面设计与实现
  2. 字符集 ISO-8859-1(1)
  3. CentOS7安装Nginx及其相关
  4. android学习笔记34——ClipDrawable资源
  5. bzoj1089 [SCOI2003]严格n元树(dp+高精)
  6. [Python]--Anaconda Resources Collection
  7. asp服务器_200行代码,7个对象——让你了解ASP.NET Core框架的本质「3.x版」
  8. matlab能输入铁心参数,基于MATLAB的电力机车110伏直流稳压电源仿真研究
  9. matlab能控型模型,级倒立摆MATLAB仿真、能控能观性分析、数学模型、极点配置
  10. linux命令无视错误,llinux 的一些命令和错误
  11. Linux内核4.14 LTS发布:那些最新最好的功能特性
  12. Android应用神器:高级终端Termux
  13. 计算机电源管理器怎么打开,联想电池管理如何使用_联想电源管理软件在哪里打开-win7之家...
  14. Python之文本去重(基础版)
  15. python flag用法_python flag什么意思
  16. 30个免费的CSS3动画片段代码
  17. 软件工程个人项目— 数独
  18. 深信服php面经,深信服面经
  19. 数据库并发抢红包_微信高并发抢红包秒杀实战案例
  20. 机器学习流程是什么?简述机器学习流程!

热门文章

  1. 优先级调度算法实现_React17新特性:启发式更新算法
  2. php ascii art,ASCII art (简体中文)
  3. php pdo mysql类源码_php pdo数据库类(提取自微擎的pdo方式处理数据库类库)
  4. sqlserver备份和恢复
  5. java8 stream遍历_Java8中用法优雅的 Stream,性能也优雅吗?
  6. java输入输出及文件_(java基础)Java输入输出流及文件相关
  7. 疯狂软件mysql视频_疯狂软件MySql视频
  8. linux系统怎么安装cas,CAS 在Linux中安装与配置
  9. tc/traffic control 网络控制工具
  10. 《Linux命令行与shell脚本编程大全 第3版》高级Shell脚本编程---24