• 使用预训练的卷积神经网络提取图片中的特征,生成特征向量。
  • 利用图片库中所有图片数据构建 <id, feature vector> 数据。
  • 使用 Faiss 创建 Index ,利用 <id, feature vector> 数据生成索引。
  • 针对待检索图片,使用模型提取图片特征向量,然后使用 Index 检索 TopK 相似图片的 id。
  • 可视化检索结果

1. 导包

import os
import time
import torch
import faiss
import numpy as np
import matplotlib.pyplot as pltfrom PIL import Image
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset%matplotlib inline

GPU 加速

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
# cuda

2.自定义数据集

transform = transforms.Compose([transforms.Resize((256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])class MyDataset(Dataset):def __init__(self, data_path, transform=None):super().__init__()self.transform = transformself.data_path = data_pathself.data = []img_path = os.path.join(data_path, 'img.txt')with open(img_path, 'r', encoding='utf-8') as f:for line in f.readlines():line = line.strip()img_name = os.path.join(data_path, line)img = Image.open(img_name)if img.mode == 'RGB':self.data.append(line)def __getitem__(self, idx):# take the data sample by it's indeximg_path = os.path.join(self.data_path, self.data[idx])# read imageimg = Image.open(img_path)# apply the transformif self.transform:img = self.transform(img)# return the image and indexdict_data = {'index': idx,'img': img}return dict_datadef __len__(self):return len(self.data)
img_folder = 'JPEGImages'
val_dataset = MyDataset(img_folder, transform=transform)
batch_size = 64
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print('Val_dataset: ', val_dataset.__len__())
print('iter: ', int(val_dataset.__len__()/batch_size)+1)
Val_dataset:  17125
iter:  268

3.预训练模型+自定义特征值提取器

# 加载预训练模型
def load_model():model = models.resnet18(pretrained=True)model.to(device)model.eval()return model# 定义 特征提取器
def feature_extract(model, x):x = model.conv1(x)x = model.bn1(x)x = model.relu(x)x = model.maxpool(x)x = model.layer1(x)x = model.layer2(x)x = model.layer3(x)x = model.layer4(x)x = model.avgpool(x)x = torch.flatten(x, 1)return x
model = load_model()for idx, batch in enumerate(val_dataloader):img = batch['img']  # 图片数据表示 --> 图片特征index = batch['index']img = img.to(device)feature = feature_extract(model, img)feature = feature.data.cpu().numpy()imgs_path = [os.path.join(img_folder, val_dataset.data[i] + '.txt') for i in index]assert len(feature) == len(imgs_path)for i in range(len(imgs_path)):feature_list = [str(f) for f in feature[i]]img_path = imgs_path[i]with open(img_path, 'w', encoding='utf-8') as f:f.write(" ".join(feature_list))print('*' * 60)print(idx * batch_size)

4.图片向量化

# 获取图片特征¶
def img2feat(pic_file):feat = []with open(pic_file, 'r', encoding='utf-8') as f:lines = f.readlines()feat = [float(f) for f in lines[0].split()]return feat
ids = []
data = []img_folder = 'VOC2012'#'VOC2012_small/'
img_path = os.path.join(img_folder,'img.txt')
with open(img_path,'r',encoding='utf-8') as f:for line in f.readlines():img_name = line.strip()img_id = img_name.split('.')[0]pic_txt_file = os.path.join( img_folder,"{}.txt".format(img_name) )if not os.path.exists(pic_txt_file):continuefeat = img2feat(pic_txt_file)ids.append(int(img_id))data.append(np.array(feat))# 构建数据<id,data>
ids = np.array(ids)
data = np.array(data).astype('float32')
d = 512 # feature 特征长度(模型的结果)
print(" 特征向量记录数: ",data.shape)
print(" 特征向量ID的记录数:",ids.shape)特征向量记录数:  (17125, 512)特征向量ID的记录数: (17125,)

5.创建 Faiss 索引 Index

# 创建图片特征索引 - 方案1
# index = faiss.index_factory(d,"IDMap,Flat")
# index.add_with_ids(data,ids)# 创建图片特征索引-方案2(  资源有限,效果更好 )
###IDMap 支持add_with_ids
###如果很在意,使用”PCARx,...,SQ8“ 如果保存全部原始数据的开销太大,可以用这个索引方式。包含三个部分,
# 1.降维
# 2.聚类
# 3.scalar 量化,每个向量编码为8bit 不支持GPU
index = faiss.index_factory(d, "IDMap,PCAR16,IVF50,SQ8")
index.train(data)
index.add_with_ids(data, ids)# 索引文件保存磁盘
faiss.write_index(index,'index_file.index') # 讲index保存index_file.index 的文件
# index = faiss.read_index("index_file.index")
# print(index.ntotal) # 查看索引库大小

加载 Faiss Index 索引文件

index = faiss.read_index('index_file.index')
print('索引记录数:', index.ntotal)
# 索引记录数: 17125

6.Faiss 相似 TopK 检索

def index_search(feat,topK ):"""feat: 检索的图片特征topK: 返回最高topK相似的图片"""feat = np.expand_dims( np.array(feat),axis=0 )feat = feat.astype('float32')start_time = time.time()dis,ind = index.search( feat,topK )end_time = time.time()print( 'index_search consume time:{}ms'.format(  int(end_time - start_time) * 1000  ) )return dis,ind # 距离,相似图片id

7.可视化检索结果

def visual_plot(ind,dis,topK,query_img = None):       # 相似照片cols = 4rows = int(topK / cols)idx = 0fig,axes = plt.subplots(rows,cols,figsize=(20 ,5*rows),tight_layout=True)#axes[0,0].imshow(query_img)for row in range(rows):for col in range(cols):_id = ind[0][idx]_dis = dis[0][idx]img_path = os.path.join(img_folder,'{}.jpg'.format(_id))#print(img_path)if query_img is not None and idx == 0:axes[row,col].imshow(query_img)axes[row,col].set_title( 'query',fontsize = 20  )else:img = plt.imread(  img_path   )axes[row,col].imshow(img)axes[row,col].set_title( 'matched_-{}_{}'.format(_id,int(_dis)) ,fontsize = 20  )idx+=1plt.savefig('pic')
img_folder = 'VOC2012/'
# img_id = '100211.jpg'
img_id = '100002.jpg'
topK = 20
img_path = os.path.join( img_folder,img_id)
print(img_path) # 查看  这个img_path 的相似图片img = Image.open(img_path)
img = transform(img) # torch.Size([3, 224, 224])
img = img.unsqueeze(0) # torch.Size([1, 3, 224, 224])
img = img.to(device)# 对我们的图片进行预测
with torch.no_grad():# 图片-> 图片特征print('1.图片特征提取')feature = feature_extract( model,img )# 特征-> 检索feature_list = feature.data.cpu().tolist()[0]print('2.基于特征的检索,从faiss获取相似度图片')# 相似图片可视化dis,ind = index_search( feature_list,topK=topK )print('ind = ',ind)print('3.图片可视化展示')# 当前图片query_img = plt.imread( img_path )visual_plot( ind,dis,topK,query_img)
VOC2012/100002.jpg
1.图片特征提取
2.基于特征的检索,从faiss获取相似度图片
index_search consume time:0ms
ind =  [[100002 101430 116500 101585 116528 100507 104768 107651 112514 102820112416 116458 106167 111781 116247 103299 103154 106012 115086 111156]]
3.图片可视化展示

基于深度学习的以图搜图相关推荐

  1. 读《基于深度学习的以图搜图技术在照片档案管理中的应用研究_赵学敏》

    论文名称:<基于深度学习的以图搜图技术在照片档案管理中的应用研究_赵学敏> 发表时间:2020年4月 发表期刊:档案学研究(北大核心.CSSCI) 发表单位:云南大学档案馆 愚见 是一个叙 ...

  2. 基于深度学习实现以图搜图功能

    前记: 深度学习的发展使得在此之前以机器学习为主流算法的相关实现变得简单,而且准确率更高,效果更好,在图像检索这一块儿,目前有谷歌的以图搜图,百度的以图搜图,而百度以图搜图的关键技术叫做"感 ...

  3. 以图搜图 图像匹配_图像匹配,基于深度学习DenseNet实现以图搜图功能

    原标题:图像匹配,基于深度学习DenseNet实现以图搜图功能 度学习的发展使得在此之前以机器学习为主流算法的相关实现变得简单,而且准确率更高,效果更好,在图像检索这一块儿,目前有谷歌的以图搜图,百度 ...

  4. 图像匹配,基于深度学习DenseNet实现以图搜图功能

    向AI转型的程序员都关注了这个号???????????? 机器学习AI算法工程   公众号:datayx 度学习的发展使得在此之前以机器学习为主流算法的相关实现变得简单,而且准确率更高,效果更好,在图 ...

  5. python以图搜图api_Python深度学习,手把手教你实现「以图搜图」

    随着深度学习的崛起,极大的推动了图像领域的发展,在提取特征这方面而言,神经网络目前有着不可替代的优势.之前文章中我们也介绍了图像检索往往是基于图像的特征比较,看特征匹配的程度有多少,从而检索出相似度高 ...

  6. python如何实现找图_Python深度学习,手把手教你实现「以图搜图」

    随着深度学习的崛起,极大的推动了图像领域的发展,在提取特征这方面而言,神经网络目前有着不可替代的优势.之前文章中我们也介绍了图像检索往往是基于图像的特征比较,看特征匹配的程度有多少,从而检索出相似度高 ...

  7. 零基础实战行人重识别ReID项目-基于Milvus的以图搜图

    目录 第一阶段,ReID的基本概念 1.1 ReID定义 1.2 技术难点 1.3 常用数据集 1.4 评价指标 1.5 实现思路 1.6 具体方案 第二阶段:复现算法 2.1 PCB的骨干网络 2. ...

  8. 基于内容的图像检索引擎(以图搜图)

    基于内容的图像检索引擎(以图搜图) 本文介绍一些基于内容的图像检索技术(Content-Based Image Retrieval,CBIR)的搜索引擎(即以图搜图),这类搜索引擎基本上代表了图像检索 ...

  9. 如何基于深度学习实现商品识别技术|图普科技

    目前实时客流检测.商品识别.货架识别等人工智能技术可以帮助越来越多的零售门店实现智慧零售数字化转型.随着人工智能技术的发展,图普科技在深度学习在实现商品识别的应用上越发成熟,从技术层面来说,具体包含以 ...

  10. 基于深度学习的目标检测算法思维导图

    在计算机视觉领域,目标检测一直是一种处于非常火热 的状态,尤其是卷积神经网络CNN出现后,出现了各种基于CNN的目标检测算法,在此根据所看到或者所了解的论文进行个人汇总,以思维导图的模式进行记录整理, ...

最新文章

  1. Java,开源,分享
  2. 可能是堆被损坏,这也说明 XX.exe 中或它所加载的任何 DLL 中有 bug
  3. 用费曼技巧自学编程,香不香?
  4. itextpdf 生成word显示不全_Word经常遇到这些偏僻小问题,值的收藏
  5. C++11:继承构造函数
  6. Android基于Docker容器的双系统多开实现和自动化部署
  7. 尽力去帮助一个陌生人
  8. 把iconfront的资源放cdn访问_详解mpvue小程序中怎么引入iconfont字体图标
  9. fastdfs java token_fastdfs-client-java操作fastdfs
  10. 华为P30现身华为新加坡官网:坐实水滴屏
  11. CondaHTTPError问题的解决
  12. oracle 11g 映像文件有效 但不适用于此计算机类型,《计算机应用基础》期末考试模拟练习题(含答案)...
  13. 华为MH5000模块知识应用简介
  14. apache评分表的意义_APACHE评分系统及评分表 -
  15. windows的特殊对话框
  16. big endian和little endian 的区别 ,BOOST_BIG_ENDIAN
  17. python实现求解完美立方等式
  18. 小游戏:HelloColor
  19. 分享6个教师常用的网站,再也不用到处找资源了
  20. python 中的MQTT模块 mqtt-paho的使用

热门文章

  1. SpringCloud项目 CICD 部署
  2. [UPC] 2021秋组队17
  3. 手机流量卡代理第一次做,要做好哪些准备?
  4. [精简]托福核心词汇62
  5. Assembler - 数据段与代码段
  6. 格式化U盘为FAT32
  7. 如何实现用户名或密码错误,弹出重新登录的提示
  8. python excel行列转置_Excel 行列转换的最简方法
  9. linux c蜂鸣器驱动程序,〖Linux〗OK6410a蜂鸣器的驱动程序编写全程实录
  10. python是爬虫的意思吗_python跟爬虫的区别