深度学习(七十二)tensorflow 集群训练
#encoding:utf-8
# -*- coding: utf-8 -*-
#使用说明:1、修改分类数目;2、修改输入图片大小;
# 3、修改是否启用集群; 4、修改batch size大小;5、修改数据路径、模型保存路径
#6、设置是否启用boostrap loss 损失函数
import osimport tensorflow as tf
from input_data import Data_layer
import netnum_class = 2
input_height = 256
input_width = 256
crop_height = 224
crop_width = 224
learning_rate = 0.01
tf.set_random_seed(123)
batch_size = tf.placeholder(tf.int32, [], 'batch_size')
tf.add_to_collection('batch_size', batch_size)
is_training = tf.placeholder(tf.bool, [])
is_boostrap = tf.placeholder(tf.bool, [])
drop_prob = tf.placeholder(tf.float32, [])
tf.add_to_collection('is_training', is_training)
def load_save_model(sess,saver,model_path,is_save):if is_save is False:print "***********restore model from %s***************"%model_pathsaver.restore(sess, model_path)else:saver.save(sess, model_path)
def train_cluster(train_volume_data,valid_volume_data,model_path):tf.flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs")tf.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224","Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "", "Either 'ps' or 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_string("volumes", "", "volumes info")FLAGS = tf.app.flags.FLAGSps_hosts = FLAGS.ps_hosts.split(",")print("ps_hosts:", ps_hosts)worker_hosts = FLAGS.worker_hosts.split(",")print("worker_hosts:", worker_hosts)cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})print("FLAGS.task_index:", FLAGS.task_index)# Create and start a server for the local task.server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)if FLAGS.job_name == "ps":server.join()elif FLAGS.job_name == "worker":with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):input_data = Data_layer(train_volume_data, valid_volume_data, batch_size=batch_size, image_height=input_height, image_width=input_width, crop_height=crop_height, crop_width=crop_width)images, labels = input_data.get_next_batch(is_training, num_class)#net_worker = net.resnet(images, labels, num_class, 18, is_training, drop_prob)net_worker = net.resnet256(images, labels, num_class, is_training,is_boostrap)saver = tf.train.Saver()init = tf.global_variables_initializer()sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), init_op=init,saver=saver,global_step=net_worker['global_step'])with sv.prepare_or_wait_for_session(server.target) as session:coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(session, coord=coord)#threads = sv.start_queue_runners(session)load_save_model(session, saver, model_path, False)try:for i in range(400000):if i < 10000:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: False}else:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: True}step, _ = session.run([net_worker['global_step'], net_worker['train_op']], feed_dict=train_dict)if i % 500 == 0:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: False}entropy, train_acc = session.run([net_worker['cross_entropy'], net_worker['accuracy']],feed_dict=train_dict)print('***** {}:{},{} *****'.format(i, entropy, train_acc))if i % 2000 == 0:test_dict = {drop_prob: 1.0,is_training: False,batch_size: 256,is_boostrap:False}acc = session.run(net_worker['accuracy'], feed_dict=test_dict)print('*****locate step {},valid step {}:accuracy {} *****'.format(i,step, acc))if i>3000:print "**************save model***************"load_save_model(session,saver,model_path,True)except Exception, e:coord.request_stop(e)finally:coord.request_stop()coord.join(threads)
深度学习(七十二)tensorflow 集群训练相关推荐
- 深度学习(七十二)ssd物体检测
def ssd_anchor_one_layer(img_shape,feat_shape,sizes,ratios,step,offset=0.5,dtype=np.float32):# 计算每个d ...
- Tensorflow深度学习之十二:基础图像处理之二
Tensorflow深度学习之十二:基础图像处理之二 from:https://blog.csdn.net/davincil/article/details/76598474 首先放出原始图像: ...
- 前几帧预测 深度学习_使用深度学习从十二导联心电图预测心律失常
上集讲到 使用深度学习 从单导联预测房颤 这一集 将继续讨论该问题 单导联心电图 对心律失常的预测作用 非常有限 因为 单导联的信号很有限 临床上需要结合 多导联心电图 判断 心律失常的类型 这一集的 ...
- 花书+吴恩达深度学习(十二)卷积神经网络 CNN 之全连接层
目录 0. 前言 1. 全连接层(fully connected layer) 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十)卷积神经网络 CNN ...
- 深度学习(十二)稀疏自编码
稀疏自编码 原文地址:http://blog.csdn.net/hjimce/article/details/49106869 作者:hjimce 一.相关理论 以前刚开始学CNN的时候,就是通过阅读 ...
- 系统学习redis之二——redis集群搭建
redis单点部署: 安装命令: # cd /usr/local/ # wget http://download.redis.io/releases/redis-4.0.1.tar.gz #下载安装包 ...
- 深度学习(十二)——Winograd(2)
最大公约数和Euclidean algorithm(续) Euclidean algorithm的步骤如下图所示: 1.假设a>ba>ba>b,则令c:=amodbc:=amodbc ...
- 深度学习(十二):Matconvnet小试牛刀与提特征
该节简单介绍一下如何使用Matconvnet的现有的模型进行图像分类实验以及提取图像对应层的特征. 先来看看如何用训练好的imagenet网络模型进行图像的预测,英文版的官网教程就在这里: http: ...
- 系统学习深度学习(十二)--池化
转自:http://blog.csdn.net/danieljianfeng/article/details/42433475 在卷积神经网络中,我们经常会碰到池化操作,而池化层往往在卷积层后面,通过 ...
最新文章
- 简单介绍.Net性能测试框架Crank的使用方法
- 预训练语言模型(PLM)必读论文清单(附论文PDF、源码和模型链接)
- apache rewrite 支持post 数据
- php后台如何连接网口打印机_如何设置斑马网络打印机的网卡IP地址
- 如何用Redlock实现分布式锁
- 面试问题:Spring中Bean 的生命周期
- 人脸识别dlib库 记录
- 历史首次!中国联通、中国电信组队了,只为达成这个目的
- 基于heartbeat v1配置mysql和httpd的高可用双主模型
- Intel altera opencl 入门
- 全国大学生英语竞赛——题型介绍
- 容错性低是什么意思_王者荣耀:在成为高手之前,这4位容错率低的千万别碰!...
- 安利一款免费、开源、实时的服务器监控工具:Netdata
- DOC与DOCX区别【100字】【原创】
- html函数参数数组遍历,JavaScript foreach遍历数组
- html5 矢量图形插件,HTML5画布矢量图形?
- 【安装记录】深度学习电脑配置
- 批处理建立网页快捷方式
- 二鱼和我,武汉,黑客马拉松
- 测试32:chemistry
热门文章
- pythonide的作用_Linux程序员宝典:2020年10款出色的Python IDE!
- linux命令fsck和fcsk,在ubuntu中shutdown和reboot的各参数的作用是什么? | 星尘
- shell 循环删除进程
- All Of ACM
- Teamviewer 手机端怎么使用右键-已解决
- 一些常见的HTTP的请求状态码
- hotspot虚拟机的调试
- Hbuilder开发app实战-识岁06-face++的js实现【完结】
- LED 将为我闪烁: 控帘 j发光二级管
- HTTP 协议演示——HTTP 协议概述(3-5)