every blog every motto:

0. 前言

以fashion_mnist 为例,estimator分布式实战,针对一机多卡的情况。

1. 代码部分

1. 导入模块

import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras# os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0'
print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)

2. GPU设置

查看GPU是否可用

tf.test.is_gpu_available()
tf.debugging.set_log_device_placement(True) # 查看变量分布在哪个GPU上
gpus = tf.config.experimental.list_physical_devices('GPU') # 获取物理GPU
print(gpus)# 设置选中的(最后一个)GPU可见
tf.config.experimental.set_visible_devices(gpus[-1],'GPU')for gpu in gpus: # 物理GPU 设置成自增长tf.config.experimental.set_memory_growth(gpu,True)print(len(gpus))
print('='*10)
logical_gpus = tf.config.experimental.list_logical_devices('GPU') # 获取逻辑GPU
print(len(logical_gpus))

3. 数据读取与处理

3.1 读取

fashion_mnist = keras.datasets.fashion_mnist
# print(fashion_mnist)
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 打印格式
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)

3.2 数据归一化

# 数据归一化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
# x_train:[None,28,28] -> [None,784]
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)

3.3 生成dataset

# 生成dataset
def make_dataset(images,labels,epochs,batch_size,shuffle=True):dataset = tf.data.Dataset.from_tensor_slices((images,labels))if shuffle:dataset = dataset.shuffle(10000)dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)return datasetbatch_size = 256
epochs = 100
train_dataset = make_dataset(x_train_scaled,y_train,epochs,batch_size)

4. 构建模型与分布式

estimator分布式

# tf.keras.models.Sequential()
# 构建模型
model = keras.models.Sequential()# 卷积神经网络
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding="same",activation='relu',input_shape=(28,28,1)))
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))# 展平
model.add(keras.layers.Flatten())# 全连接层
model.add(keras.layers.Dense(512,activation='relu'))# 输出层
model.add(keras.layers.Dense(10,activation="softmax"))#
model.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])# 分布式
strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)# 转换成estimator
estimator = keras.estimator.model_to_estimator(model,config=config)
model.summary()

5. 训练

estimator.train(input_fn = lambda : make_dataset(x_train_scaled,y_train,epochs,batch_size),max_steps=5000)

6. 学习曲线

# 画图
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8,5))plt.grid(True)plt.gca().set_ylim(0,1)plt.show()
plot_learning_curves(history)# 损失函数,刚开始下降慢的原因
# 1. 参数众多,训练不充分
# 2. 梯度消失 -》 链式法则中
# 解决: selu缓解梯度消失

7. 测试集上

model.evaluate(x_test_scaled,y_test)

8. Tensorboard中查看

1. 切换目录

上面第四步结束以后,会显示

在终端(win+r -> cmd)中切换到对应目录,如上图所示。

8.2 查看

在终端输入

tensorboard --logdir=tmph92v32wg


将上面显示的网址复制到浏览器打开

