本笔记系列以斯坦福大学CS231N课程为大纲,海豚浏览器每周组织一次授课和习题答疑。具体时间地点请见微信公众号黑斑马团队(zero_zebra)和QQ群(142961883)发布。同时课程通过腾讯课堂(百纳公开课)进行视频直播.欢迎参与学习.

在CS231N的第二课,k-nearest neighbor这部分中,核心就是计算训练集与测试集之元素之间的欧氏距离。课后作业要求从训练集取5000个图像,测试集取了500个图像,计 算这5000个用于训练的图像与500个用于测试的图像之间的欧氏距离,其结果就是一个5000*500的距离矩阵。

课后作业总共留了三道关于距离矩阵的计算题,分别是由易到难,从使用二重循环到不使用循环。特别是不使用循环的方法,需要一点数学基础,不是特别直观。经过研究《Numpy/Scipy Recipes for Data Science: Computing Nearest Neighbors》之后,决定把其中的方法写出来。

任务定义

给定

阶矩阵
,满足
。这里第
列向量是
维向量。求
矩阵,使得:

计算方法

这里提供4种方法,需要使用到以下Python库:

import numpy as np
import numpy.linalg as la

第一种方法:使用两重循环

def compute_squared_EDM_method(X):# determin dimensions of data matrixm,n = X.shape# initialize squared EDM DD = np.zeros([n, n])# iterate over upper triangle of Dfor i in range(n):for j in range(i+1, n):D[i,j] = la.norm(X[:, i] - X[:, j])**2D[j,i] = D[i,j]    #*1
return D

由于是计算矩阵自身行向量之间的距离,所以结果是一个对称的三角矩阵。注意*1行代码处所做的优化。

在上述方法中我们使用了两层循环,因此代码虽不简洁,但十分易懂。

第二种方法

在第一种方法中,我们使用了numpy的norm这个方法,这个方法从数学上讲,其计算公式是:

然后我们又将这个计算结果平方后赋给

因此,我们实际上是在计算:

上述运算可以使用点积(即矩阵内积)来计算:

D[i,j] = np.dot(X[:,i]-X[:,j],(X[:,i]-X[:,j]).T)

现在代码变化为:

def compute_squared_EDM_method2(X):# determin dimensions of data matrixm,n = X.shape# initialize squared EDM DD = np.zeros([n, n])# iterate over upper triangle of Dfor i in range(n):for j in range(i+1, n):d = X[:,i] - X[:,j]D[i,j] = np.dot(d, d)D[j,i] = D[i,j]return D

第三种方法:避免循环内的点积运算

注意在上面的方法中,dot运算被调用了

次,并且每次进行了
次乘积运算和
次加法运算。尽管numpy底层可能对点积运算做了优化,但这里还是存在可能进行进一步优化。请看下面的数学推导:

这里

属于格拉姆矩阵中的元素,可以通过在循环外计算矩阵,在循环内直接引用元素值即可,从而在循环内我们只需要做两次加(减)法运算:

格拉姆矩阵的求法很简单,只需要:

现在代码变为:

def compute_squared_EDM_method3(X):# determin dimensions of data matrixm,n = X.shape# compute Gram matrixG = np.dot(X.T, X)# initialize squared EDM DD = np.zeros([n, n])# iterate over upper triangle of Dfor i in range(n):for j in range(i+1, n):d = X[:,i] - X[:,j]D[i,j] = G[i,i] - 2 * G[i,j] + G[j,j]D[j,i] = D[i,j]return D

第四种方法:避免循环

假设距离矩阵可以表示为

,与公式
进行对比,有:

这里H中第i行的每一个元素,取值都为

,也就是H的每一列,都对应着格拉姆矩阵的对角阵,因此,我们可以用下面的代码来计算H:
H = np.tile(np.diag(G), (n,1))

此外,由于

,所以最终距离矩阵可以计算为

现在,代码不再需要循环了:

def compute_squared_EDM_method4(x):m,n = X.shapeG = np.dot(X.T, X)H = np.tile(np.diag(G), (n,1))return H + H.T - 2*G

扩展:任意行(列)同秩矩阵间的距离矩阵计算

上述方法解决了矩阵自身(列)向量之间的距离矩阵运算问题。对CS231N第二讲的课程作业来说,需要求解的问题是:

给定训练矩阵A为

