什么是零样本学习?最全零样本学习原理解释!

  • 什么是零样本学习
  • 什么时候需要零样本学习
  • 让我们开始吧
    • 实现思想
    • 数据收集
    • 图像特征提取和数据集形成
    • 词嵌入
    • 模型训练
    • 零样本学习模型
    • 零样本模型的评价
  • 总结

什么是零样本学习

零点学习方法的目的是在训练阶段没有收到任何任务的例子的情况下解决一个任务。从给定的图像中识别一个物体的任务,在训练阶段没有任何该物体的例子,可以被视为零点学习任务的一个例子。实际上,它只是让我们能够识别我们以前没有见过的物体。

什么时候需要零样本学习

在传统的物体识别过程中,有必要确定一定数量的物体类别,以便能够以高成功率进行物体识别。也有必要为选定的物体类别收集尽可能多的样本图像。当然,这些样本图像应该包含从不同角度在不同背景/环境下拍摄的物体,这样才是全面的。尽管存在很多我们可以毫不费力地收集样本图像的物体类别,但也存在我们并不总是那么幸运的情况。

想象一下,我们想要识别那些濒临灭绝的动物,或者生活在人类无法随时访问的极端环境(海洋深处/丛林或难以到达的山峰)。要收集这类动物的样本图片并不容易。即使你能收集到足够多的图像,也要记住图像不应该是相似的,它们应该是尽可能独特的 。你需要做出很大的努力来实现这一点。

除了用有限的图像识别不同的物体类别的困难之外,对一些物体类别的标注也不像普通人那样容易做到。在某些情况下,只有在真正掌握了该对象后或在专家在场的情况下才能进行标注。细致的物体识别任务,如鱼种或树种的识别,可以看作是在专家监督下贴标签的例子。一个普通人会把她/他所看到的所有的树称为/标记为树,或把她/他所看到的所有的鱼称为鱼。这些显然是真实的答案,但想象一下,你想训练一个网络,以便识别树或鱼的种类。在这种情况下,所有上述的真实答案都是无用的,你需要一个专家来帮助你完成标签任务。同样,你需要做出大量的努力来实现这一点。

题为《遥感图像中的细粒度物体识别和零点学习》的研究论文是关于该主题的有趣的实际研究之一,在该论文中,仅利用其航空或卫星图像就能识别树木并将其归类为物种,与在大片区域内行走拍摄树木的照片并给它们贴上标签相比,这些图像很难有意义,但很容易收集。

让我们开始吧

在提到什么是零点学习之后,让我们一步一步地实现一个零点学习模型。但在这之前,让我们详细说明一下我们的方法。

实现思想

我们有训练类和零拍类。请记住,在训练过程中不会使用零照类的样本。那么,用训练对象训练的模型究竟如何对零照对象进行识别?简单地说,怎么可能识别出以前从未见过的物体呢?

我们都知道,为了能够应用任何机器学习技术,我们应该用合理的特征表示数据。我们应该使用两个数据表征,其中一个表征应该起到辅助的作用。因此,我们想出了图像嵌入和类嵌入–作为辅助表示–作为我们的两种表示。

图像嵌入并不特别。它是一个使用卷积网络从图像中提取的特征向量。卷积网络可以从头开始实施,也可以使用已经证明其成功的预训练的卷积网络。我们将使用预先训练好的卷积模型–VGG16–用于图像特征提取过程。

记住,我们有训练班和零点班。我们收集训练类的图像样本,自然可以得到所有这些图像样本的图像嵌入。然而,我们没有任何零照类的图像样本–我们不知道它们是什么样子的–所以不可能得到零照类的图像嵌入。这就是零照学习方法与传统方法的不同之处。在这一点上,我们需要另一个数据表示,它将作为训练和零照类之间的桥梁。这个数据表示应该从所有的数据样本中提取出来,而不考虑它们是属于训练类还是零点类。正因为如此,我们不应该关注图像本身,而应该关注类标签,这是所有数据样本的共同属性。

类嵌入是一个类(类标签)的向量表示。它是一种表示,我们可以很容易地访问每一类物体的图像表示。我们将使用谷歌的Word2Vecs作为类嵌入,这将允许我们将单词–类标签–作为向量表示。在Word2Vec空间中,如果两个词–用两个提到的向量表示–倾向于一起出现在同一个文档中或有语义关系,那么两个向量最有可能被紧密地定位。

在上面的例子中,我们可以很容易地观察到,与可食用物体相关的类/词的向量(用白色和绿松石色的盒子表示)在位置上倾向于出现在一起。然而,它们往往与与身体部位相关的类/词的向量(用亮绿色的方框表示)出现在一起,显得很遥远。

