一、构建计算图

  1. 准备训练数据
  2. 定义前向计算过程 Inference
  3. 定义loss(loss,accuracy等scalar用tensorboard展示)
  4. 定义训练方法
  5. 变量初始化
  6. 保存计算图

二、创建会话

  1. summary对象处理
  2. 喂入数据,得到观测的loss,accuracy等
  3. 用测试数据测试模型
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
os.environ['TF_CPP_MIN_LOG_LEVEl']='3'
#
tf.reset_default_graph()
mnist = input_data.read_data_sets('D:\MyData\zengxf\.keras\datasets\MNIST_data',one_hot=True)
xq,yq = mnist.train.next_batch(2)# (2, 784),(2,10)# Inputh1 = 100
h2 = 10with tf.name_scope("Input"):X = tf.placeholder("float",[None,784],name='X')Y_true = tf.placeholder("float",[None,10],name='Y_true')
with tf.name_scope("Inference"):with tf.name_scope("hidden1"):W1 = tf.Variable(tf.random_normal([784, h1])*0.1, name='W1')# W1 = tf.Variable(tf.zeros([784, h1]), name='W1')b1 = tf.Variable(tf.zeros([h1]), name='b1')y_1 = tf.nn.sigmoid(tf.matmul(X, W1)+b1)# (None,h1)with tf.name_scope("hidden2"):W2 = tf.Variable(tf.random_normal([h1, h2])*0.1, name='W2')# W2 = tf.Variable(tf.zeros([h1, h2]), name='W2')b2 = tf.Variable(tf.zeros([h2]), name='b2')y_2 = tf.nn.sigmoid(tf.matmul(y_1, W2)+b2)# (h1,h2)with tf.name_scope("Output"):W3 = tf.Variable(tf.truncated_normal([h2, 10])*0.1, name='W3')# W3 = tf.Variable(tf.zeros([h2, 10]), name='W3')b3 = tf.Variable(tf.zeros([10]), name='b3')y = tf.nn.softmax(tf.matmul(y_2, W3)+ b3)# (None,10)
with tf.name_scope("Loss"):loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(Y_true,tf.log(y))))loss_scalar = tf.summary.scalar('loss',loss)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(Y_true,1)),tf.float32))accuracy_scalar = tf.summary.scalar('accuracy', accuracy)# loss = tf.reduce_mean(-tf.reduce_sum(Y_true * tf.log(y)))# l = tf.multiply(Y_true,tf.log(y))with tf.name_scope("Trian"):# optimizer = tf.train.AdamOptimizer(learning_rate=0.05)optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)train_op = optimizer.minimize(loss)init = tf.global_variables_initializer()
merge_summary_op = tf.summary.merge_all()
# tf.merge_all_summaries()
writer = tf.summary.FileWriter('logs', tf.get_default_graph())
sess = tf.Session()
sess.run(init)#
for step in range(5000):# train_summary = sess.run(merge_summary_op,feed_dict =  {...})#调用sess.run运行图,生成一步的训练过程数据train_x,train_y = mnist.train.next_batch(500)_,summary_op,train_loss,acc= sess.run([train_op,merge_summary_op,loss,accuracy],feed_dict={X:train_x,Y_true:train_y})if step%100==99:print('loss=',train_loss)# print(sess.run(y,feed_dict={X:train_x}))# summary_op = sess.run(merge_summary_op)writer.add_summary(summary_op,step)
#测试集上预测
print(sess.run(accuracy, feed_dict={X: mnist.test.images, Y_true: mnist.test.labels}))  # 0.9185
writer.close()
# #正确的预测结果
# correct_prediction = tf.equal(tf.argmax(Y_true, 1), tf.argmax(y, 1))
# # 计算预测准确率,它们都是Tensor
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# # 在Session中运行Tensor可以得到Tensor的值
# # 这里是获取最终模型的正确率