阶矩阵。这里5000 代表5000幅带标签的图,3072是其各像素在RGB三个通道下的取值数。给定测试集矩阵B为
阶矩阵。求矩阵B的各行与矩阵A的各行的距离(即两幅图的差异)矩阵,这个矩阵是一个
的矩阵。

更一般地,这个问题可以描述如下:

给定矩阵A为

阶矩阵,矩阵B为
阶矩阵,求矩阵B的任意行向量与矩阵A的任意行向量的距离矩阵
。这个矩阵的数学表达式为(a, b均为行向量):

为方便讨论,我们将上述各项分别记为H, M, N, K,即:

显然上述公式是无法进行运算的,因为除了M与D外,其它矩阵的秩各不相同。所以我们要回到前一个数学表达式上。

  1. H对D的贡献是对于D的每一行,都加上

  2. K对于D的贡献是对于D的每一列,都加上
  3. M和N互为转置矩阵。即对
    ,要减去矩阵
    元素,而这个元素就是

因此,可以在numpy中运用broadcasting机制,通过矩阵与行向量、列向量的运算传播机制(broadcasting)来完成计算

def compute_distances_no_loops(A, B):m = np.shape(A)[0]n = np.shape(B)[0]dists = np.zeros((m, n)) # 求得矩阵M为 m*n维M = np.dot(A, B.T)# 对于H,我们只需要A.A^T的对角线元素,下面的方法高效求解(只计算对角线元素)# 结果H为m维行向量H = np.square(A).sum(axis = 1)#结果K为n维行向量.要将其元素运用到矩阵M的每一列,需要将其转置为行向量K = np.square(B).sum(axis = 1)#H对M在y轴方向上传播,即H加和到M上的第一行,K对M在x轴方向上传播,即K加和到M上的每一列D = np.sqrt(-2*M+H+np.matrix(K).T)return D

谈谈Numpy的broadcasting

在numpy中,当一个数与矩阵相加时,实际上是将矩阵中的每一个元素都加上这个数。注意在矩阵代数里,这种相加是不允许的。但实际应用中又很常见,所以Numpy就扩展了这个定义,允许一个实数矩阵相加,这就是broadcasting。

broadcasting在工程中是非常实用的。在第四种方法中,没有使用 broadcasting机制,它是先取格拉姆矩阵的对角线元素(是n维列向量),再通过np.tile运算将其扩展为一个

的矩阵,然后才能完成相加的操作。而在扩展一节中,我们引入了行方向上(axis = 1)的传播(与K的相加),也引入了列方向上(axis = 0)的传播(与H的相加)。正因为有了这个传播,我们无须象第四种方法那样,显式地生成一个
矩阵

下面是broadcasting的规则:

  1. 让所有输入数组都向其中shape最长的数组看齐,shape中不足的部分都通过在前面加1补齐
  2. 输出数组的shape是输入数组shape的各个轴上的最大值
  3. 如果输入数组的某个轴和输出数组的对应轴的长度相同或者其长度为1时,这个数组能够用来计算,否则出错
  4. 当输入数组的某个轴的长度为1时,沿着此轴运算时都用此轴上的第一组值。

这里看一个例子:

C = np.arange(0,3)
D = np.arange(0, 40, 10).reshape(-1,1)

这样生成的C和D分别是3维行向量和4*1阶矩阵。

>>> C
array([0, 1, 2])
>>> D
array([[ 0],[10],[20],[30]])

如果计算C+D,结果如何?

  1. 让输入数组向shape最长(维度最高)的数组看齐,shape中不足的部分都通过在前面加1补齐。这里D是二维数组,所以C被reshape为 array([[0,1,2]])。
  2. 根据规则2,确定运算的输出数组维度为4*3。
  3. C与D都有一个轴同秩,所以可以计算。
  4. 将D沿 axis = 0的方向(即列方向)从上到下,逐一加上C的元素。依次得到 [0,1,2],[10,11,12]...[30,31,32]

课程导航

上一课:斯坦福CS231N课程学习笔记(二).理解CIFAR-10图像数据库

下一课

