在 Kaggle 上面的 Notebook 给可爱的学弟学妹们用于参考... 代码这个东西一定要自己多写,我一边听着林宥嘉的《想自由》,一边写出了大致的实现。K 近邻算法大概做的是一件什么事情呢?你去商店买衣服的时候,突然忘记了自己要买的衣服多大尺码比较合适(S/M/L/XL 这种)。这个时候你就要找几个身材和你差不多的几个店内顾客问一问了,结果你发现你这样的身材的人大多买的是 XL 的衣服,所以你最后告诉老板你也买 XL 的衣服,果然是机智聪明啊。

下面是解决该问题的思路重点:

  • 身材和你相似的人的建议比较有参考价值。(排序取前 K 个最相似的)
  • 你是怎么判断其它人身材和你的相似程度的?(距离度量方式)
  • 你最终参考的是最多被购买的那一类尺码。(毕竟不同的人建议可能不同)

关于数据集的读入

MNIST 数据集可以在这里获取:THE MNIST DATABASE of handwritten digits . 你一定很好奇——为什么在 Kaggle 里面数据集是 CSV 格式,而在数据集官网提供的是四个压缩文件?这没什么好稀奇的,你只需要根据不同的数据格式采用不同的数据读取套路就好了,只要最后的数据格式的维度一致即可。

数据集已经解压

def load_mnist(path, kind='train'):"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte'% kind)with open(labels_path,'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with open(images_path,'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels

数据集未解压

def load_mnist(path, kind='train'):"""Load MNIST data from `path`"""labels_path = os.path.join(path,'%s-labels-idx1-ubyte.gz'% kind)images_path = os.path.join(path,'%s-images-idx3-ubyte.gz'% kind)with gzip.open(labels_path, 'rb') as lbpath:labels = np.frombuffer(lbpath.read(), dtype=np.uint8,offset=8)with gzip.open(images_path, 'rb') as imgpath:images = np.frombuffer(imgpath.read(), dtype=np.uint8,offset=16).reshape(len(labels), 784)return images, labels

数据集已经被转为 CSV 格式且无测试集标签(Kaggle)

# It will takes about 1 ~ 2 minutes (depends on CPU)
train_data = np.genfromtxt('../input/train.csv', delimiter=',',skip_header=1).astype(np.dtype('uint8'))
X_train = train_data[:,1:]
y_train = train_data[:,:1]X_test = np.genfromtxt('../input/test.csv', delimiter=',',skip_header=1).astype(np.dtype('uint8'))

检查数据导入是否顺利

indices 即在训练集索引列表中随机取了 9 个值,用于预览图片的导入效果。这个 list 你也可以自己指定。

np.random.seed(0);
indices = list(np.random.randint(m_train, size=9))
for i in range(9):plt.subplot(3,3,i + 1)plt.imshow(X_train[indices[i]].reshape(28,28), cmap='gray', interpolation='none')plt.title("Index {} Class {}".format(indices[i], y_train[indices[i]]))plt.tight_layout()

定义距离度量

之前我们说了,你需要找几个身材和你相似的人,那么你就可以根据三维的属性定义一个距离度量公式。而对于图片来说,图片的相似度计算方法有很多,下面只给出了两种常见的距离度量公式。

def euclidean_distance(vector1, vector2):return np.sqrt(np.sum(np.power(vector1 - vector2, 2)))
def absolute_distance(vector1, vector2):return np.sum(np.absolute(vector1 - vector2))

找相似邻居(K Neighbours)

和你身材相似的人的尺码参考价值比较高,但是也不可轻信一家之言,所以你需要找和你身材最相似的 K 个人,以避免出现意外。

import operator
def get_neighbours(X_train, test_instance, k):distances = []neighbors = []for i in range(0, X_train.shape[0]):dist = euclidean_distance(X_train[i], test_instance)distances.append((i, dist))distances.sort(key=operator.itemgetter(1))for x in range(k):# print(distances[x])neighbors.append(distances[x][0])return neighbors

得到投票最多的建议

大家众说纷纭,最好的决策方法自然就是看最多数的人给出的建议了。

def predictkNNClass(output, y_train):classVotes = {}for i in range(len(output)):# print(output[i], y_train[output[i]])if y_train[output[i]][0] in classVotes:classVotes[y_train[output[i]][0]] += 1else:classVotes[y_train[output[i]][0]] = 1sortedVotes = sorted(classVotes.items(), key=operator.itemgetter(1), reverse=True)# print(sortedVotes)return sortedVotes[0][0]

拿实例来进行测试

从测试集中取出第 667 张图片来测试,找出训练集中和它最相似的 9 (K=9)张图片,其中 6 张图片的类别为 “1” ,那么这张图片分类为 “1” 的概率最大。

instance_num = 666
k = 9
plt.imshow(X_test[instance_num].reshape(28,28), cmap='gray', interpolation='none')
instance_neighbours = get_neighbours(X_train, X_test[instance_num], 9)
indices = instance_neighbours
for i in range(9):plt.subplot(3,3,i + 1)plt.imshow(X_train[indices[i]].reshape(28,28), cmap='gray', interpolation='none')plt.title("Index {} Class {}".format(indices[i], y_train[indices[i]]))plt.tight_layout()
predictkNNClass(instance_neighbours, y_train)


CSV 的导出(Kaggle)

import csv
submit = pd.DataFrame(columns=('ImageId', 'Label'))
for i in range(5):  # change 5 to X_test.shape[0] will takes a long long long ... TIME!neighbours = get_neighbours(X_train, X_test[i], 20)label = predictkNNClass(neighbours, y_train)submit.loc[i]={'ImageId': i + 1,'Label': label}
submit.to_csv('csv_to_submit.csv', index = False)