基于tensorflow框架的神经网络结构处理mnist数据集相关推荐

  1. TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)

    TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%) 目录 输出结果 实现代码 输出结果 Successfully downloaded t ...

  2. DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测

    DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测 目录 输出结果 核心代码 输出结果 数据集 tensorboard可视化 iter: 0 loss: 0.010 ...

  3. 低耗时、高精度,微软提基于半监督学习的神经网络结构搜索算法

    作者 | 罗人千.谭旭.王蕊.秦涛.陈恩红.刘铁岩 来源 | 微软研究院AI头条(ID:MSRAsia) 编者按:近年来,神经网络结构搜索(Neural Architecture Search, NA ...

  4. 低耗时、高精度,微软提出基于半监督学习的神经网络结构搜索算法 SemiNAS

    编者按:近年来,神经网络结构搜索(Neural Architecture Search, NAS)取得了较大的突破,但仍然面临搜索耗时及搜索结果不稳定的挑战.为此,微软亚洲研究院机器学习组提出了基于半 ...

  5. CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别

    CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别 目录 基于tensorflow框架采用CNN(改进的AlexNet, ...

  6. DL之DNN:基于Tensorflow框架对神经网络算法进行参数初始化的常用九大函数及其使用案例

    DL之DNN:基于Tensorflow框架对神经网络算法进行参数初始化的常用九大函数及其使用案例 目录 基于Tensorflow框架对神经网络算法进行初始化的常用函数及其使用案例 1.初始化的常用函数

  7. TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测

    TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测 目录 输出结果 LSTM代码 输出结果 数据集 L ...

  8. CV之YOLOv3:基于Tensorflow框架利用YOLOv3算法对热播新剧《庆余年》实现目标检测

    CV之YOLOv3:基于Tensorflow框架利用YOLOv3算法对热播新剧<庆余年>实现目标检测 目录 搭建 1.下载代码 2.安装依赖库 3.导出COCO权重解压到checkpoin ...

  9. TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的.pb文件

    TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的frozen_inference_graph.pb文件 目录 ...

最新文章

  1. 如何在JavaScript中实现链接列表
  2. 「技术综述」一文道尽传统图像降噪方法
  3. 网络营销外包——网络营销外包专员如何做好网站锚文本优化?
  4. 一篇文章带你详解 TCP/IP 协议(下)
  5. 运维少年系列 python and cisco (1)
  6. 说一下对象或数组转JSON怎么转【fastjson】
  7. C#的多线程机制探索6
  8. linux中cat监控,Linux基本命令——cat、rev、head、tail
  9. php获取cpu编码,PHP下通过exec获得计算机的唯一标识[CPU,网卡 MAC地址]
  10. Linux安装mysql(解决E: Package ‘mysql-server‘ has no installation candidate与ERROR 1698 (28000))
  11. 【Luogu1182】数列分段Section II(二分)
  12. RabbitMQ延迟消息队列实现定时任务完整代码示例
  13. VBA连接MySQL数据库以及ODBC的配置(ODBC版本和MySQL版本如果不匹配会出现驱动和应用程序的错误)...
  14. 服务器2003设置共享文件夹共享文件夹,WinServer2003 文件夹共享 方法设置
  15. 汇总 | 嵌入式软硬件领域各种“黑科技”
  16. OpenStack 2015年度总结
  17. 计算机提示无法验证发布者,win10 ie11提示由于无法验证发布者所以windows已经阻止此软件怎么办...
  18. 微商城分销系统开发方式需求与价格开发周期评估
  19. CAN总线通信学习笔记
  20. 快递单号查询免费api接口(PHP示例)

热门文章

  1. SQLite中的高级SQL
  2. 数据结构与算法(C++)– 动态规划(Dynamic Programming)
  3. 【Python】Python“表情包”工具包真好用
  4. 【CV论文解读】AAAI2021 | 在图卷积网络中超越低频信息
  5. 现在的计算机专业(比如机器学习)已经沦为调包专业了吗?
  6. 【职场】程序员摆地摊都能月入过万,是真的吗?
  7. Yoshua Bengio等图神经网络的新基准Benchmarking Graph Neural Networks(代码已开源)
  8. 复现经典:《统计学习方法》第 7 章 支持向量机
  9. ML 自学者周刊:第 3 期
  10. 推荐一个python学习的宝库(github的star数71000+)