基于BoF算法的图像分类

图像分类一直是计算机视觉中的一个重要问题,BoF(Bag of features)算法在图像分类中具有着重要的作用。本文旨在介绍BoF算法的基本原理和过程并且给出Python代码的实现:用于解决在Caltech 101数据库上的多分类问题。

算法起源

起源1:纹理识别

纹理(texture)是由一些重复的纹理单元(texton)组成的,如图1所示。

我们想要进行纹理的识别,应该关注组成这些纹理的纹理单元的类型,而不是空间的分布。一副纹理图像包含很多种的纹理单元,我们可以将所有可能出现的纹理单元组成一个集合或者说叫做纹理单元字典(texton dictionary),然后统计对于某图像中某纹理单元出现的个数,就可以得到该图像对应的直方图,如图2所示。

显然,这些直方图可以很好地表示原始的纹理图像。假如我们有一堆纹理图像,可以得到一堆这样的直方图,送入某种分类器中进行训练,然后就可以进行纹理的分类了。

起源2:Bag-of-Words模型
Bag-of-Words模型的思想很简单:我们想要了解一段本文的核心内容,最简单直接的方式是找出其中的关键词,然后根据关键词出现的频率来确定该段文本要想表述的意思。

从上图中,我们知道关键词是iraq和terrorists,由此可以推荐该文本的主题与伊拉克的恐怖主义有关。这里所说的关键词,就是Bag-of-Words中的words,它们是区分度较高的单词。根据这些 words,
我们可以很快地识别出文章的内容,并快速地对文章进行分类。

Bags of Features算法

Bag of features算法分为四步:

  1. 提取图像特征;
  2. 对特征进行聚类,得到可视化字典(visual vocabulary);
  3. 根据字典将图片表示成向量,即直方图;
  4. 使用得到的直方图表示的特征进行分类器的训练。

    特征提取
    首先我们从原始图像中提取特征,如图4所示。常用的特征提取方法有SIFT,SURF。SIFT得到的特征描述是128维度的向量,相比SISF,SURF计算量更小些,得到的特征是64维的向量。也有使用HoG和LBP来进行特征提取的。注意特征提取的方法要满足旋转不变性以及尺寸不变性。

字典生成
对所有的图片提取完特征后,将所有的特征进行聚类,比如使用K-Means聚类,得到K类,每个类别看作一个word,这样我们就得到了字典,如下图所示。

直方图表示
上一步训练得到的字典,是为了这一步对图像特征进行量化。对于一幅图像而言,我们可以提取出大量的特征,但这些特征(如SIFT提取的特征)仍然属于一种浅层的表示,缺乏代表性。因此,这一步的目标,是根据字典重新提取图像的高层特征。具体做法是,对于每一张图片得到的每一个特征(如SIFT提取的特征),都可以在字典中找到一个最相似的word(实际上就是将特征输入到得到的聚类模型,得到类别),统计相似的每种word的数量,于是就得到一个K维的直方图。如下图所示。

训练分类器
对于每张图片,我们得到了其对应的直方图向量,当然也知道其对应的属于哪种物品的标记。这样我们就可以构造训练集来训练某种分类器。当需要进行预测时,我们先测试集的图片中提取特征,然后利用字典量化得到直方图,输入训练好的分类器,得到预测的类别。

代码实现

下面让我们一起使用Python来实现基于基于BoF算法的图像分类。首先需要下载数据集Caltech-101。解压后进入caltech101(点击进行下载),再进入其子目录,可以看到有102个文件夹,其中每个文件夹对应一种物品。简单起见,我们使用三种物品:bonsai,ferry和laptop。

数据预处理
在进行Bag-of-Features算法的实现之前,首先我们来读取所需要的图片。

import os"""
功能:读取文件夹中的图片
输入:data_dir:某种物品图片所在的文件夹
输出:imgs:某种物品所有的图片路径
"""
def read_imgs(data_dir):imgs = os.listdir(data_dir)imgs = [data_dir + "/" + img for img in imgs]return imgsdata_dir = 'caltech101/101_ObjectCategories/'
catalog = ['bonsai', 'ferry', 'laptop']imgSet = [read_imgs(data_dir + catalog[0]),read_imgs(data_dir + catalog[1]),read_imgs(data_dir + catalog[2]),]

实现输出代码,输出一下每种物品的数量信息。

print ("Label\t\tcount")
print ("---------------------")
for i, item in enumerate(catalog):print ("%s\t\t%s" %(item, len(imgSet[i])))

输出结果如下。

其中第一列表示物体的种类,第二列表示对应的图片的数量。
在上面的代码基础上,我们进行训练集和测试集数据的划分和生成。