总而言之,对于训练类,我们有他们的图像样本和类标签,因此我们有他们的图像嵌入和类嵌入。然而,对于零拍类,我们只有他们的类标签–我们从未见过任何图像样本–因此我们只有他们的类嵌入。通过观察左边的图,可以更清楚地看到这一点。

最后,我们简单地想做的是;我们将使用图像嵌入(图像特征向量)和它们相关的类嵌入(单词Word2Vecs)来训练类。这样,网络基本上会学习如何将一个给定的输入图像映射到位于Word2Vec空间的向量上。训练完成后,当一个属于零拍类的物体的图像被赋予网络时,我们将能够获得一个矢量作为输出。然后,通过使用这个输出向量(测量它与我们拥有的所有类向量的距离–包括训练和零照–),我们将能够进行分类。

数据收集

作为第一项工作,我们需要收集图像数据,这是在训练阶段和评估阶段所需要的,以衡量训练后的Zero-Shot性能。我从Visual Genome收集数据,并决定总共使用20个类,其中有15个类被选为训练类,5个类被选为零点射击类。

然后,我们应该确定哪些对象类将被选为训练类,哪些将被选为Zero-Shot类。为了便于说明,它将更适合于识别日常物体,而不是为细粒度的物体识别任务预演和选择适当的类。

下面展示一些 内联代码片

Train Classes

### SELECTED TRAIN CLASSES FOR ZERO-SHOT LEARNING
arm
boy
bread
chicken
child
computer
ear
house
leg
sandwich
television
truck
vehicle
watch
woman

Zero-Shot Classes

### SELECTED ZERO-SHOT CLASSES FOR ZERO-SHOT LEARNING
car
food
hand
man
neck

图像特征提取和数据集形成

在收集了足够的图像样本(每类400个样本,包括训练和零照)后,现在是时候从这些样本图像中提取特征了。对于这项任务,我们倾向于使用一个预先训练好的图像分类网络–VGG16–它是在ImageNet上训练的。

从视觉基因组中,我们获得了图像及其相应的注释,这些注释指出了图像中物体的位置。对于每张图像,从相应的图像注释中获得图像中出现的物体的坐标,并对物体进行裁剪。然后,使用预先训练好的模型(在我们的例子中是VGG16)从这些裁剪过的图像中提取图像特征。特征提取器类可以看到如下。

def get_model():vgg_model = keras.applications.VGG16(include_top=True, weights='imagenet')vgg_model.layers.pop()vgg_model.layers.pop()inp = vgg_model.inputout = vgg_model.layers[-1].outputmodel = Model(inp, out)return modeldef get_features(model, cropped_image):x = image.img_to_array(cropped_image)x = np.expand_dims(x, axis=0)x = keras.applications.vgg16.preprocess_input(x)features = model.predict(x)return features

词嵌入

在我们提取了图像特征并形成了数据集之后,现在我们应该收集其他类别的表示,即单词嵌入。我们将使用在谷歌新闻文件上训练的谷歌Word2Vec表示。我们将为我们指定的20个对象类别中的每一个得到一个300维的Wor2Vec。

class Word2vec():def __init__(self, model_path=WORD2VECPATH):self.model_path  = model_pathself.model       = gensim.models.KeyedVectors.load_word2vec_format(self.model_path, binary=True)self.vocab       = self.model.vocab.keys()def generate_word2vec(self, words):print('Number of words  : ' + str(len(words)))vectors = list()for i,word in enumerate(words):print(str(i+1) + "\t" + str(word))if ' ' in word:word1, word2 = word.split(' ')if word1 in self.vocab:vec1 = self.model[word1]else:print("Word {} not in vocab".format(word))vectors.append([0])continueif word2 in self.vocab:vec2 = self.model[word2]else:print("Word {} not in vocab".format(word))vectors.append([0])continuevec3 = (vec1 + vec2)/2.vectors.append(vec3)continueif word in self.vocab:vectors.append(self.model[word])else:print("Word {} not in vocab".format(word))vectors.append([0])return vectors

模型训练

模型的结构必须设计成给定的输入(图像特征)应该映射到相应的输出(Word2Vecs)。由于我们已经使用了一个预训练的卷积模型来获得图像特征,现在,我们需要创建一个小型的后续全连接模型。

这里重要的一点是创建一个自定义层,它将是模型的最后一层。这个层的权重必须使用训练类的Word2Vecs来确定,而且这个层必须是不可训练的,这意味着它在训练期间不应该受到梯度更新的影响,保持不变。它将是一个简单的矩阵乘法,放在网络的最后。

