从零基础入门Tensorflow2.0 ----八、41. estimator分布式实战
every blog every motto:
0. 前言
以fashion_mnist 为例,estimator分布式实战,针对一机多卡的情况。
1. 代码部分
1. 导入模块
import matplotlib as mpl
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import sklearn
import pandas as pd
import os
import sys
import time
import tensorflow as tf
from tensorflow import keras# os.environ['CUDA_VISIBLE_DEVICES'] = '/gpu:0'
print(tf.__version__)
print(sys.version_info)
for module in mpl,np,pd,sklearn,tf,keras:print(module.__name__,module.__version__)
2. GPU设置
查看GPU是否可用
tf.test.is_gpu_available()
tf.debugging.set_log_device_placement(True) # 查看变量分布在哪个GPU上
gpus = tf.config.experimental.list_physical_devices('GPU') # 获取物理GPU
print(gpus)# 设置选中的(最后一个)GPU可见
tf.config.experimental.set_visible_devices(gpus[-1],'GPU')for gpu in gpus: # 物理GPU 设置成自增长tf.config.experimental.set_memory_growth(gpu,True)print(len(gpus))
print('='*10)
logical_gpus = tf.config.experimental.list_logical_devices('GPU') # 获取逻辑GPU
print(len(logical_gpus))
3. 数据读取与处理
3.1 读取
fashion_mnist = keras.datasets.fashion_mnist
# print(fashion_mnist)
(x_train_all,y_train_all),(x_test,y_test) = fashion_mnist.load_data()
x_valid,x_train = x_train_all[:5000],x_train_all[5000:]
y_valid,y_train = y_train_all[:5000],y_train_all[5000:]
# 打印格式
print(x_valid.shape,y_valid.shape)
print(x_train.shape,y_train.shape)
print(x_test.shape,y_test.shape)
3.2 数据归一化
# 数据归一化
from sklearn.preprocessing import StandardScalerscaler = StandardScaler()
# x_train:[None,28,28] -> [None,784]
x_train_scaled = scaler.fit_transform(x_train.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_valid_scaled = scaler.transform(x_valid.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
x_test_scaled = scaler.transform(x_test.astype(np.float32).reshape(-1,1)).reshape(-1,28,28,1)
3.3 生成dataset
# 生成dataset
def make_dataset(images,labels,epochs,batch_size,shuffle=True):dataset = tf.data.Dataset.from_tensor_slices((images,labels))if shuffle:dataset = dataset.shuffle(10000)dataset = dataset.repeat(epochs).batch(batch_size).prefetch(50)return datasetbatch_size = 256
epochs = 100
train_dataset = make_dataset(x_train_scaled,y_train,epochs,batch_size)
4. 构建模型与分布式
estimator分布式
# tf.keras.models.Sequential()
# 构建模型
model = keras.models.Sequential()# 卷积神经网络
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding="same",activation='relu',input_shape=(28,28,1)))
model.add(keras.layers.Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=256,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding="same",activation='relu'))
model.add(keras.layers.Conv2D(filters=512,kernel_size=3,padding='same',activation='relu'))
model.add(keras.layers.MaxPool2D(pool_size=2))# 展平
model.add(keras.layers.Flatten())# 全连接层
model.add(keras.layers.Dense(512,activation='relu'))# 输出层
model.add(keras.layers.Dense(10,activation="softmax"))#
model.compile(loss='sparse_categorical_crossentropy',optimizer='sgd',metrics=['accuracy'])# 分布式
strategy = tf.distribute.MirroredStrategy()
config = tf.estimator.RunConfig(train_distribute=strategy)# 转换成estimator
estimator = keras.estimator.model_to_estimator(model,config=config)
model.summary()
5. 训练
estimator.train(input_fn = lambda : make_dataset(x_train_scaled,y_train,epochs,batch_size),max_steps=5000)
6. 学习曲线
# 画图
def plot_learning_curves(history):pd.DataFrame(history.history).plot(figsize=(8,5))plt.grid(True)plt.gca().set_ylim(0,1)plt.show()
plot_learning_curves(history)# 损失函数,刚开始下降慢的原因
# 1. 参数众多,训练不充分
# 2. 梯度消失 -》 链式法则中
# 解决: selu缓解梯度消失
7. 测试集上
model.evaluate(x_test_scaled,y_test)
8. Tensorboard中查看
1. 切换目录
上面第四步结束以后,会显示
在终端(win+r -> cmd)中切换到对应目录,如上图所示。
8.2 查看
在终端输入
tensorboard --logdir=tmph92v32wg
将上面显示的网址复制到浏览器打开
从零基础入门Tensorflow2.0 ----八、41. estimator分布式实战相关推荐
- 视频编码零基础入门(0):零基础,史上最通俗视频编码技术入门
[来源申明]本文引用了微信公众号"鲜枣课堂"的<视频编码零基础入门>文章内容.为了更好的内容呈现,即时通讯网在引用和收录时内容有改动,转载时请注明原文来源信息,尊重原作 ...
- SQL零基础入门学习(八)
SQL零基础入门学习(七) SQL 连接(JOIN) SQL join 用于把来自两个或多个表的行结合起来. 下图展示了 LEFT JOIN.RIGHT JOIN.INNER JOIN.OUTER J ...
- 【全套】Android零基础入门教程(知识精讲+强化实战)
在目前的IT行业中,Android开发相关的人才需求量依旧不减,尤其是高级的Android架构师是非常吃香的. 关于安卓如何学习,如何get正确的学习姿势?这篇文章主要分享的是安卓开发的基础内容和学习 ...
- C语言零基础入门习题(八)四则运算
前言 C语言是大多数小白走上程序员道路的第一步,在了解基础语法后,你就可以来尝试解决以下的题目.放心,本系列的文章都对新手非常友好. Tips:题目是英文的,但我相信你肯定能看懂 一.四则运算 题目 ...
- Apache Flink 零基础入门(十八)Flink Table APISQL
什么是Flink关系型API? 虽然Flink已经支持了DataSet和DataStream API,但是有没有一种更好的方式去编程,而不用关心具体的API实现?不需要去了解Java和Scala的具体 ...
- 指针01 - 零基础入门学习C语言41
第八章:指针01 让编程改变世界 Change the world by program 指针啥玩意?似乎很神秘? 指针是C语言中的一个重要的概念,也是C语言的一个重要特色. 正确而灵活地运用它,可以 ...
- Apache Flink 零基础入门(十四)Flink 分布式缓存
Apache Flink 提供了一个分布式缓存,类似于Hadoop,用户可以并行获取数据. 通过注册一个文件或者文件夹到本地或者远程HDFS等,在getExecutionEnvironment中指定一 ...
- 【深度学习时间序列预测案例】零基础入门经典深度学习时间序列预测项目实战(附代码+数据集+原理介绍)
- SQL零基础入门学习(九)
SQL零基础入门学习(八) SQL UNION 操作符 UNION 操作符用于合并两个或多个 SELECT 语句的结果集. 请注意,UNION 内部的每个 SELECT 语句必须拥有相同数量的列.列也 ...
- 0基础能学漫画么?漫画零基础入门教程!
漫画零基础入门教程!很多人都喜欢看动漫,同时也会幻想成为动漫里的主角,与此同时也会诞生学漫画的想法.不论是你真的想学习漫画,又或出于个人爱好,或职业需要,或为了具备一项自己喜欢的看家本领.我们都要先清 ...
最新文章
- Form 去掉使用格式掩码带来的多余字符
- 菜鸟网络 | 寄件业务的产品逻辑
- 用反射写的取属性值和设置属性值得方法
- CTF之一次曲折获取Flag的过程
- SAP Spartacus 4.0 ng serve 之后,localhost 4200 会后面自动添上 electronics-spa 吗?
- .NET Core WEB API中接口参数的模型绑定的理解
- CSS链接四种状态注意顺序、UI伪类选择器的顺序
- CodeForces 214B Hometask
- 【源码阅读】看Spring Boot如何自动装配ActiveMQ收发组件
- 简单的LRU Cache设计与实现
- 统信UOS家庭版使用体验
- 移动硬盘插入提示需要格式化RAW_学会自己判断移动硬盘故障!如何在保数据的情况下进行正确处理!...
- 车辆出险保险索赔技巧——让每个车友都能学习
- 计算机工作表中按升序排列,计算机文化基础上机指导
- SSO(Single Sign On)系列(二)--SSO原理
- 听云短信接口安全测试,你的短信接口到底有多危险,可能瞬间损失过万,短信接口防盗刷测试
- 5GNR漫谈9:PDSCH和PUSCH资源映射(频域type0/type1和时域typeA/typeB/typeC)
- DPU芯片企业中科驭数加入龙蜥社区,构建异构算力生态
- DAS、NAS、SAN三种存储架构
- Fiddler抓包6-get请求(url详解)