import random
"""
功能:产生训练集和测试集
输入:imgSet:包含所有物品种类的图片路径split:根据split进行划分训练集和测试集,表示训练集的比例
输出:train_datas:训练集数据,列表类型test_datas:测试集数据,列表类型train_labels:训练集标签,列表类型test_labels:测试集标签,列表类型
"""
def make_dataset(imgSet, split):train_datas=[]test_datas = []train_labels = []test_labels = []#用index来表示label,即三种类型物体标签如下:# bonsai --- 0# ferry ---- 1# laptop --- 2for index, item in enumerate(imgSet):random.shuffle(item) #将某种物品数据打乱interval = int(len(item) * split)train_item = item[:interval]test_item = item[interval:]train_datas += train_itemtest_datas += test_itemtrain_labels += [index for _ in range(len(train_item))]test_labels += [index for _ in range(len(test_item))]return train_datas, test_datas, train_labels, test_labelstrain_datas, test_datas ,train_labels, test_labels = make_dataset(imgSet, 0.7)

特征提取
首先我们用一个函数将原始的RGB图转换为灰度图,然后使用OpenCV的SURF算法来进行特征的提取,最后使用几行代码来测试下效果。

import cv2
"""
功能:将一张RGB图转换为灰度图
输入:color_img:RGB图
输出:gray:灰度图
"""
def to_gray(color_img):gray = cv2.cvtColor(color_img, cv2.COLOR_RGB2GRAY)return gray
"""
功能:提取一张灰度图的SURF特征
输入:gray_img:要提取特征的灰度图
输出:key_query:兴趣点desc_query:描述符,即我们最终需要的特征
"""
def gen_surf_features(gray_img):#400表示hessian阈值,一般使用300-500,表征了提取的特征的数量,#值越大得到的特征数量越少,但也越突出。surf = cv2.xfeatures2d.SURF_create(400)key_query, desc_query = surf.detectAndCompute(gray_img, None)return key_query, desc_query#测试gen_surf_features的结果
import matplotlib.pyplot as plt
img = cv2.imread(train_datas[0])
img = to_gray(img)
key_query, desc_query = gen_surf_features(img)
imgOut = cv2.drawKeypoints(img, key_query, None, (255, 0, 0), 4)
plt.imshow(imgOut)
plt.show()

为了展示该阈值的影响,这里我们使用两种不同的Hessian阈值(400和3000)得到两张结果的图示。因为代码中在划分训练集和测试集时进行过随机处理,所以这两张图并不一定是同一物体。

接下来我们来实现一个函数,它可以利用上面已经实验的函数来提取所有的特征。

"""
功能:提取所有图像的SURF特征
输入:imgs:要提取特征的所有图像
输出:img_descs:提取的SURF特征
"""
def gen_all_surf_features(imgs):img_descs = []for item in imgs:img = cv2.imread(item)img = to_gray(img)key_query, desc_query = gen_surf_features(img)img_descs.append(desc_query)return img_descsimg_descs = gen_all_surf_features(train_datas)

至此我们已经完成了特征提取的部分,得到了提取到的SURF特征。接下来进行字典的生成。

字典生成
我们先再来回顾下生成字典的流程,对训练集的所有图片进行特征提取,将提取的所有的特征向量进行聚类,从而得到字典。如下图所示。

import numpy as np
from sklearn.cluster import MiniBatchKMeans"""
功能:提取所有图像的SURF特征
输入:img_descs:提取的SURF特征
输出:img_bow_hist:条形图,即最终的特征cluster_model:训练好的聚类模型
"""
def cluster_features(img_descs, cluster_model):n_clusters = cluster_model.n_clusters #要聚类的种类数#将所有的特征排列成N*D的形式,其中N表示特征数,#D表示特征维度,这里特征维度D=64train_descs = [desc for desc_list in img_descsfor desc in desc_list]train_descs = np.array(train_descs)#转换为numpy的格式#判断D是否为64if train_descs.shape[1] != 64: raise ValueError('期望的SURF特征维度应为64, 实际为', train_descs.shape[1])#训练聚类模型,得到n_clusters个word的字典cluster_model.fit(train_descs)#raw_words是每张图片的SURF特征向量集合,#对每个特征向量得到字典距离最近的wordimg_clustered_words = [cluster_model.predict(raw_words)for raw_words in img_descs]#对每张图得到word数目条形图(即字典中每个word的数量)#即得到我们最终需要的特征img_bow_hist = np.array([np.bincount(clustered_words, minlength=n_clusters)for clustered_words in img_clustered_words])return img_bow_hist, cluster_modelK = 500 #要聚类的数量,即字典的大小(包含的单词数)
cluster_model=MiniBatchKMeans(n_clusters=K, init_size=3*K)
train_datas, cluster_model = cluster_features(img_descs,cluster_model)

