因为有打算想要写一组关于零样本学习算法的博客,需要用到AWA2数据集作为demo演示

之前想只展示算法部分的代码就好了,但是如果只展示算法部分的代码可能不方便初学者复现,所以这里把我数据预处理的方法也说一下,博客的最后会给一个处理好的数据下载地址,之后的博客都会利用该博客的方法作为数据预处理

我会对AWA2数据集做一个详细的介绍,对数据集有一个好的理解本身也有助于算法的学习和实现

AWA2 图像数据集下载地址:http://cvml.ist.ac.at/AwA2/
数据集比较大有13个G,下载可能得花点时间

这里还有几篇零样本算法的介绍和复现:
DeepLearning | Semantic Autoencoder for Zero Shot Learning 零样本学习 (论文、算法、数据集、代码)
DeepLearning | Relational Knowledge Transfer for Zero Shot Learning 零样本学习(论文、算法、数据集、代码)

目录

  • 一、AWA2数据集简介
    • 1.1 classes.txt
    • 1.2 JPGEImages
    • 1.3 licenses
    • 1.4 predicate-matrix-binary.txt
    • 1.5 predicate-matrix-continuous.txt
    • 1.6 predict-matrix.png
    • 1.7 predicate.txt
    • 1.8 README-attributes.txt和README-images.txt
    • 1.9 testclass.txt
    • 1.10 trainclasses.txt
  • 二、数据集处理
    • 2.1 图片读取
    • 2.2 准备属性标签
    • 2.3 使用预训练的resnet101提取图片特征
  • 三、处理完毕的数据
  • 四、资源下载

一、AWA2数据集简介

该数据集是C. H. Lampert 等人在 Zero-Shot Learning - A Comprehensive Evaluation of the Good, the Bad and the Ugly上公布的动物识别数据集,该数据集一共包含以下几个文件

接下来我会一一对这些文件进行介绍

1.1 classes.txt

该文件记录了数据集所包含的动物种类,共50种,注意,该文件我稍微做了修改,源文件是没有+号的如6这里,这么做是为了写法保持一致。源数据集有的文件里用了加号,有的没用,这里统一了一下

1.2 JPGEImages

该文件夹包含了数据集的所有图片数据,格式如下,每一个子文件夹包含一种动物的图片

1.3 licenses

该文件夹包含每一张图片的授权,这个文件我们在处理时是用不到的

1.4 predicate-matrix-binary.txt

该文件记录了50种动物,每一种动物的85种属性特征情况,是一个50x85的矩阵,1表示有该特征,0表示无,如下

1.5 predicate-matrix-continuous.txt

和 predicate-matrix-binary.txt 文件一样,记录了50种动物,每一种动物的85种属性特征情况,只是该矩阵对属性的描述用的是连续数字

1.6 predict-matrix.png

文件 predicate-matrix-binary.txt 的图形化

1.7 predicate.txt

该文件记录了85种预测的属性分别是什么

1.8 README-attributes.txt和README-images.txt

这两个说明文件对我们也是没有用的

1.9 testclass.txt

该文件说明了哪些动物是测试种类,共10个测试类别

1.10 trainclasses.txt

该文件说明了哪些动物是训练种类,共40个训练类别

这里就介绍完了数据集的全部文件,简而言之,数据集包含50个种类动物的37322张图片,训练集40类30337张图片,测试集10类6985张图片

二、数据集处理

2.1 图片读取

这一步我们需要将图片统一大小为224x224x3,并为数据集制作相应的标签,代码如下