三元组法矩阵加法java_计算机视觉学习笔记(2.1)-KNN算法中距离矩阵的计算相关推荐

  1. 学习笔记——Kaggle_Digit Recognizer (KNN算法 Python实现)

    本文是个人学习笔记,该篇主要学习KNN算法理论和应用范围,并应用KNN算法解决Kaggle入门级Digit Recognizer,也是个人入坑ML和Kaggle的开端,希望能够有始有终. KNN算法 ...

  2. 图像坐标:我想和世界坐标谈谈(A) 【计算机视觉学习笔记--双目视觉几何框架系列】

    玉米竭力用轻松具体的描述来讲述双目三维重建中的一些数学问题.希望这样的方式让大家以一个轻松的心态阅读玉米的<计算机视觉学习笔记>双目视觉数学架构系列博客.这个系列博客旨在捋顺一下已标定的双 ...

  3. python矩阵左除_matlab学习笔记

    Matlab学习笔记 运算: 1.     算术运算(在矩阵意义下进行) +:要求矩阵同型,对应元素相加减,如果用标量和矩阵相加减,不同型就凉凉提示错误,那就将矩阵每个元素和数字相加减 -:同上 *: ...

  4. 使用Excel分析数据学习笔记之 二分类与混淆矩阵

    使用Excel分析数据学习笔记之 二分类与混淆矩阵 混淆矩阵的构成: e.g.1:Bombers and seagulls 案例背景 混淆矩阵 如何根据混淆矩阵得到ROC曲线? 如何设定最佳阈值(op ...

  5. Apollo星火计划学习笔记——Apollo开放空间规划算法原理与实践

    文章目录 前言 1. 开放空间规划算法总体介绍 1.1 Task: OPEN_SPACE_ROI_DECIDER 1.2 Task: OPEN_SPACE_TRAJECTORY_PROVIDER 1. ...

  6. 【学习笔记】多项式相关算法

    [学习笔记]多项式相关算法 手动博客搬家: 本文发表于20181125 13:19:28, 原地址https://blog.csdn.net/suncongbo/article/details/844 ...

  7. Boost库学习笔记(二)算法模块-C++11标准

    Boost库学习笔记(二)算法模块-C++11标准 一.综述 Boost.Algorithm是一系列人通用推荐算法的集合,虽然有用的通用算法很多,但是为了保证质量和体积,并不会将太多通用算法通过审查测 ...

  8. 【学习笔记】目标检测算法总结

    [学习笔记]目标检测算法总结 说明 MacOS操作系统. MindNote思维导图软件. B站学习视频+原论文学习. 初学者 笔记 如有问题请多多指教 记录 Overfeat模型.R-CNN.Fast ...

  9. C++ Primer 学习笔记 第十章 泛型算法

    C++ Primer 学习笔记 第十章 泛型算法 336 find函数 #include <iostream> #include <vector> #include <s ...

最新文章

  1. fftw_plan_dft_2d优化
  2. 他奶奶的,我要再不写技术文章,找工作都没有说服力!
  3. ubuntu下载软件安装包
  4. MyBatis 实际使用案例-plugins
  5. html dot标签,html – CSS Dot符号命名约定
  6. Behave用户自定义数据类型
  7. 【待完善】MongoDB - 数据模型
  8. sqlite简单笔记
  9. 防护IOS APP安全的几种方式(详解)
  10. 2013年下半年 系统分析师 案例分析真题
  11. 08Spring Boot自定定义配置
  12. 实战案例:抽屉自动点赞与爬取汽车之家新闻
  13. android横竖屏切换布局闪退,Android 横竖屏切换以及横屏启动闪退问题
  14. 四电极体脂称解决方案——测量原理
  15. 登陆163邮箱 验证邮箱帐号密码是否正确
  16. 百度云 不限速 | 2019 最好用下载工具
  17. Android Studio连接使用第三方模拟器
  18. python中标点符号大全及名字_标点符号大全及名字0919.史上最全标点符号英语读法...
  19. yxc_第一章 基础算法(三)_离散化
  20. 计算机培训班价格多少钱?

热门文章

  1. Android 人脸识别进行实名验证demo
  2. bootstrapselect使用 Bootstrap's dropdowns require Popper.js
  3. centos 生产 ssh-key
  4. 360健康助手文件存储位置 获取图片
  5. hadoop join
  6. 云开发0基础训练营第二期热力来袭!
  7. windows下解决pip安装出错问题
  8. Linux课堂笔记-第二天
  9. oracle数据库查看用户相关语句
  10. clientWidth、clientHeight、offsetWidth、offsetHeight以及scrollWidth、scrollHeight