经过上述代码(主要是进行聚类分析),对于每张原始图片,我们得到了其对应的最终的特征(直方图)。接下来我们来学习如何进行分类器的训练以及进行结果的预测,得到最终的Accuracy值。

from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC"""
功能:分类
输入:train_datas:训练集,即最终的特征(所有图像的直方图集合),要求是numpy.array类型train_labels:训练集的label,要求是numpy.array类型
输出:classifier:训练好的分类器
"""
def run_svm(train_datas, train_labels):   classifier = OneVsRestClassifier(LinearSVC(random_state=0)).fit(train_datas, train_labels)return classifier#将训练集label转化为numpy.array类型
train_labels = np.array(train_labels)
classifier = run_svm(train_datas, train_labels)

对于分类器的选择我们也可以使用多层感知机或其他的神经网络:

from sklearn.multiclass import OneVsRestClassifier
from sklearn.svm import LinearSVC
from sklearn.neural_network import MLPClassifier"""
功能:分类
输入:train_datas:训练集,即最终的特征(所有图像的直方图集合),要求是numpy.array类型train_labels:训练集的label,要求是numpy.array类型
输出:classifier:训练好的分类器
"""
def run_svm(train_datas, train_labels):#注释内容:SVM分类器#classifier = OneVsRestClassifier( #    LinearSVC(random_state=0)).fit(#                    train_datas, train_labels)classifier = MLPClassifier(solver='lbfgs', alpha=1e-10,hidden_layer_sizes=(100,),random_state=1).fit(train_datas, train_labels)return classifier

接下来我们来进行预测并得到最终的Accuracy结果。进行预测的过程如下:

  1. 提取每张测试集图像的SURF特征;
  2. 利用训练好的字典得到每张图片的直方图;
  3. 对每张图片的直方输入分类器得到结果;
  4. 计算Accuracy值。

    首先我们来实现一个函数,用来从一张图片得到对应的直方图向量。

"""
功能:将一张图片转化为直方图的形式
输入:img_path:一张图片cluster_model:已经训练好的聚类模型
输出:img_bow_hist:直方图向量
"""
def img_to_vect(img_path, cluster_model):"""Given an image path and a trained clustering model (eg KMeans),generates a feature vector representing that image.Useful for processing new images for a classifier prediction."""img = cv2.imread(img_path)gray = to_gray(img)kp, desc = gen_surf_features(gray)clustered_desc = cluster_model.predict(desc)img_bow_hist = np.bincount(clustered_desc,minlength=cluster_model.n_clusters)#转化为1*K的形式,K为字典的大小,即聚类的类别数return img_bow_hist.reshape(1,-1)

接下来我们来实现最终的测试函数。

"""
功能:对测试集数据进行预测,得到Accuracy
输入:test_datas:测试集数据,要求是numpy.array类型test_labels:测试集label,要求是numpy.array类型
输出:无返回值,输出Accuracy
"""
def test(test_datas, test_labels, cluster_model, classifier):print ("测试集的数量: ", len(test_datas))preds = []for item in test_datas:vect = img_to_vect(item, cluster_model)pred = classifier.predict(vect)preds.append(pred[0])preds = np.array(preds)idx = preds == test_labelsaccuracy = sum(idx)/len(idx)print ("Accuracy是: ", accuracy)test_labels = np.array(test_labels)
test(test_datas, test_labels, cluster_model, classifier)

得到的结果为。

当然每次运行得到的结果会有所差异。

参考:

https://www.cnblogs.com/jermmyhsu/p/8195727.html
https://ww2.mathworks.cn/help/vision/examples/image-category-classification-using-bag-of-features.html
http://www.cs.unc.edu/~lazebnik/spring09/lec18_bag_of_features.pdf

