import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import keras.backend.tensorflow_backend as KTFdef add_layer(inputs,in_size,out_size,activation_function=None):#Weights是一个矩阵,[行,列]为[in_size,out_size]Weights=tf.Variable(tf.random_normal([in_size,out_size]))#正态分布#初始值推荐不为0,所以加上0.1,一行,out_size列biases=tf.Variable(tf.zeros([1,out_size])+0.1)#Weights*x+b的初始化的值,也就是未激活的值Wx_plus_b=tf.matmul(inputs,Weights)+biases#激活if activation_function is None:#激活函数为None,也就是线性函数outputs=Wx_plus_belse:outputs=activation_function(Wx_plus_b)return outputsdef compute_accuracy(prediction, xs, ys, sess, v_xs,v_ys):y_pre=sess.run(prediction,feed_dict={xs:v_xs})correct_prediction=tf.equal(tf.arg_max(y_pre,1),tf.arg_max(v_ys,1))accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))result=sess.run(accuracy,feed_dict={xs:v_xs,ys:v_ys})return resultdef train_test_mnist():mnist=input_data.read_data_sets('MNIST_data',one_hot=True)# define placeholder for inputs to networks# 不规定有多少个sample,但是每个sample大小为784(28*28)xs=tf.placeholder(tf.float32,[None,784])ys=tf.placeholder(tf.float32,[None,10])#add output layerprediction=add_layer(xs,784,10,activation_function=tf.nn.softmax)#the error between prediction and real datacross_entropy=tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1]))train_strp=tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)config = tf.ConfigProto()config.gpu_options.allow_growth = True  # 不全部占满显存, 按需分配config.gpu_options.per_process_gpu_memory_fraction = 0.6  #限制GPU内存占用率init=tf.global_variables_initializer()sess = tf.Session(config=config)KTF.set_session(sess)  # 设置sessionif True:#with tf.Session() as sess:sess.run(init)for i in range(2000):batch_xs,batch_ys=mnist.train.next_batch(100)sess.run(train_strp,feed_dict={xs:batch_xs,ys:batch_ys})if i%20==0:print("accuracy:", compute_accuracy(prediction, xs, ys, sess, mnist.test.images, mnist.test.labels))def train_test_mnist_visual():#define placeholder for inputs to networkxs=tf.placeholder(tf.float32,[None,64])ys=tf.placeholder(tf.float32,[None,10])#add output layer# l1为隐藏层,为了更加看出overfitting,所以输出给了100l1=add_layer(xs,64,100,'l1',activation_function=tf.nn.tanh)prediction=add_layer(l1,100,10,'l2',activation_function=tf.nn.softmax)def main():train_test_mnist()if __name__ == '__main__':main()

tensorflow mnist 1相关推荐

  1. 逻辑回归实现多分类任务(python+TensorFlow+mnist)

    逻辑回归实现多分类任务(python+TensorFlow+mnist) 逻辑回归是统计学中的一种经典方法,虽然叫回归,但在机器学习领域,逻辑回归通常情况下当成一个分类任务,softmax就是由其演变 ...

  2. Tensorflow mnist 数据集测试代码 + 自己下载数据

    https://blog.csdn.net/weixin_39673686/article/details/81068582 import tensorflow as tf from tensorfl ...

  3. Tensorflow— MNIST数据集分类简单版本

    代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data#载入数据集 #当前路径 m ...

  4. TensorFlow MNIST最佳实践

    之前通过CNN进行的MNIST训练识别成功率已经很高了,不过每次运行都需要消耗很多的时间.在实际使用的时候,每次都要选经过训练后在进行识别那就太不方便了. 所以我们学习一下如何将训练习得的参数保存起来 ...

  5. Tensorflow MNIST浅层神经网络的解释和答复

    本系列文章由 @yhl_leo 出品,转载请注明出处. 文章链接: http://blog.csdn.net/yhl_leo/article/details/51416540 看到之前的一篇博文:深入 ...

  6. TensorFlow MNIST初级学习

    2019独角兽企业重金招聘Python工程师标准>>> MNIST MNIST 是一个入门级计算机视觉数据集,包含了很多手写数字图片,如图所示: 数据集中包含了图片和对应的标注,在 ...

  7. TensorFlow MNIST LeNet 模型持久化

    前向传播过程mnist_inference.py import tensorflow as tf# 定义神经网络相关的参数 INPUT_NODE = 784 OUTPUT_NODE = 10def i ...

  8. TensorFlow MNIST AlexNet

    原始的AlexNet用来处理277*277*3的数据集, 并且采用5层卷积,3层全连接层来处理图像分类. 具体结构和参数信息见 http://blog.csdn.net/chenhaifeng2016 ...

  9. TensorFlow MNIST TensorBoard版本

    代码 import tensorflow as tf import numpy as np from tensorflow.examples.tutorials.mnist import input_ ...

最新文章

  1. 一个爬虫的故事:这是人干的事儿?
  2. React学习实例总结,包含yeoman安装、webpack构建
  3. 用NiceTool在微信浏览器中下载APP
  4. Linux系统编程之进程与线程控制原语对比
  5. 在docker镜像中加入环境变量
  6. python基于web可视化_Python Selenium实现无可视化界面
  7. springmvc mybatis 做分页sql 语句
  8. [转]Groovy和Grails简介
  9. ROS学习手记 - 5 理解ROS中的基本概念_Services and Parameters
  10. 【Linux】SecureCRT中按退格键出现^H
  11. [转载] 【C/C++】Vector的创建与初始化方法
  12. Atitit 编程语言语言规范总结 目录 1. 语言规范 3 2. Types 3 2.1.1. Primitive types 3 2.1.2. Compound types 4 3. State
  13. C语言程序设计第三版微课版,C语言程序设计(第3版 微课版)
  14. 草根站长的创业路:说说这两年的创业经历
  15. cp105b linux 驱动,cp105b驱动下载-富士施乐cp105b驱动下载v2.6.15.0 官方最新版-西西软件下载...
  16. matlab欧式期权定价公式,[转载]期权定价的Matlab实现(以欧式看涨期权为例)
  17. ScrollView 吸顶效果
  18. 在 Microsoft Visual Studio Team System 和 Microsoft Visual SourceSafe 之间选择
  19. Windows10系统变成英文如何切换回中文,Ctrl+Shift无法切换输入法
  20. 手把手带你做一个Python打飞机游戏

热门文章

  1. mysql 表的继承,MySQL是否支持表继承?
  2. Java培训深度学习都要学什么
  3. 软件测试培训需要学习什么
  4. session,cookie,sessionStorage,localStorage的区别及应用场景
  5. AI时代:推荐引擎正在塑造人类
  6. Robotium todolist.test.elements
  7. Dokku和Docker的完美配合
  8. Android系统移植与调试之-------如何修改Android设备添加重启、飞行模式、静音模式等功能(一)...
  9. 39个超实用jQuery实例应用特效
  10. java中List深拷贝的简单实例