#encoding:utf-8
# -*- coding: utf-8 -*-
#使用说明:1、修改分类数目;2、修改输入图片大小;
# 3、修改是否启用集群; 4、修改batch size大小;5、修改数据路径、模型保存路径
#6、设置是否启用boostrap loss 损失函数
import osimport tensorflow as tf
from input_data import Data_layer
import netnum_class = 2
input_height = 256
input_width = 256
crop_height = 224
crop_width = 224
learning_rate = 0.01
tf.set_random_seed(123)
batch_size = tf.placeholder(tf.int32, [], 'batch_size')
tf.add_to_collection('batch_size', batch_size)
is_training = tf.placeholder(tf.bool, [])
is_boostrap = tf.placeholder(tf.bool, [])
drop_prob = tf.placeholder(tf.float32, [])
tf.add_to_collection('is_training', is_training)
def load_save_model(sess,saver,model_path,is_save):if is_save is False:print "***********restore model from %s***************"%model_pathsaver.restore(sess, model_path)else:saver.save(sess, model_path)
def train_cluster(train_volume_data,valid_volume_data,model_path):tf.flags.DEFINE_string("ps_hosts", "localhost:2222", "Comma-separated list of hostname:port pairs")tf.flags.DEFINE_string("worker_hosts", "localhost:2223,localhost:2224","Comma-separated list of hostname:port pairs")tf.app.flags.DEFINE_string("job_name", "", "Either 'ps' or 'worker'")tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")tf.app.flags.DEFINE_string("volumes", "", "volumes info")FLAGS = tf.app.flags.FLAGSps_hosts = FLAGS.ps_hosts.split(",")print("ps_hosts:", ps_hosts)worker_hosts = FLAGS.worker_hosts.split(",")print("worker_hosts:", worker_hosts)cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts})print("FLAGS.task_index:", FLAGS.task_index)# Create and start a server for the local task.server = tf.train.Server(cluster,job_name=FLAGS.job_name,task_index=FLAGS.task_index)if FLAGS.job_name == "ps":server.join()elif FLAGS.job_name == "worker":with tf.device(tf.train.replica_device_setter(worker_device="/job:worker/task:%d" % FLAGS.task_index,cluster=cluster)):input_data = Data_layer(train_volume_data, valid_volume_data, batch_size=batch_size, image_height=input_height, image_width=input_width, crop_height=crop_height, crop_width=crop_width)images, labels = input_data.get_next_batch(is_training, num_class)#net_worker = net.resnet(images, labels, num_class, 18, is_training, drop_prob)net_worker = net.resnet256(images, labels, num_class, is_training,is_boostrap)saver = tf.train.Saver()init = tf.global_variables_initializer()sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), init_op=init,saver=saver,global_step=net_worker['global_step'])with sv.prepare_or_wait_for_session(server.target) as session:coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(session, coord=coord)#threads = sv.start_queue_runners(session)load_save_model(session, saver, model_path, False)try:for i in range(400000):if i < 10000:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: False}else:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: True}step, _ = session.run([net_worker['global_step'], net_worker['train_op']], feed_dict=train_dict)if i % 500 == 0:train_dict = {batch_size: 32,drop_prob: 1,is_training: True,is_boostrap: False}entropy, train_acc = session.run([net_worker['cross_entropy'], net_worker['accuracy']],feed_dict=train_dict)print('***** {}:{},{} *****'.format(i, entropy, train_acc))if i % 2000 == 0:test_dict = {drop_prob: 1.0,is_training: False,batch_size: 256,is_boostrap:False}acc = session.run(net_worker['accuracy'], feed_dict=test_dict)print('*****locate step {},valid step {}:accuracy {} *****'.format(i,step, acc))if i>3000:print "**************save model***************"load_save_model(session,saver,model_path,True)except Exception, e:coord.request_stop(e)finally:coord.request_stop()coord.join(threads)