开始训练的脚本

### SET HYPERPARAMETERS
global NUM_CLASS, NUM_ATTR, EPOCH, BATCH_SIZE
NUM_CLASS = 15
NUM_ATTR = 300
BATCH_SIZE = 128
EPOCH = 65### TRAINING PHASE
(x_train, x_valid, x_zsl), (y_train, y_valid, y_zsl) = load_data()
model = build_model()
train_model(model, (x_train, y_train), (x_valid, y_valid))

加载数据的函数

def load_data():"""read data, create datasets"""# READ DATAwith gzip.GzipFile(DATAPATH, 'rb') as infile:data = cPickle.load(infile)# ONE-HOT-ENCODE DATAlabel_encoder   = LabelEncoder()label_encoder.fit(train_classes)training_data = [instance for instance in data if instance[0] in train_classes]zero_shot_data = [instance for instance in data if instance[0] not in train_classes]# SHUFFLE TRAINING DATAnp.random.shuffle(training_data)### SPLIT DATA FOR TRAININGtrain_size  = 300train_data  = list()valid_data  = list()for class_label in train_classes:ct = 0for instance in training_data:if instance[0] == class_label:if ct < train_size:train_data.append(instance)ct+=1continuevalid_data.append(instance)# SHUFFLE TRAINING AND VALIDATION DATAnp.random.shuffle(train_data)np.random.shuffle(valid_data)train_data = [(instance[1], to_categorical(label_encoder.transform([instance[0]]), num_classes=15))for instance in train_data]valid_data = [(instance[1], to_categorical(label_encoder.transform([instance[0]]), num_classes=15)) for instance in valid_data]# FORM X_TRAIN AND Y_TRAINx_train, y_train    = zip(*train_data)x_train, y_train    = np.squeeze(np.asarray(x_train)), np.squeeze(np.asarray(y_train))# L2 NORMALIZE X_TRAINx_train = normalize(x_train, norm='l2')# FORM X_VALID AND Y_VALIDx_valid, y_valid = zip(*valid_data)x_valid, y_valid = np.squeeze(np.asarray(x_valid)), np.squeeze(np.asarray(y_valid))# L2 NORMALIZE X_VALIDx_valid = normalize(x_valid, norm='l2')# FORM X_ZSL AND Y_ZSLy_zsl, x_zsl = zip(*zero_shot_data)x_zsl, y_zsl = np.squeeze(np.asarray(x_zsl)), np.squeeze(np.asarray(y_zsl))# L2 NORMALIZE X_ZSLx_zsl = normalize(x_zsl, norm='l2')

用于模型构建和自定义层初始化的功能

def custom_kernel_init(shape):class_vectors       = np.load(WORD2VECPATH)training_vectors    = sorted([(label, vec) for (label, vec) in class_vectors if label in train_classes], key=lambda x: x[0])classnames, vectors = zip(*training_vectors)vectors             = np.asarray(vectors, dtype=np.float)vectors             = vectors.Treturn vectorsdef  build_model():model = Sequential()model.add(Dense(1024, input_shape=(4096,), activation='relu'))model.add(BatchNormalization())model.add(Dropout(0.8))model.add(Dense(512, activation='relu'))model.add(Dropout(0.5))model.add(Dense(256, activation='relu'))model.add(Dense(NUM_ATTR, activation='relu'))model.add(Dense(NUM_CLASS, activation='softmax', trainable=False, kernel_initializer=custom_kernel_init))print("-> model building is completed.")return model

对于分类任务,每一层都使用ReLU作为激活函数,但输出层则使用概率分布–softmax函数。

模型训练功能

def train_model(model, train_data, valid_data):x_train, y_train = train_datax_valid, y_valid = valid_dataadam = Adam(lr=5e-5)model.compile(loss      = 'categorical_crossentropy',optimizer = adam,metrics   = ['categorical_accuracy', 'top_k_categorical_accuracy'])history = model.fit(x_train, y_train,validation_data = (x_valid, y_valid),verbose         = 2,epochs          = EPOCH,batch_size      = BATCH_SIZE,shuffle         = True)print("model training is completed.")return history

下面,可以看到训练阶段的最后一个 epoch 信息。

我们通过使用15个训练班进行训练取得了足够好的分数。

零样本学习模型

请记住,我们计划用Word2Vecs作为桥梁来识别我们以前从未见过的物体类别,我们说模型应该为每个输入图像提供一个矢量输出。为了能够实现这一点,我们需要去掉模型的最后一层,这一层是不可训练的和自定义的。

