基于tensorflow框架的神经网络结构处理mnist数据集
一、构建计算图
- 准备训练数据
- 定义前向计算过程 Inference
- 定义loss(loss,accuracy等scalar用tensorboard展示)
- 定义训练方法
- 变量初始化
- 保存计算图
二、创建会话
- summary对象处理
- 喂入数据,得到观测的loss,accuracy等
- 用测试数据测试模型
import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
os.environ['TF_CPP_MIN_LOG_LEVEl']='3'
#
tf.reset_default_graph()
mnist = input_data.read_data_sets('D:\MyData\zengxf\.keras\datasets\MNIST_data',one_hot=True)
xq,yq = mnist.train.next_batch(2)# (2, 784),(2,10)# Inputh1 = 100
h2 = 10with tf.name_scope("Input"):X = tf.placeholder("float",[None,784],name='X')Y_true = tf.placeholder("float",[None,10],name='Y_true')
with tf.name_scope("Inference"):with tf.name_scope("hidden1"):W1 = tf.Variable(tf.random_normal([784, h1])*0.1, name='W1')# W1 = tf.Variable(tf.zeros([784, h1]), name='W1')b1 = tf.Variable(tf.zeros([h1]), name='b1')y_1 = tf.nn.sigmoid(tf.matmul(X, W1)+b1)# (None,h1)with tf.name_scope("hidden2"):W2 = tf.Variable(tf.random_normal([h1, h2])*0.1, name='W2')# W2 = tf.Variable(tf.zeros([h1, h2]), name='W2')b2 = tf.Variable(tf.zeros([h2]), name='b2')y_2 = tf.nn.sigmoid(tf.matmul(y_1, W2)+b2)# (h1,h2)with tf.name_scope("Output"):W3 = tf.Variable(tf.truncated_normal([h2, 10])*0.1, name='W3')# W3 = tf.Variable(tf.zeros([h2, 10]), name='W3')b3 = tf.Variable(tf.zeros([10]), name='b3')y = tf.nn.softmax(tf.matmul(y_2, W3)+ b3)# (None,10)
with tf.name_scope("Loss"):loss = tf.reduce_mean(-tf.reduce_sum(tf.multiply(Y_true,tf.log(y))))loss_scalar = tf.summary.scalar('loss',loss)accuracy = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y,1),tf.argmax(Y_true,1)),tf.float32))accuracy_scalar = tf.summary.scalar('accuracy', accuracy)# loss = tf.reduce_mean(-tf.reduce_sum(Y_true * tf.log(y)))# l = tf.multiply(Y_true,tf.log(y))with tf.name_scope("Trian"):# optimizer = tf.train.AdamOptimizer(learning_rate=0.05)optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01)train_op = optimizer.minimize(loss)init = tf.global_variables_initializer()
merge_summary_op = tf.summary.merge_all()
# tf.merge_all_summaries()
writer = tf.summary.FileWriter('logs', tf.get_default_graph())
sess = tf.Session()
sess.run(init)#
for step in range(5000):# train_summary = sess.run(merge_summary_op,feed_dict = {...})#调用sess.run运行图,生成一步的训练过程数据train_x,train_y = mnist.train.next_batch(500)_,summary_op,train_loss,acc= sess.run([train_op,merge_summary_op,loss,accuracy],feed_dict={X:train_x,Y_true:train_y})if step%100==99:print('loss=',train_loss)# print(sess.run(y,feed_dict={X:train_x}))# summary_op = sess.run(merge_summary_op)writer.add_summary(summary_op,step)
#测试集上预测
print(sess.run(accuracy, feed_dict={X: mnist.test.images, Y_true: mnist.test.labels})) # 0.9185
writer.close()
# #正确的预测结果
# correct_prediction = tf.equal(tf.argmax(Y_true, 1), tf.argmax(y, 1))
# # 计算预测准确率,它们都是Tensor
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# # 在Session中运行Tensor可以得到Tensor的值
# # 这里是获取最终模型的正确率
基于tensorflow框架的神经网络结构处理mnist数据集相关推荐
- TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%)
TF之GD:基于tensorflow框架搭建GD算法利用Fashion-MNIST数据集实现多分类预测(92%) 目录 输出结果 实现代码 输出结果 Successfully downloaded t ...
- DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测
DL之LSTM:基于tensorflow框架利用LSTM算法对气温数据集训练并回归预测 目录 输出结果 核心代码 输出结果 数据集 tensorboard可视化 iter: 0 loss: 0.010 ...
- 低耗时、高精度,微软提基于半监督学习的神经网络结构搜索算法
作者 | 罗人千.谭旭.王蕊.秦涛.陈恩红.刘铁岩 来源 | 微软研究院AI头条(ID:MSRAsia) 编者按:近年来,神经网络结构搜索(Neural Architecture Search, NA ...
- 低耗时、高精度,微软提出基于半监督学习的神经网络结构搜索算法 SemiNAS
编者按:近年来,神经网络结构搜索(Neural Architecture Search, NAS)取得了较大的突破,但仍然面临搜索耗时及搜索结果不稳定的挑战.为此,微软亚洲研究院机器学习组提出了基于半 ...
- CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别
CV之CNN:基于tensorflow框架采用CNN(改进的AlexNet,训练/评估/推理)卷积神经网络算法实现猫狗图像分类识别 目录 基于tensorflow框架采用CNN(改进的AlexNet, ...
- DL之DNN:基于Tensorflow框架对神经网络算法进行参数初始化的常用九大函数及其使用案例
DL之DNN:基于Tensorflow框架对神经网络算法进行参数初始化的常用九大函数及其使用案例 目录 基于Tensorflow框架对神经网络算法进行初始化的常用函数及其使用案例 1.初始化的常用函数
- TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测
TF之LSTM:基于tensorflow框架自定义LSTM算法实现股票历史(1990~2015数据集,6112预测后100+单变量最高)行情回归预测 目录 输出结果 LSTM代码 输出结果 数据集 L ...
- CV之YOLOv3:基于Tensorflow框架利用YOLOv3算法对热播新剧《庆余年》实现目标检测
CV之YOLOv3:基于Tensorflow框架利用YOLOv3算法对热播新剧<庆余年>实现目标检测 目录 搭建 1.下载代码 2.安装依赖库 3.导出COCO权重解压到checkpoin ...
- TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的.pb文件
TF之TFOD-API:基于tensorflow框架利用TFOD-API脚本文件将YoloV3训练好的.ckpt模型文件转换为推理时采用的frozen_inference_graph.pb文件 目录 ...
最新文章
- 如何在JavaScript中实现链接列表
- 「技术综述」一文道尽传统图像降噪方法
- 网络营销外包——网络营销外包专员如何做好网站锚文本优化?
- 一篇文章带你详解 TCP/IP 协议(下)
- 运维少年系列 python and cisco (1)
- 说一下对象或数组转JSON怎么转【fastjson】
- C#的多线程机制探索6
- linux中cat监控,Linux基本命令——cat、rev、head、tail
- php获取cpu编码,PHP下通过exec获得计算机的唯一标识[CPU,网卡 MAC地址]
- Linux安装mysql(解决E: Package ‘mysql-server‘ has no installation candidate与ERROR 1698 (28000))
- 【Luogu1182】数列分段Section II(二分)
- RabbitMQ延迟消息队列实现定时任务完整代码示例
- VBA连接MySQL数据库以及ODBC的配置(ODBC版本和MySQL版本如果不匹配会出现驱动和应用程序的错误)...
- 服务器2003设置共享文件夹共享文件夹,WinServer2003 文件夹共享 方法设置
- 汇总 | 嵌入式软硬件领域各种“黑科技”
- OpenStack 2015年度总结
- 计算机提示无法验证发布者,win10 ie11提示由于无法验证发布者所以windows已经阻止此软件怎么办...
- 微商城分销系统开发方式需求与价格开发周期评估
- CAN总线通信学习笔记
- 快递单号查询免费api接口(PHP示例)
热门文章
- SQLite中的高级SQL
- 数据结构与算法(C++)– 动态规划(Dynamic Programming)
- 【Python】Python“表情包”工具包真好用
- 【CV论文解读】AAAI2021 | 在图卷积网络中超越低频信息
- 现在的计算机专业(比如机器学习)已经沦为调包专业了吗?
- 【职场】程序员摆地摊都能月入过万,是真的吗?
- Yoshua Bengio等图神经网络的新基准Benchmarking Graph Neural Networks(代码已开源)
- 复现经典:《统计学习方法》第 7 章 支持向量机
- ML 自学者周刊:第 3 期
- 推荐一个python学习的宝库(github的star数71000+)