基于BoF算法的图像分类相关推荐

  1. [Python图像处理] 二十六.图像分类原理及基于KNN、朴素贝叶斯算法的图像分类案例

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  2. 基于深度残差网络图像分类算法研究综述

    文章从残差网络的设计出发,分析了不同残差单元的构造方式,介绍了深度残差网络不同的变体.从不同角度比较了不同网络之间的差异以及这些网络架构常用图像分类数据集上的性能表现.最后对各种网络进行l总结,并讨论 ...

  3. 使用OpenCV与sklearn实现基于词袋模型的图像分类预测与搜索

    点击上方"小白学视觉",选择加"星标"或"置顶" 重磅干货,第一时间送达 基于OpenCV实现SIFT特征提取与BOW(Bag of Wor ...

  4. python计算机视觉编程——基于BOF的图像检索(附代码)

    图像检索 一.图像检索 1.1基本原理 1.2 BOW原理简述(同BOF原理) 1.3 基于BOF的图像检索基本流程 1.3.1 sift特征提取 1.3.2 建立视觉单词 1.3.3 K-Means ...

  5. 【图像分类】 基于Pytorch的细粒度图像分类实战

    欢迎大家来到<图像分类>专栏,今天讲述基于pytorch的细粒度图像分类实战! 作者&编辑 | 郭冰洋 1 简介 针对传统的多类别图像分类任务,经典的CNN网络已经取得了非常优异的 ...

  6. 基于小样本学习的图像分类技术综述

    基于小样本学习的图像分类技术综述 人工智能技术与咨询 昨天 本文来自<自动化学报>,作者李颖等 关注微信公众号:人工智能技术与咨询.了解更多咨询! 图像分类是一个经典的研究课题, 典型的图 ...

  7. 基于深度神经网络的图像分类与训练系统(MATLAB GUI版,代码+图文详解)

    摘要:本博客详细介绍了基于深度神经网络的图像分类与训练系统的MATLAB实现代码,包括GUI界面和数据集,可选择模型进行图片分类,支持一键训练神经网络.首先介绍了基于GoogleNet.ResNet进 ...

  8. 【Pytorch进阶一】基于LeNet的CIFAR10图像分类

    [Pytorch进阶一]基于LeNet的CIFAR10图像分类 一.LeNet网络介绍 二.CIFAR10数据集介绍 三.程序架构介绍 3.1 LeNet模型(model.py) 3.2 训练(tra ...

  9. 用MindStudio完成基于CTC算法的语音热词唤醒

    Bilibili视频链接: 使用MindStudio完成基于CTC算法的语音热词唤醒_哔哩哔哩_bilibili 一. MindStudio 介绍与安装 相关课程:昇腾全流程开发工具链(MindStu ...

  10. LIME算法:图像分类解释器(代码实现)

    在上一篇博客LIME算法:模型的可解释性(代码实现)中,我整理了LIME算法的原理及在文本分类模型中的应用.在这篇笔记中,我记录了LIME算法在图像分类模型中的应用及过程中遇到的问题和解决方法. 一. ...

最新文章

  1. linux运行脚本报错:/bin/bash^M: bad interpreter: No such file or directory(dos2unix )(/bin/sh^M)(回车符、换行符)
  2. iOS Winding Rules 缠绕规则
  3. sql安装目录下log文件夹_Linux安装Hive数据仓库工具
  4. 手风琴html例子,jquery实现简单手风琴菜单效果实例
  5. 基于360搜图爬取图片
  6. php同时抢购 代码,浅谈PHP实现大流量下抢购方案
  7. 安装MySql报错(This application requires .NET Framework x.x.x)
  8. 计算机视觉图像去噪原理,图像去噪方法研究进展
  9. 漫画 | 硬核技术预测你有没有女朋友
  10. Linux-HA 高可用开源方案 Keepalived VS Heartbeat 的选择
  11. Linux安装GIMP
  12. Altium Designer 17 安装破解版详细教程
  13. python爬淘宝评论源代码_python3爬取淘宝信息代码分析
  14. 从零搭建与好友“一起看王心凌《爱你》MV”功能
  15. css设置div垂直居中
  16. java和vue实现拖拽可视化_可视化拖拽页面编辑器 一__Vue.js
  17. (转载)各类指数基金标的指数比较
  18. mantis apache mysql_Nginx、Apache、PHP、Mantis上传文件和附件大小设置
  19. org.postgresql.util.PSQLException: 不支援 10 验证类型
  20. 搭建超级实用的免费机器翻译api

热门文章

  1. 如何发布百度离线地图及二次开发API
  2. P4556 [Vani有约会]雨天的尾巴 树链剖分 线段树合并
  3. 程序开发,也要匠心独运
  4. mcafee 8.5i杀毒软件规则配置
  5. linux嵌入式reboot不生效,Embeded linux之reboot
  6. wtg linux双系统,Windows和Linux同时装入移动硬盘,实现可移动专属双系统
  7. xpadder教程:自定义设置游戏手柄的图片
  8. 【javafx】如何java查询12306火车票剩余数量
  9. MSE = Bias² + Variance?什么是“好的”统计估计器
  10. 51单片机程序存储器扩展