### CREATE AND SAVE ZSL MODEL
inp         = model.input
out         = model.layers[-2].output
zsl_model   = Model(inp, out)
print(zsl_model.summary())
save_keras_model(zsl_model, model_path=MODELPATH)

现在,在我们删除模型的最后一层后,我们现在可以得到一个300维的向量,它表示的是向量空间中的一个坐标,对于每一个图像输入。我们将把这个输出向量与我们已经拥有的所有20个类的向量进行比较,从而把它映射到最近的一个。

零样本模型的评价

我们创建了 "零点射击 "模型。现在是时候衡量它的性能了。我们将使用我们已经确定的Zero-Shot类别的样本。我们在模型训练的任何阶段都没有使用这些样本。让我们记住这些类别:汽车、食物、手、人和脖子。我们为每个类别收集了400张图片(总共2000张),仅仅通过执行零点分类–使用300维的输出向量–我们就可以衡量模型的性能。

在我们为每个图像样本获得一个向量(Word2Vec)后,我们将这个向量与代表我们每个类别的20个向量进行比较。我们使用欧几里得距离指标来进行比较。然后,我们宣布,属于与我们的输出向量最接近的向量的类别是我们预测的类别。

# EVALUATION OF ZERO-SHOT LEARNING PERFORMANCE
class_vectors       = sorted(np.load(WORD2VECPATH), key=lambda x: x[0])
classnames, vectors = zip(*class_vectors)
classnames          = list(classnames)
vectors             = np.asarray(vectors, dtype=np.float)tree        = KDTree(vectors)
pred_zsl    = zsl_model.predict(x_zsl)top5, top3, top1 = 0, 0, 0
for i, pred in enumerate(pred_zsl):pred            = np.expand_dims(pred, axis=0)dist_5, index_5 = tree.query(pred, k=5)pred_labels     = [classnames[index] for index in index_5[0]]true_label      = y_zsl[i]if true_label in pred_labels:top5 += 1if true_label in pred_labels[:3]:top3 += 1if true_label in pred_labels[0]:top1 += 1print()
print("ZERO SHOT LEARNING SCORE")
print("-> Top-5 Accuracy: %.2f" % (top5 / float(len(x_zsl))))
print("-> Top-3 Accuracy: %.2f" % (top3 / float(len(x_zsl))))
print("-> Top-1 Accuracy: %.2f" % (top1 / float(len(x_zsl))))

现在我们已经执行了评估,让我们看一下零样本模型的性能。

前五名的准确率几乎为79%。请记住,我们能够对我们以前从未见过的图像进行分类–模型不知道属于这些类别的物体是什么样子的–有了这些准确率,这一点都不差 请记住,模型只得到了这些类别的词向量在词向量空间中的位置信息。这比随机分类要好得多。

准确率并不总是那么高(79%),因为要对细粒度的对象类别进行高准确率的分类真的很难。在我们的案例中,我选择了相对离散的日常对象类别,以表达零点学习算法是如何工作和实现的。这应该可以解释高精确度。如果你还记得介绍部分,我举了一些关于细粒度物体问题的例子,比如树/鱼种的分类。这些都是尝试零点学习的问题,这将是明智的。

总结

零点学习是一个非常新的研究领域,但毋庸置疑的是,它具有非常大的潜力,是计算机视觉领域的领先研究课题之一。

它可以作为未来许多项目的基础系统。可以利用Zero-Shot学习为视力障碍者开发一个辅助性的嵌入式系统。自然界中的监控摄像机可以使用零点射击学习来检测和计算它们自己栖息地中的稀有动物。

随着机器人领域的发展,我们正试图生产与我们自己相似的机器人。人类的视觉是使我们成为人类的最重要特征之一,我们希望将这一特征转移到机器人身上。我们能够解释和识别一个物体,即使我们从未见过一个样本,至少我们可以推理出那个东西是什么。零点学习法在很多方面与人类视觉系统相似,因此它可以用于机器人视觉。与其在有限的物体上进行识别,使用零点学习法有可能识别世界上所有的物体。

原文:Zero-Shot Learning

