基础Estimator

#--coding:utf-8--
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)
mnist = input_data.read_data_sets()#指定网络输入,所有这里指定的输入都会拼接起来作为整个网络的输入
feature_cloumns = [tf.feature_column.numeric_column("image", shape=[784])]"""
#通过Tensorflow提供的封装好的Estimator定义网络模型。Arguments:features_cloumns:神经网络输入层需要的数据hidden_units:神经网络的结构 注意 DNNClassifier只能定义多层全连接神经网络 而hidden则给出了每一层隐藏层的节点个数n_classes:总共类目的数目optimizer:所使用的优化函数model_dir:将训练过程中loss的变化以及一些其他指标保存到此目录,通过TensorBoard可以可视化
"""
estimator = tf.estimator.DNNClassifier(feature_columns=feature_cloumns,hidden_units=[500],n_classes=10,optimizer=tf.train.AdamOptimizer(),model_dir="~~"
)train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image":mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True
)#训练模型 注意 此处没有定义损失函数 ,通过DNN定义的模型会使用交叉上作为损失函数
estimator.train(input_fn=train_input_fn, steps=10000)#定义测试时的数据输入
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image":mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False
)accuracy_score = estimator.evaluate(input_fn=test_input_fn)["accuracy"]
print("\nTest accuracy: %g %%" %(accuracy_score * 100))

自定义Estimator

# --coding:utf-8--
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_datatf.logging.set_verbosity(tf.logging.INFO)# 通过tf.layers来定义模型结构。可以使用原生态tf api或者其他高层封装。
def lenet(x, is_training):x = tf.reshape(x, shape=[-1, 28, 28, 1])net = tf.layers.conv2d(x, 32, 5, activation=tf.nn.relu)net = tf.layers.max_pooling2d(net, 2, 2)net = tf.layers.conv2d(net, 64, 3, activation=tf.nn.relu)net = tf.layers.max_pooling2d(net, 2, 2)net = tf.contrib.layers.flatten(net)net = tf.layers.dense(net, 1024)net = tf.layers.dropout(net, rate=0.4, training=is_training)return tf.layers.dense(net, 10)"""
#自定义estimator中使用的模型。Arguments:features:输入函数中会提供的输入层张亮。这是一个字典,字典里的内容是通过tf.estimator.inputs.numpy_input_fn中x参数的内容指定的。label:正确分类标签,这个字段的内容是通过numpy_input_fn中y参数给出,mode:train/evaluate/predictparams:字典  超参数
"""
def model_fn(featuers, labels, mode, params):predict = lenet(featuers["image"], mode == tf.estimator.ModeKeys.TRAIN)#如果在预测模式 只需要将结果返回if mode == tf.estimator.ModeKeys.PREDICT:#使用EstimatorSpec传递返回值,并通过predictions参数指定返回的结果return tf.estimator.EstimatorSpec(mode = mode, predictions={"result":tf.argmax(predict, 1)})#定义损失loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=predict, labels=labels))optimizer = tf.train.GradientDescentOptimizer(learning_rate=params["learning_rate"])#定义训练过程train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())#定义评测标准eval_metric_ops = {"my_metric": tf.metrics.accuracy(tf.argmax(predict, 1), labels)}#返回模型训练过程需要使用的损失函数、训练过程和评测方法return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops)mnist = input_data.read_data_sets("/path/to/MNIST_data", one_hot=False)#通过自定义的方式生成Esttimator类,这里需要提供模型定义的函数并通过params参数指定模型定义时使用的超参数
model_params = {"learning_rate": 0.01}
estimator = tf.estimator.Estimator(model_fn=model_fn, params=model_params)#训练
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.train.images},y=mnist.train.labels.astype(np.int32),num_epochs=None,batch_size=128,shuffle=True
)
estimator.train(input_fn=train_input_fn, steps=30000)
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images},y=mnist.test.labels.astype(np.int32),num_epochs=1,batch_size=128,shuffle=False
)
test_results = estimator.evaluate(input_fn=test_input_fn)#这里的my_metric中的内容就是model_fn中eval_metric_ops定义的评测指标
accuracy_score = test_results["my_metric"]
print("\nTest accuracy: %g %%" % (accuracy_score * 100))predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"image": mnist.test.images[:10]},num_epochs=1,shuffle=False
)
predictions = estimator.predict(input_fn=predict_input_fn)for i, p in enumerate(predictions):print("Prediction %s: %s" % (i + 1, p["result"]))

