TensorFlow学习笔记之五——源码分析之最近算法
- import numpy as np
- import tensorflow as tf
- # Import MINST data
- import input_data
- mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
- #这里主要是导入数据,数据通过input_data.py已经下载到/tmp/data/目录之下了,这里下载数据的时候,需要提前用浏览器尝试是否可以打开
- #http://yann.lecun.com/exdb/mnist/,如果打不开,下载数据阶段会报错。而且一旦数据下载中断,需要将之前下载的未完成的数据清空,重新
- #进行下载,否则会出现CRC Check错误。read_data_sets是input_data.py里面的一个函数,主要是将数据解压之后,放到对应的位置。
- # In this example, we limit mnist data
- Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
- Xte, Yte = mnist.test.next_batch(200) #200 for testing
- #mnist.train.next_batch,其中train和next_batch都是在input_data.py里定义好的数据项和函数。此处主要是取得一定数量的数据。
- # Reshape images to 1D
- Xtr = np.reshape(Xtr, newshape=(-1, 28*28))
- Xte = np.reshape(Xte, newshape=(-1, 28*28))
- #将二维的图像数据一维化,利于后面的相加操作。
- # tf Graph Input
- xtr = tf.placeholder("float", [None, 784])
- xte = tf.placeholder("float", [784])
- #设立两个空的类型,并没有给具体的数据。这也是为了基于这两个类型,去实现部分的graph。
- # Nearest Neighbor calculation using L1 Distance
- # Calculate L1 Distance
- distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
- # Predict: Get min distance index (Nearest neighbor)
- pred = tf.arg_min(distance, 0)
- #最近邻居算法,算最近的距离的邻居,并且获取该邻居的下标,这里只是基于空的类型,实现的graph,并未进行真实的计算。
- accuracy = 0.
- # Initializing the variables
- init = tf.initialize_all_variables()
- #初始化所有的变量和未分配数值的占位符,这个过程是所有程序中必须做的,否则可能会读出随机数值。
- # Launch the graph
- with tf.Session() as sess:
- sess.run(init)
- # loop over test data
- for i in range(len(Xte)):
- # Get nearest neighbor
- nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})
- # Get nearest neighbor class label and compare it to its true label
- print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])
- # Calculate accuracy
- if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
- accuracy += 1./len(Xte)
- print "Done!"
- print "Accuracy:", accuracy
- #for循环迭代计算每一个测试数据的预测值,并且和真正的值进行对比,并计算精确度。该算法比较经典的是不需要提前训练,直接在测试阶段进行识别。
源代码地址:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2%20-%20Basic%20Classifiers/nearest_neighbor.py
相关API:
tf.reduce_sum(input_tensor, reduction_indices=None, keep_dims=False, name=None)
Computes the sum of elements across dimensions of a tensor.
Reduces input_tensor
along the dimensions given in reduction_indices
. Unless keep_dims
is true, the rank of the tensor is reduced by 1 for each entry in reduction_indices
. If keep_dims
is true, the reduced dimensions are retained with length 1.
If reduction_indices
has no entries, all dimensions are reduced, and a tensor with a single element is returned.
For example:
# 'x' is [[1, 1, 1]
# [1, 1, 1]]
tf.reduce_sum(x) ==> 6
tf.reduce_sum(x, 0) ==> [2, 2, 2]
tf.reduce_sum(x, 1) ==> [3, 3]
tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]]
tf.reduce_sum(x, [0, 1]) ==> 6
Args:
input_tensor
: The tensor to reduce. Should have numeric type.reduction_indices
: The dimensions to reduce. IfNone
(the default), reduces all dimensions.keep_dims
: If true, retains reduced dimensions with length 1.name
: A name for the operation (optional).
Returns:
The reduced tensor.
点评:这个API主要是降维使用,在这个例子中,将测试图片和所有图片相加后的二维矩阵,降为每个图片只有一个最终结果的一维矩阵。
TensorFlow学习笔记之五——源码分析之最近算法相关推荐
- TensorFlow学习笔记之源码分析(3)---- retrain.py
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py ...
- TensorFlow学习笔记之源码分析(3)---- retrain.py(转)
原文地址:http://blog.csdn.net/daydayup_668819/article/details/68060483 https://github.com/tensorflow/ten ...
- sheng的学习笔记-Vector源码分析
概述 Vector底层也是数组,跟ArrayList很像(先看下ArrayList,再看Vector会很轻松),ArrayList可参考下文,并且由于效率低,已经被淘汰了,大概瞅瞅得了 sheng的学 ...
- 循环神经网络(RNN, Recurrent Neural Networks)学习笔记:源码分析(一)
前面帖子给出了RNN的基础理论,里面也提到了神牛Mikolov,这个帖子就基于此牛开源出的一个语言建模工具箱(RNN Language Modeling Tookit)进行代码走读,会加速理解RNN算 ...
- c++imread 函数_OpenCV学习笔记与源码分析: imread( )函数
引言 imread()函数在opencv使用比较. imread()函数 声明: Mat imread(const string& filename, int flags); 这很标准的写法, ...
- FreeRTOS学习笔记---任务控制块源码分析
#define portSTACK_TYPE uint32_t #define portBASE_TYPE long typedef portSTACK_TYPE StackType_t; typed ...
- kademlia java_死磕以太坊源码分析之Kademlia算法
死磕以太坊源码分析之Kademlia算法 KAD 算法概述 Kademlia是一种点对点分布式哈希表(DHT),它在容易出错的环境中也具有可证明的一致性和性能.使用一种基于异或指标的拓扑结构来路由查询 ...
- Java设计模式学习以及底层源码分析
源码在分支master 工厂模式 把具体创建产品的细节封装起来,你要什么产品,我给你什么产品即可. 简单工厂模式 工厂方法模式 缓存层:抽象类 抽象工厂模式 缓存层是:接口 原型模式 问题: 原型模式 ...
- Licode入门学习:MediaStream源码分析(二)
Licode服务与启动过程分析 MediaStream源码分析(一) MediaStream源码分析(二) MediaStream源码分析(三) WebRtcConnection源码分析(一) Web ...
最新文章
- 快速设置戴尔latitude笔记本的触摸板和指点杆
- centos 6.5 安装dotnet core 2.2
- 如何使用Python制作一个会动的地球仪?
- DOM编程以及domReady加载的几种方式
- 论文学习8-How Question Generation Can Help Question Answering over Knowledge Base(KBQA-知识问答)
- c# combobox集合数据不显示_excel打开数据时显示乱码/问号amp;看起来一样却v不出来怎么办...
- 忘记mysql数据库名称_忘记MySQL数据库密码的解决办法
- Spring Cloud微服务之子模块的创建(二)
- 【Java从0到架构师】项目实战 - 前后端分离、后端校验、Swagger、全局异常处理
- C# Winform代码片段-大二下学期的垃圾代码
- 软考 2015 年上半年 网络管理员 上午试卷
- android机器人方向,Android横版过关类游戏推荐《机器人大挑战》
- Android性能优化 _ 大图做帧动画卡?优化帧动画之 SurfaceView滑动窗口式帧复用
- 微信小程序之获取百度天气
- 数字体验词汇表:您需要了解的最重要术语
- LWN: 名为 Sequoia 的 seq_file 漏洞!
- 任天堂超级玛丽(SuperMario)改编的超级企鹅(java)搞笑版,绝对给力
- 3D软件开发工具HOOPS全套产品开发介绍 | HOOPS Visualize、HOOPS Publish
- c语言任伟,任 伟
- 在太平洋人寿保险公司工作好不好?
热门文章
- android tab 悬停效果代码,Android 仿腾讯应用宝 之 Toolbar +Scroolview +tab滑动悬停效果...
- Java查询spark中生成的文件_java+spark-sql查询excel
- 迷宫搜索问题最短路_[源码和文档分享]基于C语言实现的勇闯迷宫游戏
- 「小程序JAVA实战」微信开发者工具helloworld(三)
- MyEclipse图表工具Birt的使用技巧(三)--连接webservice数据源
- AC日记——中庸之道 codevs 2021
- React编写一个简易的评论区组件
- win8/8.1 免密码登录设置
- OCP读书笔记(5) - 使用RMAN创建备份
- 理解Linux系统负荷