深度学习(七十二)tensorflow 集群训练相关推荐

  1. 深度学习(七十二)ssd物体检测

    def ssd_anchor_one_layer(img_shape,feat_shape,sizes,ratios,step,offset=0.5,dtype=np.float32):# 计算每个d ...

  2. Tensorflow深度学习之十二:基础图像处理之二

    Tensorflow深度学习之十二:基础图像处理之二 from:https://blog.csdn.net/davincil/article/details/76598474   首先放出原始图像: ...

  3. 前几帧预测 深度学习_使用深度学习从十二导联心电图预测心律失常

    上集讲到 使用深度学习 从单导联预测房颤 这一集 将继续讨论该问题 单导联心电图 对心律失常的预测作用 非常有限 因为 单导联的信号很有限 临床上需要结合 多导联心电图 判断 心律失常的类型 这一集的 ...

  4. 花书+吴恩达深度学习(十二)卷积神经网络 CNN 之全连接层

    目录 0. 前言 1. 全连接层(fully connected layer) 如果这篇文章对你有一点小小的帮助,请给个关注,点个赞喔~我会非常开心的~ 花书+吴恩达深度学习(十)卷积神经网络 CNN ...

  5. 深度学习(十二)稀疏自编码

    稀疏自编码 原文地址:http://blog.csdn.net/hjimce/article/details/49106869 作者:hjimce 一.相关理论 以前刚开始学CNN的时候,就是通过阅读 ...

  6. 系统学习redis之二——redis集群搭建

    redis单点部署: 安装命令: # cd /usr/local/ # wget http://download.redis.io/releases/redis-4.0.1.tar.gz #下载安装包 ...

  7. 深度学习(十二)——Winograd(2)

    最大公约数和Euclidean algorithm(续) Euclidean algorithm的步骤如下图所示: 1.假设a>ba>ba>b,则令c:=amodbc:=amodbc ...

  8. 深度学习(十二):Matconvnet小试牛刀与提特征

    该节简单介绍一下如何使用Matconvnet的现有的模型进行图像分类实验以及提取图像对应层的特征. 先来看看如何用训练好的imagenet网络模型进行图像的预测,英文版的官网教程就在这里: http: ...

  9. 系统学习深度学习(十二)--池化

    转自:http://blog.csdn.net/danieljianfeng/article/details/42433475 在卷积神经网络中,我们经常会碰到池化操作,而池化层往往在卷积层后面,通过 ...

最新文章

  1. 简单介绍.Net性能测试框架Crank的使用方法
  2. 预训练语言模型(PLM)必读论文清单(附论文PDF、源码和模型链接)
  3. apache rewrite 支持post 数据
  4. php后台如何连接网口打印机_如何设置斑马网络打印机的网卡IP地址
  5. 如何用Redlock实现分布式锁
  6. 面试问题:Spring中Bean 的生命周期
  7. 人脸识别dlib库 记录
  8. 历史首次!中国联通、中国电信组队了,只为达成这个目的
  9. 基于heartbeat v1配置mysql和httpd的高可用双主模型
  10. Intel altera opencl 入门
  11. 全国大学生英语竞赛——题型介绍
  12. 容错性低是什么意思_王者荣耀:在成为高手之前,这4位容错率低的千万别碰!...
  13. 安利一款免费、开源、实时的服务器监控工具:Netdata
  14. DOC与DOCX区别【100字】【原创】
  15. html函数参数数组遍历,JavaScript foreach遍历数组
  16. html5 矢量图形插件,HTML5画布矢量图形?
  17. 【安装记录】深度学习电脑配置
  18. 批处理建立网页快捷方式
  19. 二鱼和我,武汉,黑客马拉松
  20. 测试32:chemistry

热门文章

  1. pythonide的作用_Linux程序员宝典:2020年10款出色的Python IDE!
  2. linux命令fsck和fcsk,在ubuntu中shutdown和reboot的各参数的作用是什么? | 星尘
  3. shell 循环删除进程
  4. All Of ACM
  5. Teamviewer 手机端怎么使用右键-已解决
  6. 一些常见的HTTP的请求状态码
  7. hotspot虚拟机的调试
  8. Hbuilder开发app实战-识岁06-face++的js实现【完结】
  9. LED 将为我闪烁: 控帘 j发光二级管
  10. HTTP 协议演示——HTTP 协议概述(3-5)