从零基础入门Tensorflow2.0 ----八、41. estimator分布式实战相关推荐

  1. 视频编码零基础入门(0):零基础,史上最通俗视频编码技术入门

    [来源申明]本文引用了微信公众号"鲜枣课堂"的<视频编码零基础入门>文章内容.为了更好的内容呈现,即时通讯网在引用和收录时内容有改动,转载时请注明原文来源信息,尊重原作 ...

  2. SQL零基础入门学习(八)

    SQL零基础入门学习(七) SQL 连接(JOIN) SQL join 用于把来自两个或多个表的行结合起来. 下图展示了 LEFT JOIN.RIGHT JOIN.INNER JOIN.OUTER J ...

  3. 【全套】Android零基础入门教程(知识精讲+强化实战)

    在目前的IT行业中,Android开发相关的人才需求量依旧不减,尤其是高级的Android架构师是非常吃香的. 关于安卓如何学习,如何get正确的学习姿势?这篇文章主要分享的是安卓开发的基础内容和学习 ...

  4. C语言零基础入门习题(八)四则运算

    前言 C语言是大多数小白走上程序员道路的第一步,在了解基础语法后,你就可以来尝试解决以下的题目.放心,本系列的文章都对新手非常友好. Tips:题目是英文的,但我相信你肯定能看懂 一.四则运算 题目 ...

  5. Apache Flink 零基础入门(十八)Flink Table APISQL

    什么是Flink关系型API? 虽然Flink已经支持了DataSet和DataStream API,但是有没有一种更好的方式去编程,而不用关心具体的API实现?不需要去了解Java和Scala的具体 ...

  6. 指针01 - 零基础入门学习C语言41

    第八章:指针01 让编程改变世界 Change the world by program 指针啥玩意?似乎很神秘? 指针是C语言中的一个重要的概念,也是C语言的一个重要特色. 正确而灵活地运用它,可以 ...

  7. Apache Flink 零基础入门(十四)Flink 分布式缓存

    Apache Flink 提供了一个分布式缓存,类似于Hadoop,用户可以并行获取数据. 通过注册一个文件或者文件夹到本地或者远程HDFS等,在getExecutionEnvironment中指定一 ...

  8. 【深度学习时间序列预测案例】零基础入门经典深度学习时间序列预测项目实战(附代码+数据集+原理介绍)

  9. SQL零基础入门学习(九)

    SQL零基础入门学习(八) SQL UNION 操作符 UNION 操作符用于合并两个或多个 SELECT 语句的结果集. 请注意,UNION 内部的每个 SELECT 语句必须拥有相同数量的列.列也 ...

  10. 0基础能学漫画么?漫画零基础入门教程!

    漫画零基础入门教程!很多人都喜欢看动漫,同时也会幻想成为动漫里的主角,与此同时也会诞生学漫画的想法.不论是你真的想学习漫画,又或出于个人爱好,或职业需要,或为了具备一项自己喜欢的看家本领.我们都要先清 ...

最新文章

  1. Form 去掉使用格式掩码带来的多余字符
  2. 菜鸟网络 | 寄件业务的产品逻辑
  3. 用反射写的取属性值和设置属性值得方法
  4. CTF之一次曲折获取Flag的过程
  5. SAP Spartacus 4.0 ng serve 之后,localhost 4200 会后面自动添上 electronics-spa 吗?
  6. .NET Core WEB API中接口参数的模型绑定的理解
  7. CSS链接四种状态注意顺序、UI伪类选择器的顺序
  8. CodeForces 214B Hometask
  9. 【源码阅读】看Spring Boot如何自动装配ActiveMQ收发组件
  10. 简单的LRU Cache设计与实现
  11. 统信UOS家庭版使用体验
  12. 移动硬盘插入提示需要格式化RAW_学会自己判断移动硬盘故障!如何在保数据的情况下进行正确处理!...
  13. 车辆出险保险索赔技巧——让每个车友都能学习
  14. 计算机工作表中按升序排列,计算机文化基础上机指导
  15. SSO(Single Sign On)系列(二)--SSO原理
  16. 听云短信接口安全测试,你的短信接口到底有多危险,可能瞬间损失过万,短信接口防盗刷测试
  17. 5GNR漫谈9:PDSCH和PUSCH资源映射(频域type0/type1和时域typeA/typeB/typeC)
  18. DPU芯片企业中科驭数加入龙蜥社区,构建异构算力生态
  19. DAS、NAS、SAN三种存储架构
  20. Fiddler抓包6-get请求(url详解)

热门文章

  1. java 字符串和整型的相互转换
  2. 01. Django基础:Django介绍
  3. Navicat:navicat数据库始终保持连接的方法
  4. 实战Node—幼教平台项目重构和优化
  5. JavaScript:递归实现深拷贝
  6. 实战Javascript:结合电商主界面实现轮播图和倒计时秒杀
  7. 实战CSS:苏宁商城静态实现
  8. Spring Boot+Vue从零开始搭建博客系统veblog(一):项目前端_vuejs环境搭建
  9. Jmeter之app性能测试(ios,android)
  10. ubuntu18.04安装ros-melodic