import pandas as pd
import os
import numpy as np
import cv2
from PIL import Imageimage_size = 224          # 指定图片大小
path = '/Users/zhuxiaoxiansheng/Desktop/Animals_with_Attributes2/'   #文件读取路径classname = pd.read_csv(path+'classes.txt',header=None,sep = '\t')
dic_class2name = {classname.index[i]:classname.loc[i][1] for i in range(classname.shape[0])}
dic_name2class = {classname.loc[i][1]:classname.index[i] for i in range(classname.shape[0])}
# 两个字典,记录标签信息,分别是数字对应到文字,文字对应到数字#根据目录读取一类图像,read_num指定每一类读取多少图片,图片大小统一为image_size
def load_Img(imgDir,read_num = 'max'):imgs = os.listdir(imgDir)imgs = np.ravel(pd.DataFrame(imgs).sort_values(by=0).values)if read_num == 'max':imgNum = len(imgs)else:imgNum = read_numdata = np.empty((imgNum,image_size,image_size,3),dtype="float32")print(imgNum)for i in range (imgNum):img = Image.open(imgDir+"/"+imgs[i])arr = np.asarray(img,dtype="float32")if arr.shape[1] > arr.shape[0]:arr = cv2.copyMakeBorder(arr,int((arr.shape[1]-arr.shape[0])/2),int((arr.shape[1]-arr.shape[0])/2),0,0,cv2.BORDER_CONSTANT,value=0)else:arr = cv2.copyMakeBorder(arr,0,0,int((arr.shape[0]-arr.shape[1])/2),int((arr.shape[0]-arr.shape[1])/2),cv2.BORDER_CONSTANT,value=0)       #长宽不一致时,用padding使长宽一致arr = cv2.resize(arr,(image_size,image_size))if len(arr.shape) == 2:temp = np.empty((image_size,image_size,3))temp[:,:,0] = arrtemp[:,:,1] = arrtemp[:,:,2] = arrarr = temp        data[i,:,:,:] = arrreturn data,imgNum  #读取数据
def load_data(train_classes,test_classes,num):read_num = numtraindata_list = []trainlabel_list = []testdata_list = []testlabel_list = []    for item in train_classes.iloc[:,0].values.tolist():tup = load_Img(path+'JPEGImages/'+item,read_num=read_num)traindata_list.append(tup[0])trainlabel_list += [dic_name2class[item]]*tup[1]for item in test_classes.iloc[:,0].values.tolist():tup = load_Img(path+'JPEGImages/'+item,read_num=read_num)testdata_list.append(tup[0])testlabel_list += [dic_name2class[item]]*tup[1]      return np.row_stack(traindata_list),np.array(trainlabel_list),np.row_stack(testdata_list),np.array(testlabel_list)train_classes = pd.read_csv(path+'trainclasses.txt',header=None)
test_classes = pd.read_csv(path+'testclasses.txt',header=None)traindata,trainlabel,testdata,testlabel = load_data(train_classes,test_classes,num='max')print(traindata.shape,trainlabel.shape,testdata.shape,testlabel.shape)#降图像和标签保存为numpy数组,下次可以直接读取
np.save(path+'AWA2_224_traindata.npy',traindata)
np.save(path+'AWA2_224_testdata.npy',testdata)np.save(path+'AWA2_trainlabel.npy',trainlabel)
np.save(path+'AWA2_testlabel.npy',testlabel)

2.2 准备属性标签

刚刚我们读取了数据并制作了0-49的数字标签,但光是数字标签在零样本学习中是不足的,我们还需要每一张图片与其对应的属性标签
下面制作了连续属性的标签,同样的方法还可以制作离散(01)属性的标签,还可以将连续属性规范到0-1范围内作为标签,这些代码不再重复,处理好的标签会在最后的链接中统一给出

import pandas as pd
import numpy as nppath = '/Users/zhuxiaoxiansheng/Desktop/Animals_with_Attributes2/'def make_attribute_label(trainlabel,testlabel):  attribut_bmatrix = pd.read_csv(path+'predicate-matrix-continuous.txt',header=None,sep = ',')trainlabel = pd.DataFrame(trainlabel).set_index(0)testlabel = pd.DataFrame(testlabel).set_index(0)return trainlabel.join(attribut_bmatrix),testlabel.join(attribut_bmatrix)trainlabel = np.load(path+'AWA2_trainlabel.npy')
testlabel = np.load(path+'AWA2_testlabel.npy')train_attributelabel,test_attributelabel = make_attribute_label(trainlabel,testlabel)np.save(path+'AWA2_train_continuous_attributelabel.npy',train_attributelabel.values)
np.save(path+'AWA2_test_continuous_attributelabel.npy',test_attributelabel.values)

2.3 使用预训练的resnet101提取图片特征