基础Estimator以及自定义Estimator相关推荐

  1. 创建自定义 Estimator

    ref 本文档介绍了自定义 Estimator.具体而言,本文档介绍了如何创建自定义 Estimator 来模拟预创建的 Estimator DNNClassifier 在解决鸢尾花问题时的行为.要详 ...

  2. 代码实现tan graph model for classification_自定义 Estimator 实现(以BERT为例)

    本文将主要介绍tensorflow 的Estimator 这个高级API,它的主要作用就是提出一个高级范式(paradigm),将模型的训练,验证,预测,以及保存规范起来,免去了tensorflow的 ...

  3. tf.estimator API技术手册(16)——自定义Estimator

    tf.estimator API技术手册(16)--自定义Estimator (一)前 言 (二)自定义estimator的一般步骤 (三)准备训练数据 (四)自定义estimator实践 (1)创建 ...

  4. c语言何编写自定义函数,C语言菜鸟基础教程之自定义函数

    C语言菜鸟基础教程之自定义函数 先动手编写程序: #include int add(int x, int y) { int z = x + y; return z; } int main() { in ...

  5. android arrayadapter自定义,Android零基础入门|自定义ArrayAdapter

    原标题:Android零基础入门|自定义ArrayAdapter ListView用起来还是比较简单的,也是Android应用程序中最重要的一个组件,但其他ListView可以随你所愿,能够完成很多想 ...

  6. bpmn-process-designer基础上进行自定义样式(工具、元素、菜单)

    文章目录 一.自定义工具Palette 例如这里对开始事件工具进行自定义 二.自定义样式Palette和PopupMenu 三.自定义图形元素svg样式(包含了节点和连线) 自定义元素颜色 然后就是元 ...

  7. LTspice基础教程-027.自定义函数;func指令用法

    在LTspice中,我们可以自定义函数.语法如下: .func <name>([args]) {<expression>} func是function的缩写:name是自定义函 ...

  8. Nios II 基础工程和自定义组件

    软件环境:Quartus Prime Standard 18.1 Window 10 硬件环境:小梅哥 AC501 开发板 主要参考: Intel Quartus Prime Standard Edi ...

  9. Python基础教程:自定义迭代器

    本文介绍如何自定义迭代器,涉及到类的运算符重载,包括__getitem__的索引迭代,以及__iter__.__next__和__contains__,如果不了解这些知识可跳过本文. 索引迭代方式 索 ...

最新文章

  1. python操作mongodb进行读写
  2. python真的那么火吗-Python语言为什么这么火?
  3. 基于Lucene查询原理分析Elasticsearch的性能
  4. ubuntu16.04安装ROS
  5. [css] 为什么要使用sass/less?
  6. 什么工作经常出差_商旅人群洞察:什么样的人经常坐飞机出差?
  7. 一次Nginx负载均衡的安装与配置
  8. WebService学习笔记系列(四)
  9. python小白从哪来开始-写给小白的工程师入门 - 从 Python 开始
  10. mysql备份与恢复的一些方法
  11. C#的进度条--progressBar
  12. js返回上一页的实现方法
  13. java小数正负数据类型_Java - day001 - 8种基本数据类型
  14. 菜鸟教程python在线编译器-Python3 教程 | 菜鸟教程
  15. Augustus:真核生物基因结构预测软件-安装篇
  16. 用c语言解参数积分,C语言求定积分的通用函数
  17. cpuid limit_Max CPUID Valut Limit 请懂电脑的解答下 谢谢!
  18. CodeForces 757 E.Bash Plays with Functions(积性函数+dp)
  19. ubuntu20.04基础入门日记V1.0
  20. Android美化menu的小技巧-item菜单项添加标题

热门文章

  1. 面试之ElasticSearch与Solor
  2. 史上最全的mime-type大全
  3. 常用的空间形状相似性计算方法有哪些,它们之间有什么不同
  4. 快速了解靶点预测的方法!
  5. android自定义动画插值器(Interpolator)
  6. 图片存储格式 PNM 以及 PBM/PGM/PPM
  7. 【Python爬虫错误】ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接
  8. Go实战Gin+Vue+微服务打造秒杀商城第五课 gin+vue实战
  9. MogaFX—国际货币基金组织小组与毛里塔尼亚就一项为期三年的延长信贷安排和延长基金安排达成工作人员级协议
  10. 洛谷P3955 [NOIP2017 普及组] 图书管理员