什么是零样本学习?最全零样本学习原理解释!相关推荐

  1. 《零基础学JavaScript(全彩版)》学习笔记

    <零基础学JavaScript(全彩版)>学习笔记 二〇一九年二月九日星期六0时9分 前期: 刚刚学完<零基础学HTML5+CSS3(全彩版)>,准备开始学习JavaScrip ...

  2. 【区块链学习最全教程】学习 Solidity,全栈 Web3,Javascript 和区块链开发

    Chainlink 开发者社区发布了一个关于全栈 web3,solidity 和区块链开发的完整视频教程.本视频教程由 Chainlink 开发者大使 Patrick Collins 讲解.教程由浅入 ...

  3. 【深度学习】语音识别之CTC算法原理解释与公式推导

    不搞语音识别得人开这个论文确实有点费劲,结合上图,思考一下语音识别的场景,输入是一段录音,输出是识别的音素, 输入的语音文件的长度和输出的音素个数之间没有一一对应关系,通常将语音文件「分片」之后,会出 ...

  4. python零基础入门教程视频下载-Python零基础入门学习视频教程全42集,资源教程下载...

    课程名称 Python零基础入门学习视频教程全42集,资源教程下载 课程目录 001我和Python的第一次亲密接触 002用Python设计第一个游戏 003小插曲之变量和字符串 004改进我们的小 ...

  5. python基础教程视频教程百度云-Python零基础入门学习视频教程全42集百度云网盘下载...

    课程简介 Python零基础入门学习视频教程全42集百度云网盘下载 课程目录 042魔法方法:算术运算 041魔法方法:构造和析构 040类和对象:一些相关的BIF 039类和对象拾遗 038类和对象 ...

  6. 慕课学习史上最全零基础入门HTML5和CSS笔记

    慕课学习史上最全零基础入门HTML5和CSS笔记 Html和CSS的关系 学习web前端开发基础技术需要掌握:HTML.CSS.JavaScript语言.下面我们就来了解下这三门技术都是用来实现什么的 ...

  7. python基础教程百度云-Python零基础入门学习视频教程全42集百度云网盘下载

    课程简介 Python零基础入门学习视频教程全42集百度云网盘下载 课程目录 042魔法方法:算术运算 041魔法方法:构造和析构 040类和对象:一些相关的BIF 039类和对象拾遗 038类和对象 ...

  8. 2023年最新最全uniCloud入门学习,零基础入门到实战项目 uni-admin打造uniapp网页后端 微信支付宝抖音小程序后端 unicloud数据后台快速打造uniapp小程序项目

    今天开始带着大家一起零基础学习uniCloud,在下面的课程中我们就简称uniCloud为cloud吧.我这里从零基础开始教大家,后面可以带大家简单的做一个实战项目.所以不用担心自己没有基础,跟着石头 ...

  9. 零基础自学python教程-零基础人员可以学习python吗?|Python培训基础教程

    python是目前市场上比较流行的一种语言,个人认为也是比较有发展前途的编程语言,如果你既然决定想要好好去学习python,那么一定要做好准备,下定决心,制定合适的计算. 相信对于很多基础的小白来说, ...

最新文章

  1. 对端边缘云网络计算模式:透明计算、移动边缘计算、雾计算和Cloudlet
  2. 字符串的html语言,html语言解析为属性字符串NSMutableAttributedString
  3. AutoCAD安装失败怎样卸载重新安装AutoCAD,解决AutoCAD安装失败的方法总结
  4. python基础知识面试题-python基础知识的重点面试题
  5. 关于网络安全几个问题的整理
  6. 陌生的是人心,是人性,是社会,是世道
  7. 还在用 Win?教你从零把 Mac 打造成开发利器
  8. Oracle中大批量删除数据的方法
  9. Javascript中数组去重的六种方法
  10. 记得把每一次面试当做经验积累,深夜思考
  11. 设计趋势|几何元素增加Banner版面率
  12. 【函数计算月报】2018年12月刊
  13. 【蓝桥杯历年题】2020蓝桥杯A组省赛第二场(10.17)【含蓝桥杯官网提交地址】
  14. 生信技能树课程记录笔记(一)20220523
  15. mysql删除图书信息,图书管理系统(一):出版社列表增加、删除和编辑
  16. mysql左连接查询慢
  17. matlab两条曲线方程求交点_帮忙matlab求两条曲线交点程序,不知问题出在哪里。...
  18. 继续分享最新版本的autohotkey自己编写的快捷键
  19. php开发公众号素材管理总结
  20. 经典进程同步问题(十)

热门文章

  1. 9种网页动画常用实现方式总结
  2. android view背景颜色,Android - ViewPager进阶篇之渐变背景色
  3. 【VB与数据库】——数据库连接
  4. js垃圾回收机制的优化
  5. MetaSpace浅析
  6. 包装类(装箱与拆箱)
  7. 百度滤网行动“关于“窃取用户隐私行为”的算法升级公告”
  8. 间接比较各种生物制剂治疗传统DMARDs难治性RA
  9. 如何避免Java死锁
  10. Spring知识点总结归纳。