TF:TF下CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+AdamOptimizer算法

目录

输出结果

代码设计


输出结果

后期更新……

代码设计

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
# number 1 to 10 data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)def compute_accuracy(v_xs, v_ys):global predictiony_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})return resultdef weight_variable(shape):      initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial)
def bias_variable(shape):        initial = tf.constant(0.1, shape=shape)      return tf.Variable(initial)def conv2d(x, W):                 return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x):              return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')xs = tf.placeholder(tf.float32, [None, 784]) # 28x28
ys = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
x_image = tf.reshape(xs, [-1, 28, 28, 1])  ## conv1 layer;
W_conv1 = weight_variable([5,5, 1,32])
b_conv1 = bias_variable([32])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)                        W_conv2 = weight_variable([5,5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)                          W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)              W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) # the error between prediction and real data
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)    sess = tf.Session()
# important step
sess.run(tf.global_variables_initializer())for i in range(10):  batch_xs, batch_ys = mnist.train.next_batch(100)sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5}) if i % 50 == 0:print(compute_accuracy(mnist.test.images, mnist.test.labels))

相关文章
TF:TF下CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+AdamOptimizer算法

TF之CNN:CNN实现mnist数据集预测 96%采用placeholder用法+2层C及其max_pool法+隐藏层dropout法+输出层softmax法+目标函数cross_entropy法+相关推荐

  1. Pytorch:手把手教你搭建简单的卷积神经网络(CNN),实现MNIST数据集分类任务

    关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!! 可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行 第一步:基本库的导入 import n ...

  2. DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本

    DL之DCGNN:基于TF利用DCGAN实现在MNIST数据集上训练生成新样本 目录 输出结果 设计思路 实现部分代码 说明:所有图片文件丢失 输出结果 更新-- 设计思路 更新-- 实现部分代码 更 ...

  3. Tensorflow—CNN应用于MNIST数据集分类

    代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_datamnist = input_ ...

  4. 机器学习|卷积神经网络(CNN) 手写体识别 (MNIST)入门

    人工智能,机器学习,监督学习,神经网络,无论哪一个都是非常大的话题,都覆盖到可能就成一本书了,所以这篇文档只会包含在 RT-Thread 物联网操作系统,上面加载 MNIST 手写体识别模型相关的部分 ...

  5. Pytorch 实现全连接神经网络/卷积神经网络训练MNIST数据集,并将训练好的模型在制作自己的手写图片数据集上测试

    使用教程 代码下载地址:点我下载 模型在训练过程中会自动显示训练进度,如果您的pytorch是CPU版本的,代码会自动选择CPU训练,如果有cuda,则会选择GPU训练. 项目目录说明: CNN文件夹 ...

  6. Python实现bp神经网络识别MNIST数据集

    title: "Python实现bp神经网络识别MNIST数据集" date: 2018-06-18T14:01:49+08:00 tags: [""] cat ...

  7. MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试

    文章目录 MATLAB实现基于BP神经网络的手写数字识别+GUI界面+mnist数据集测试 一.题目要求 二.完整的目录结构说明 三.Mnist数据集及数据格式转换 四.BP神经网络相关知识 4.1 ...

  8. Python神经网络识别手写数字-MNIST数据集

    Python神经网络识别手写数字-MNIST数据集 一.手写数字集-MNIST 二.数据预处理 输入数据处理 输出数据处理 三.神经网络的结构选择 四.训练网络 测试网络 测试正确率的函数 五.完整的 ...

  9. 【深度学习的数学】接“2×3×1层带sigmoid激活函数的神经网络感知机对三角形平面的分类训练预测”,输出层加偏置b

    文章目录 代码 接:[深度学习的数学]2×3×1层带sigmoid激活函数的神经网络感知机对三角形平面的分类训练预测(绘制出模型结果三维图展示效果)(梯度下降法+最小二乘法+激活函数sigmoid+误 ...

最新文章

  1. 如何从 100 亿 URL 中找出相同的 URL?
  2. Python 查重,统计重复 排序
  3. 图像混合模式:Android Paint Xfermode 使用和demo
  4. oracle instr函数 收藏
  5. editor修改样式 vue_手摸手Electron + Vue实战教程(三)
  6. 卡通驱动项目ThreeDPoseTracker——关键点平滑方案解析
  7. Python办公自动化,对文件进行自由操作
  8. String.valueOf()方法的使用总结
  9. Kaggle TMDB电影数据分析项目实战
  10. 写Java要用什么编译器最好?
  11. Linux 计划任务crontab详解,含笔试题讲解
  12. 爬虫实现对于百度文库内容的爬取
  13. centos7创建asm磁盘_centos7下安装oracle rac使用udev绑定磁盘方法
  14. ResponseEntity和ResponseBody比较
  15. 如何零基础转行成为一个自信的前端达人
  16. php 中 href,html中href什么意思
  17. QEMU(3) 参数解析
  18. python networkx 边权重_Python/NetworkX:动态计算边权重
  19. 微信小程序案例学习笔记
  20. PHP 短信验证码验证(短信宝)

热门文章

  1. wiki----为用户设置管理员权限
  2. 如何为从1到10万用户的应用程序,设计不同的扩展方案?
  3. 图解:从单个服务器扩展到百万用户的系统
  4. 搭建rabbitmq的docker集群
  5. SpringBoot应用之消息队列rabbitmq
  6. Android --- Retrofit 上传/下载文件扩展实现进度的监听
  7. org.hibernate.transientobjectexception:The given object has a null identifier: com.gxuwz.check.entit
  8. ui和android有联系,Android单位换算与UI适配
  9. springboot 整合mybatis_SpringBoot整合MyBatis框架快速入门
  10. zookeeper是做什么用的_做橱柜用什么门板好 选对很关键