在零样本学习中,很多情况下,我们不会直接使用图片本身,使用卷积网络提取出的特征会更加方便

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
import torch
import torchvision
from torchvision import datasets, models,transforms
from torch.autograd import Variable
from torch.utils.data import DataLoader,Dataset
from tqdm import tqdm
from torch import nn,optim
import lightgbm as lgb
import warnings
warnings.filterwarnings("ignore")
from sklearn.linear_model import LogisticRegressionpath = '/Users/zhuxiaoxiansheng/Desktop/Animals_with_Attributes2/'classname = pd.read_csv(path+'classes.txt',header=None,sep = '\t')
dic_class2name = {classname.index[i]:classname.loc[i][1] for i in range(classname.shape[0])}
dic_name2class = {classname.loc[i][1]:classname.index[i] for i in range(classname.shape[0])}def make_test_attributetable():    #制作测试10类的属性表attribut_bmatrix = pd.read_csv(path+'predicate-matrix-binary.txt',header=None,sep = ' ')test_classes = pd.read_csv(path+'testclasses.txt',header=None)test_classes_flag = []for item in test_classes.iloc[:,0].values.tolist():test_classes_flag.append(dic_name2class[item])return attribut_bmatrix.iloc[test_classes_flag,:]class dataset(Dataset):def __init__(self,data,label,transform):super().__init__()self.data = dataself.label = labelself.transform = transformdef __getitem__(self,index):return self.transform(self.data[index]),self.label[index]def __len__(self):return self.data.shape[0] class FeatureExtractor(nn.Module):def __init__(self, submodule, extracted_layers):super(FeatureExtractor,self).__init__()self.submodule = submoduleself.extracted_layers= extracted_layersdef forward(self, x):outputs = []for name, module in self.submodule._modules.items():if name is "fc": x = x.view(x.size(0), -1)x = module(x)if name in self.extracted_layers:outputs.append(x)return outputstraindata = np.load(path+'AWA2_224_traindata.npy')
trainlabel = np.load(path+'AWA2_trainlabel.npy')
train_attributelabel = np.load(path+'AWA2_train_attributelabel.npy')testdata = np.load(path+'AWA2_224_testdata.npy')
testlabel = np.load(path+'AWA2_testlabel.npy')
test_attributelabel = np.load(path+'AWA2_test_attributelabel.npy')print(traindata.shape,trainlabel.shape,train_attributelabel.shape)
print(testdata.shape,testlabel.shape,test_attributelabel.shape)data_tf = transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])])train_dataset = dataset(traindata,trainlabel,data_tf)
test_dataset = dataset(testdata,testlabel,data_tf)train_loader = DataLoader(train_dataset,batch_size=1,shuffle=False)
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=False)model = models.resnet101(pretrained=True)     #使用训练好的resnet101(if torch.cuda.is_available():model=model.cuda()model.eval()exact_list = ['avgpool']    #提取最后一层池化层的输出作为图像特征
myexactor = FeatureExtractor(model,exact_list)train_feature_list = []
for data in tqdm(train_loader):img,label = data if torch.cuda.is_available():with torch.no_grad():img = Variable(img).cuda()with torch.no_grad():     label = Variable(label).cuda()else:with torch.no_grad():img = Variable(img)with torch.no_grad():label = Variable(label)  feature = myexactor(img)[0]feature = feature.resize(feature.shape[0],feature.shape[1])train_feature_list.append(feature.detach().cpu().numpy()) trainfeatures = np.row_stack(train_feature_list) test_feature_list = []
for data in tqdm(test_loader):img,label = data if torch.cuda.is_available():with torch.no_grad():img = Variable(img).cuda()with torch.no_grad():     label = Variable(label).cuda()else:with torch.no_grad():img = Variable(img)with torch.no_grad():label = Variable(label)  feature = myexactor(img)[0]feature = feature.resize(feature.shape[0],feature.shape[1])test_feature_list.append(feature.detach().cpu().numpy()) testfeatures = np.row_stack(test_feature_list)  print(trainfeatures.shape,testfeatures.shape)

三、处理完毕的数据

上面已经介绍了一些基本的处理方法和数据,在之后介绍算法的过程中,数据会直接拿来使用,处理好的数据下载链接如下:

AWA2_trainlabel https://pan.baidu.com/s/1d08IninWz7FATJrDL6DsDA
AWA2_testlabel https://pan.baidu.com/s/1j-GOTYMB2DfaLPH_FziRxQ
resnet101_trainfeatures https://pan.baidu.com/s/10OwVXFVDJMneNFNZlYygew
resnet101_testfeatures https://pan.baidu.com/s/1UT5roIJm9dGb3BMr1mVyQQ
AWA2_train_attributelabel.npy https://pan.baidu.com/s/1xgzJBwCRiOjOKSm13IY3kQ
AWA2_test_attributelabel.npy https://pan.baidu.com/s/1UwtQmDlFJTLvFc71xkFZ6A
AWA2_train_continuous_01_attributelabel.npy https://pan.baidu.com/s/1_31wEQZO81-8kJjANFwdeA
AWA2_test_continuous_01_attributelabel.npy https://pan.baidu.com/s/1at2El02-JCmD-1SrKhQMeA

四、资源下载

微信搜索“老和山算法指南”获取更多下载链接与技术交流群

有问题可以私信博主,点赞关注的一般都会回复,一起努力,谢谢支持。