最后算法给出的预测类别也是 “1” ,但是 KNN 算法的准确率其实并不高,不信你就换换上面的 instance_num ,自行测试一下预测结果。另外一个需要注意的问题是 K 的取值,这个可以算作是一种超参数,具体取什么值就见仁见智了。至于我为什么没有对训练集全部运行 KNN 导出 CSV 文件去得到在 MNIST 的预测准确率,请尝试计算一下上面算法所需的时间复杂度... 这个时候你就能体会到 Scikit-Learn 的优点了,毕竟我们的代码实现只能算是一个 demo,这里有一篇 推荐博文 。

转载于:https://www.cnblogs.com/accepteddoge/p/mnist-knn-with-numpy.html

K 近邻算法识别手写数字(Numpy写法)相关推荐

  1. k近邻算法_图穷匕见:K近邻算法与手写数字识别

    机器学习算法是从数据中产生模型,也就是进行学习的算法.我们把经验提供给算法,它就能够根据经验数据产生模型.在面对新的情况时,模型就会为我们提供判断(预测)结果.例如,我们根据"个子高.腿长. ...

  2. 【机器学习实战】k近邻算法实战——手写识别系统

    文章目录 手写识别系统 步骤: 准备数据:将图像转换为测试向量 测试算法:使用k-近邻算法识别手写数字 [完整代码] 手写识别系统 为了简单起见,这里构造的系统只能识别数字0到9,参见图2-6.需要识 ...

  3. python与人工智能:KNN近邻法识别手写数字

    机器学习分类? 1 特征(feature) 数据是区分事物和事物的关键. 举例:不同类型的书,我们用书的内容来对它进行分类 2 标签(label) 数据的标签,显示的分类结果. 举例:书属于的类别,例 ...

  4. OpenCV-Python实战(番外篇)——利用 KNN 算法识别手写数字

    OpenCV-Python实战(番外篇)--利用 KNN 算法识别手写数字 前言 手写数字数据集 MNIST 介绍 基准模型--利用 KNN 算法识别手写数字 改进模型1--参数 K 对识别手写数字精 ...

  5. 机器学习实战之k-近邻算法识别手写数字(含拍照检验步骤详解)

    哈哈,这是我写的第一篇博客,就此拉开了我的程序员生涯的序幕.希望有缘人看见之后,能够解决你所遇见的问题.废话不多说,开始办正事. 本例中使用K-近邻算法识别手写数字,参考书目:Peter Harrin ...

  6. OpenCV-Python实战(番外篇)——利用 SVM 算法识别手写数字

    OpenCV-Python实战(番外篇)--利用 SVM 算法识别手写数字 前言 使用 SVM 进行手写数字识别 参数 C 和 γ 对识别手写数字精确度的影响 完整代码 相关链接 前言 支持向量机 ( ...

  7. 基于K近邻法的手写数字图像识别

                           数字图像处理课程论文                          题目:数字图像识别   摘要 模式识别(PatternRecognition)是一 ...

  8. opencv(python)使用knn最近邻算法识别手写数字

    knn最近邻算法是一种分类以及回归算法,算法原理是一个样本与样本集中k个样本最相似,如果这k个样本的大多数也属于同一个类别,则该样本也属于这一类.关于knn算法的详细原理读者可以在网上找一些资料了解下 ...

  9. opencv(python)使用svm算法识别手写数字

    svm算法是一种使用超平面将数据进行分类的算法. 关于mnist数据的解析,读者可以自己从网上下载相应压缩文件,用python自己编写解析代码,由于这里主要研究knn算法,为了图简单,直接使用Kera ...

最新文章

  1. 面试官:Java中 serialVersionUID 的作用是什么?举个例子说明
  2. dataframe筛选数据根据某一个列的数据在另外的一个数组中
  3. 【转】UGUI研究院之缓存策略让UI打开更快(三十)
  4. c++ static 关键字用法
  5. linux安装apache+mysql+php3.8练习环境
  6. toj 4613 Number of Battlefields
  7. 未能找到文件“\bin\roslyn\csc.exe”
  8. 命名集 —— 绰号昵称篇
  9. Android Studio中快捷键实现try catch等功能包含代码块
  10. vs2005下载,中文版,官方
  11. 2013 Office安装aurora公式编辑器
  12. 算法11 抓住波粒二象性的火星人
  13. word背景图片设置a4纸大小教程
  14. 在c# winform 的 monthCalendar 里粗体凸显有数据的日期
  15. NR-PRACH接受端如何检测出preambleid和TA的
  16. A hard puzzle(HDU1097)(快速幂取模)
  17. 外文翻译原文附在后面_外文翻译及外文原文(参考格式).doc
  18. 图神经网络:GAT学习、理解、入坑
  19. MY资源网址整合记录
  20. 秒懂Excel的三种引用

热门文章

  1. codeforces 1598 A
  2. 简读《SASE安全访问边缘白皮书》| 了解SASE的核心技术及应用场景
  3. Excel数据分析案例三——预测销量
  4. 建筑工地人脸识别门禁通道闸机如何安装 1
  5. 移动IP技术概述(转)
  6. 产品概念之4/4:产品包 —— 升维思考,降维打击
  7. 畅捷通T+Cloud给客户一站式的产品体验
  8. C++程序设计同步实践宝典——前言
  9. 游戏后台生成唯一ID
  10. python用matplotlib画玫瑰_用Python matplotlib 怎么画风向玫瑰图 ,能给出程序的?