Tensorflow是一个非常好用的deep learning框架

学完了cs231n,大概就可以写一个CNN做一下MNIST了

tensorflow具体原理可以参见它的官方文档

然后CNN的原理可以直接学习cs231n的课程。

另外这份代码本地跑得奇慢。。估计用gpu会快很多。

import loaddata
import tensorflow as tf#生成指定大小符合标准差为0.1的正态分布的矩阵
def weight_variable(shape):initial = tf.truncated_normal(shape, stddev = 0.1)return tf.Variable(initial)#生成偏移变量
def bias_variable(shape):initial = tf.constant(0.1, shape=shape)return tf.Variable(initial)#做W与x的卷积运算,跨度为1,zero-padding补全边界(使得最后结果大小一致)
def conv2d(x, W):return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')#做2x2的max池化运算,使结果缩小4倍(面积上)
def max_pool_2x2(x):return tf.nn.max_pool(x, ksize = [1, 2, 2, 1],strides=[1, 2, 2, 1], padding = 'SAME')#导入数据
mnist = loaddata.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])#filter取5x5的范围,因为mnist为单色,所以第三维是1,卷积层的深度为32
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])#将输入图像变成28*28*1的形式,来进行卷积
x_image = tf.reshape(x, [-1, 28, 28, 1])#卷积运算,activation为relu
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)#池化运算
h_pool1 = max_pool_2x2(h_conv1)#第二个卷积层,深度为64,filter仍然取5x5
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])#做同样的运算
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)#full-connected层,将7*7*64个神经元fc到1024个神经元上去
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])#将h_pool2(池化后的结果)打平后,进行fc运算
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)#防止过拟合,fc层进行dropout处理,参数为0.5
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)#第二个fc层,将1024个神经元fc到10个最终结果上去(分别对应0~9)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])#最后结果
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)#误差函数使用交叉熵
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))#梯度下降使用adam算法
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)#正确率处理
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))#初始化
sess = tf.Session()
sess.run(tf.initialize_all_variables())#进行训练
for i in range(20000):batch = mnist.train.next_batch(50)if i%100 == 0:train_accuracy = sess.run(accuracy, feed_dict = {x:batch[0], y_:batch[1], keep_prob : 1.0})print("step %d, accuracy %g" % (i, train_accuracy))sess.run(train_step, feed_dict={x:batch[0], y_:batch[1], keep_prob:0.5})#输出最终结果
print(sess.run(accuracy, feed_dict={x:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))

转载于:https://www.cnblogs.com/Saurus/p/7487720.html

Tensorflow框架初尝试————搭建卷积神经网络做MNIST问题相关推荐

  1. (二)Tensorflow搭建卷积神经网络实现MNIST手写字体识别及预测

    1 搭建卷积神经网络 1.0 网络结构 图1.0 卷积网络结构 1.2 网络分析 序号 网络层 描述 1 卷积层 一张原始图像(28, 28, 1),batch=1,经过卷积处理,得到图像特征(28, ...

  2. [转载] 卷积神经网络做mnist数据集识别

    参考链接: 卷积神经网络在mnist数据集上的应用 Python TensorFlow是一个非常强大的用来做大规模数值计算的库.其所擅长的任务之一就是实现以及训练深度神经网络. 在本教程中,我们将学到 ...

  3. 【深度学习】Tensorflow搭建卷积神经网络实现情绪识别

    [深度学习]Tensorflow搭建卷积神经网络实现情绪识别 文章目录 1 Tensorflow的基本使用方法1.1 计算图1.2 Feed1.3 Fetch1.4 其他解释 2 训练一个Tensor ...

  4. 从零开始用TensorFlow搭建卷积神经网络

     https://www.jiqizhixin.com/articles/2017-08-29-14 机器之心GitHub项目:从零开始用TensorFlow搭建卷积神经网络 By 蒋思源2017 ...

  5. Tensorflow入门到实战五(卷积神经网络)

    方法定义 tf.nn.conv2d (input, filter, strides, padding, use_cudnn_on_gpu=None, data_format=None, name=No ...

  6. 用PyTorch搭建卷积神经网络

    用PyTorch搭建卷积神经网络 本篇是加拿大McGill本科,Waterloo硕士林羿实习时所作的工作.发文共享,主要是面对PyTorch的初学者. 本篇文章是一篇基础向的PyTorch教程,适合有 ...

  7. Python图像识别实战(四):搭建卷积神经网络进行图像二分类(附源码和实现效果)

    前面我介绍了可视化的一些方法以及机器学习在预测方面的应用,分为分类问题(预测值是离散型)和回归问题(预测值是连续型)(具体见之前的文章). 从本期开始,我将做一个关于图像识别的系列文章,让读者慢慢理解 ...

  8. Python-深度学习-学习笔记(13):keras搭建卷积神经网络(对二维数据进行一维卷积)

    Python-深度学习-学习笔记(13):keras搭建卷积神经网络(对二维数据进行一维卷积) 卷积神经网络进行图像分类是深度学习关于图像处理的一个应用,卷积神经网络的优点是能够直接与图像像素进行卷积 ...

  9. [转]Theano下用CNN(卷积神经网络)做车牌中文字符OCR

    Theano下用CNN(卷积神经网络)做车牌中文字符OCR 原文地址:http://m.blog.csdn.net/article/details?id=50989742 之前时间一直在看 Micha ...

最新文章

  1. 面试 -- Java基础(一)
  2. 《直播疑难杂症排查》之四:延时高
  3. matlab时频分析工具箱安装_EEG时频分析介绍与实现(基于EEGLAB、NetStation与Analyzer2软件)...
  4. opensource项目_最佳Opensource.com:艺术与设计
  5. Vrep中将物体变得透明的方法
  6. ctfshow-萌新-web11( 利用命令执行漏洞获取网站敏感文件)
  7. mysql 8.0 ~ 安装
  8. HDU1875 畅通工程再续【Kruskal算法+并查集】
  9. 【NOIP2005】【Luogu1046】陶陶摘苹果
  10. maven 项目 spring mvc + jdbc 配置文件
  11. 小米笔记本 镜像_小米笔记本Pro Windows 10 原装系统镜像
  12. 计算机关机怎么按,按什么键电脑关机
  13. 计算机控制器如何调用打印机,怎样设置打印机的虚拟usb端口
  14. UE4 InputMode无法锁定编辑器视口鼠标解决方案
  15. mysql 分区 线性hash_MySQL表分区(3)哈希分区-hash
  16. 小程序高级电商前端第1周走进Web全栈工程师一----小程序注册、开发工具推荐、《风袖》首页布局详尽分析、Webstorm开发小程序必配配置、mock数据...
  17. Logit-Probit:非线性模型中交互项的边际效应解读
  18. C/C++图书管理系统[2023-02-04]
  19. vue中下载图片到本地
  20. 微信小程序-编写图标的方法

热门文章

  1. python列表反向取值_Python列表的反向遍历,python,逆序
  2. hcia是什么等级的证书_华为的HCNA,HCNP,HCIE认证证书都有什么用?
  3. 爬虫模拟登陆手机验证码_Python+scrapy爬虫之模拟登陆
  4. 管理服务器一般的作用,管理服务器作用
  5. python3 lambda函数字典排序_排序字典表理解中的lambda函数
  6. linux本地时间与utc不一致_Linux Windows 双系统时间不一致
  7. 手电筒java_Java鼠标“手电筒”效果如何?
  8. java throwable用法_java – ExceptionHandler不能与Throwable一起使用
  9. java连接摄像头_Java实现 海康摄像头抓拍图像(示例代码)
  10. 用c实现部分java数组功能,很烂,留个参考吧