DeepLearning | Zero shot learning 零样本学习AWA2 图像数据集预处理相关推荐

  1. DeepLearning | Zero Shot Learning 零样本学习(扩展内容、模型、数据集)

    之前写过一篇关于零样本学习的博客,当时写的比较浅.后来导师让我弄个ppt去给本科生做一个关于Zero Shot Learning 的报告,我重新总结了一下,添加了一些新的内容,讲课的效果应该还不错,这 ...

  2. Zero-shot Learning零样本学习 论文阅读(五)——DeViSE:A Deep Visual-Semantic Embedding Model

    Zero-shot Learning零样本学习 论文阅读(五)--DeViSE:A Deep Visual-Semantic Embedding Model 背景 Skip-gram 算法 算法思路 ...

  3. Zero-shot Learning零样本学习 论文阅读(一)——Learning to detect unseen object classes by between-class attribute

    Zero-shot Learning零样本学习 论文阅读(一)--Learning to detect unseen object classes by between-class attribute ...

  4. few-shot learning 1.1——零样本学习

    few-shot learning 1.1--初识零样本学习 1. 什么是few-shot learning 小样本学习问题是指只给定目标少量训练样本的条件下,如何训练一个可以有效地识别这些目标的机器 ...

  5. 近期必读的6篇 NeurIPS 2019 的零样本学习(Zero-Shot Learning)论文

    近期必读的6篇 NeurIPS 2019 的零样本学习(Zero-Shot Learning)论文 PS:转发自"专知"公众号 [导读]NeurIPS 是全球最受瞩目的AI.机器学 ...

  6. ChatGPT基础知识系列之零样本学习( Zero-Short learning)

    ChatGPT基础知识系列之零次学习( Zero-Short learning) 顾名思义,在训练分类器的时候可以不需要A类物体样本就能在测试时识别A类物体,咋一看,很玄乎,其实并没有.在具体解释思路 ...

  7. 论文浅尝 | 当知识图谱遇上零样本学习——零样本学习综述

    随着监督学习在机器学习领域取得的巨大发展,如何减少人工在样本方面的处理工作,以及如何使模型快速适应层出不穷的新样本,成为亟待解决的问题.零样本学习(Zero-Shot Learning, ZSL)的提 ...

  8. 现代NLP中的零样本学习

    2020-07-01 11:19:35 作者:Joe Davison 编译:ronghuaiyang 导读 使用最新的NLP技术来进行零样本学习的一些进展和工作. 自然语言处理现在是一个非常令人兴奋的 ...

  9. 干货!基于层次适应的零样本学习

    点击蓝字 关注我们 AI TIME欢迎每一位AI爱好者的加入! 针对视觉-语义异构特征难对准的问题,我们提出一种层次视觉-语义层次适应的学习模型,通过同时进行结构对准和分布对准,学习一个具有结构和分布 ...

最新文章

  1. Spring(五):SpringStruts2Hibernate整合后,实现查询Employee信息
  2. Nginx防盗链详细设置
  3. boost::hana::cartesian_product用法的测试程序
  4. 【附10】kibana创建新的index patterns
  5. hmailserver mysql密码_第二步:点晴MIS系统Email模块hMailServer数据库连接配置指引
  6. task2 EDA数据分析
  7. oracle数据库服务器启动后需执行的命令(SecureCRT中执行)
  8. Oracle实现数据不存在则插入,数据存在则更新(insert or update)
  9. c#截取两个指定字符串中间的字符串
  10. git SSH 公钥拉取代码(使用及配置)
  11. 天思经理人ERP塑胶玩具行业应用方案
  12. Speedpdf——无需下载软件,就可以免费轻松搞定CAJ转word
  13. 计算机中什么不可打印,电脑不能打印怎么办
  14. word 添加批注 标题向右移动 解决方法
  15. bit(比特)和byte(字节的关系)以及现实应用场景(存储单位和网速单位)
  16. 缓动动画_核心动画概念:缓入缓出
  17. 最新小月云匿名短信系统源码V2.0+全新UI的
  18. 开源分布式存储系统的对比
  19. 桌面多出一个IE图标无法删除的解决办法
  20. maven jar包瘦身

热门文章

  1. springcloud 微服务鉴权_springcloud 微服务权限校验JWT模式获取 token 实战(十二)...
  2. PHP采集程序原理分析篇
  3. a c++ bloger
  4. 转载---编写高质量代码:改善Java程序的151个建议(第3章:类、对象及方法___建议41~46)...
  5. OpenFOAM动网格技术介绍【转载】
  6. 关于深拷贝和浅拷贝的一些思考
  7. Matlab 多项式拟合
  8. 【爬虫+算法】爬取成都地铁所有站点信息,并基于迪杰特斯拉算法计算最优地铁路线
  9. PAT (Basic Level) Practice (中文)1070 结绳 (25 分) 凌宸1642
  10. 3分钟告诉你什么是CDP系统