元学习—高斯原型网络实现(Pytorch)

原理部分可以参考我之前的博客,元学习—高斯原型网络,本文在实现过程中,采用了基于半径的置信度计算,对于逆矩阵采用的softplus的计算方式。以评论文本分类作为基础任务,使用LSTM作为原型网络中的编码函数。下面给出具体的代码实现:

1. 数据处理部分(utils.py)

#encoding=utf-8
'''
以评论文本为例,进行计算。
'''
import torch
import torch.nn as nn
import numpy as np
import jieba
import random#基本文本数据集text_good = ['本片比想象中好剧情紧凑但是演员演绎到位全程紧张','在节奏上把控得很好包括当中的一些神转折戏剧性又不突兀我觉得是上乘之作','精彩甚至超出原版情节更加紧凑如果很爱电影的人看这部电影一定很爱这部电影','不出意外这就是我的年度华语最佳了无懈可击那种','可以说是今年看到最好的华语电影没有之一','情节紧凑节奏干净利落尤其是演员的表演都让人惊喜','剧情无敌演技无敌特别喜欢']
text_bad = ['这部片子就是一部没讲好故事的剧本连故事都没讲好其他的可想而知','个人觉得影片辣鸡电影漏洞真的太多了','剧情也太简单了一眼猜出来到底有多少僵尸被车撞死了特效不错。','难看死了还疯狂营销棒子的基操','真无聊拿着枪打丧尸末日逃生强行煽情','从头到尾都是垃圾煽你妹啊这结尾你妹啊''剧情拖沓演员演技不在线']def dataProcess():'''数据处理过程,整个过程生成词典等等:return:'''word_pos = [[item for item in jieba.cut(text)] for text in text_good]word_bad = [[item for item in jieba.cut(text)] for text in text_bad]word_all = []for item in word_pos:for key in item:word_all.append(key)for item in word_bad:for key in item:word_all.append(key)vocab = list(set(word_all))word2idx = {w:c for c,w in enumerate(vocab)}idx_word_pos = [[word2idx[item] for item in text] for text in word_pos]idx_word_neg = [[word2idx[item] for item in text] for text in word_bad]return vocab,word2idx,idx_word_pos,idx_word_negdef createOneHot(vocab,idx_word_pos,idx_word_neg):feature_pos_list = []feature_bad_list = []for text in idx_word_pos:#构建二维矩阵sequence = torch.zeros(size=[len(text),len(vocab)])for i in range(len(text)):sequence[i,text[i]] = 1.0sequence = torch.unsqueeze(sequence,0)feature_pos_list.append(sequence)for text in idx_word_neg:sequence = torch.zeros(size=[len(text),len(vocab)])for i in range(len(text)):sequence[i,text[i]] = 1.0sequence = torch.unsqueeze(sequence,0)feature_bad_list.append(sequence)return feature_pos_list,feature_bad_listdef RandomCreateSQ(feature_pos_list,feature_bad_list):'''通过随机抽样的方式,构建支持集和查询集:param feature_pos_list: 正类样本:param feature_bad_list: 负类样本:return:'''support_pos_list = list(random.sample(feature_pos_list,6))support_bad_list = list(random.sample(feature_bad_list,6))query_list = [support_pos_list.pop(),support_bad_list.pop()]support_pos_list.extend(support_bad_list)support_list = support_pos_listquery_label = [1,0]query_label = torch.LongTensor(query_label)return support_list,query_list,query_labeldef getData():vocab, word2idx, idx_word_pos, idx_word_neg = dataProcess()feature_pos_list, feature_bad_list = createOneHot(vocab, idx_word_pos, idx_word_neg)support_list, query_list, query_label = RandomCreateSQ(feature_pos_list, feature_bad_list)return len(vocab),support_list,query_list,query_label

2 模型搭建部分

#encoding=utf-8
'''高斯原型网络实现
核心:编码函数同时输出协方差矩阵
'''
import torch
import torch.nn as nn
from math import log,exp
import torch.nn.functional as F'''
pytorch中对于对于LSTM的使用:
1,输入参数: input_size, hidden_size,num_layers,batch_first 主要用于控制输入的维度,隐层维度,LSTM的层数,是否在batch放在第一个维度bidirectional: 是否是双向
2. 输出结果: output,hn,cn 分别表示每一个单元的输出结果整理成一个矩阵,维度和输入一致,hn表示最后一个隐层单元的结果,cn表示最后一个隐层的细胞状态hn:维度:第一个维度:表示方向*层数,第二个维度 batch的数量,第三个维度 隐层节点维度cn维度:第一个维度:表示方向*层数,第二个维度 batch的数量,第三个维度 隐层维度注意 1.如果LSTM是双向的,那么LSTM的输出是两个方向拼接的结果2. 可以初始化一个元组(h0,c0),表示初始的隐层状态和细胞状态
'''def softplus(Sraw):return log((1+exp(Sraw)),2)class model(nn.Module):def __init__(self,input_dim,embedding_dim):super(model, self).__init__()self.input_dim = input_dimself.embedding_dim = embedding_dimself.LstmEmbedding = nn.LSTM(input_size=input_dim,hidden_size=embedding_dim,num_layers=1,batch_first=True)def embedding_function(self,x):'''这里由于文章长度不一致,采用batch=1的训练方式,并且这里选选择LSTM最后一个时刻的输出结果作为整片文章的表示结果:param x: 表示输入的文章:return:'''embedding_result,(hn,cn) = self.LstmEmbedding(x)embedding_result = embedding_result.squeeze(0)return embedding_result[-1,:]def forward(self,support_list,query_list):'''在整个模型前向传播的过程中,主要分成两个部分,为了便于计算,这里选择半径,即编码输出的最后一个维度,表示当前向量的置信度第一个部分: 计算原型向量,高斯协方差矩阵结果第二个部分:对于查询集中的样本进行距离计算,划定类别注意: 对于逆矩阵的计算策略,这里采用的是 Sc = 1 + softplus(Sraw):param support_list: 支持集数据,两个类别,每一个类别5个数据:param query_list: 查询集数据,两个类别,每个类别1个数据:return:'''#对于支持集进行编码,最后一个维度作为置信度support_embedding = torch.zeros(size=[len(support_list),self.embedding_dim-1])gcsMatrix = []for i in range(len(support_list)):embedding_result = self.embedding_function(support_list[i]).reshape(1,-1)support_embedding[i,:] = embedding_result[0,:-1]gcsMatrix.append(embedding_result[0,-1].data.item())#计算原型向量first_class_e = [gcsMatrix[i] * support_embedding[i,:] for i in range(0,5)]first_class_m = [gcsMatrix[i] for i in range(0,5)]first_class_p = sum(first_class_e) / sum(first_class_m)sencond_class_e = [gcsMatrix[i] * support_embedding[i,:] for i in range(5,10)]sencond_class_m = [gcsMatrix[i] for i in range(5,10)]sencond_class_p = sum(sencond_class_e) / sum(sencond_class_m)support_class_p = torch.cat([first_class_p.reshape(1,-1),sencond_class_p.reshape(1,-1)],dim=0)#同过原型,计算协方差矩阵,及其逆矩阵temp_result = 0.0corfMatrix = [gcsMatrix[i] for i in range(0,5)]corsMatrix = [gcsMatrix[i] for i in range(5,10)]fSraw = sum(corfMatrix)SSraw = sum(corsMatrix)fSc = softplus(fSraw)SSc = softplus(SSraw)Sc = torch.tensor([[fSc],[SSc]])# 对于查询集中的数据进行编码,然后计算其对于原型向量的距离query_embedding = torch.zeros(size=[len(query_list),self.embedding_dim-1])gcqMatrix = []for i in range(len(query_list)):embedding_result = self.embedding_function(query_list[i]).reshape(1,-1)query_embedding[i,:] = embedding_result[0,:-1]gcqMatrix.append(embedding_result[0,-1].data.item())# 计算距离,分别对于查询集中的向量,利用协方差逆矩阵,计算距离# 构建距离矩阵,两个查询,两个别向量,即2*2的距离结果distence_matrix = torch.zeros(size=[2,2])for i in range(query_embedding.shape[0]):temp_query = query_embedding[i,:]temp_query = temp_query.repeat(2,1)result_left = temp_query - support_class_presult_right = torch.mul(Sc,result_left)dc2 = torch.mm(query_embedding[i,:].reshape(1,-1),result_right.T)distence_matrix[i,:] = dc2#将计算的距离结果转换成矩阵的形式,进行返回distence_soft = F.log_softmax(distence_matrix,dim=1)return distence_soft

3 模型训练部分

#encoding=utf-8
'''
模型训练部分:
1. 导入数据
2. 导入模型
3. 超参设置,学习率0.01,输入维度128,编码维度65(最后一个维度是置信度)
'''
from  model import model
from utils import getData
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as Flr = 0.01
embedding_dim = 65
epochs = 100input_dim,support_list,query_list,query_label = getData()
train_model = model(input_dim,embedding_dim)
optimer = optim.Adam(train_model.parameters(),lr=lr,weight_decay=5e-4)def train(epoch,support_list,query_list,query_label):optimer.zero_grad()output = train_model(support_list,query_list)loss = F.nll_loss(output,query_label)loss.backward()optimer.step()print("Epoch:{:04d}".format(epoch),"loss:{:.4f}".format(loss))if __name__ == "__main__":for i in range(epochs):input_dim, support_list, query_list, query_label = getData()train(i,support_list,query_list,query_label)

4 总结

高斯原型网络的实现难度不同,与传统的原型网络的实现难度类似,有兴趣的读者可以参考为之前对于一般原型网络的原理描述和基本实现。元学习——原型网络(Prototypical Networks)

元学习—高斯原型网络实现(Pytorch)相关推荐

  1. 高斯原型网络原论文高质量翻译

    论文地址:Gaussian Prototypical Networks for Few-Shot Learning on Omniglot 文章目录 摘要 1 引言 1.1 Few-shot lear ...

  2. 元学习 迁移学习_元学习就是您所需要的

    元学习 迁移学习 Update: This post is part of a blog series on Meta-Learning that I'm working on. Check out ...

  3. 繁凡的对抗攻击论文精读(二)CVPR 2021 元学习训练模拟器进行超高效黑盒攻击(清华)

    点我轻松弄懂深度学习所有基础和各大主流研究方向入门综述! <繁凡的深度学习笔记>,包含深度学习基础和 TensorFlow2.0,PyTorch 详解,以及 CNN,RNN,GNN,AE, ...

  4. Python 元学习实用指南:1~5

    原文:Hands-On Meta Learning with Python 协议:CC BY-NC-SA 4.0 译者:飞龙 本文来自[ApacheCN 深度学习 译文集],采用译后编辑(MTPE)流 ...

  5. 论文浅尝 - ICML2020 | 拆解元学习:理解 Few-Shots 任务中的特征表示

    论文笔记整理:申时荣,东南大学博士生. 来源:ICML2020 链接:http://arxiv.org/abs/2002.06753 元学习算法会生成特征提取器,这些特征提取器在进行few-shot分 ...

  6. 元学习—Meta Learning的兴起

    来源:专知 [导读]元学习描述了训练深度神经网络相关的更高级别的元素.在深度学习文献中,"元学习"一词经常表示神经网络架构的自动化设计,经常引用" AutoML" ...

  7. WSDM 2022最佳论文候选:港大提出多行为对比元学习的推荐系统

    ©PaperWeekly 原创 · 作者 | 韦玮 单位 | 香港大学 研究方向 | 推荐系统 论文标题: Contrastive Meta Learning with Behavior Multip ...

  8. 论文浅尝 | 基于动态记忆的原型网络进行元学习以实现少样本事件探测

    本文转载自公众号:浙大KG. 论文题目:Meta-Learning with Dynamic-Memory-Based Prototypical Network for Few-Shot Event ...

  9. 今日 Paper | 虚拟试穿网络;人群计数基准;联邦元学习;目标检测等

    2020-01-15 05:41:40 为了帮助各位学术青年更好地学习前沿研究成果和技术,AI科技评论联合Paper 研习社(paper.yanxishe.com),推出[今日 Paper]栏目, 每 ...

最新文章

  1. 一张心酸得不想起名字的照片,人艰就别拆了好吗 | 每日趣闻
  2. 深入理解 Java 线程池:ThreadPoolExecutor
  3. IDEA 代码生成插件 CodeMaker
  4. 4怎么放大字体_Word字体怎么放大?简单教你几招轻松搞定
  5. 导致大量kworker的原因_高尿酸与生活习惯有关?导致高尿酸的8个坏习惯,现在改还来得及...
  6. 女士怎么就不适合PhP呢,女人可以不美丽,但不能不智慧
  7. (转)静态变量和全局变量的区别
  8. 折线图_手把手教你用ECharts画折线图
  9. 利用canvas的getImageData()方法制作《在线取色器》
  10. 万恶淫为首,你想知道的真相!
  11. 家用智能摄像头横评:小米、华为海雀、TP-LINK、智汀
  12. CTF Alice与Bob
  13. 微信小程序 基础 - 19 (登录后用户头像的更新)
  14. Linux上C语言程序编译过程详解
  15. 手机辐射危害盘点:可降低男性精子活性
  16. 技术人总有想写文章的冲动却无疾而终?4个小Tips帮你快速上手~
  17. 押注汽车操作系统,手机厂商就能借无人驾驶弯道超车?
  18. 谷歌中国发布年度热榜 iPhone成全球最流行词
  19. [小故事大道理] -- 木桶原理的逆向思考
  20. mysql命令行远程登录时,用户名密码等连接信息配置正确,出现:ERROR 1045 (HY000): Access denied for user ‘xxx'

热门文章

  1. Devexpress GridControl GridView表中列增加按钮
  2. 利用css修饰个人主页,利用html/css设计一个简单个人主页
  3. docker registry存储镜像文件的组织结构
  4. flink exactly once和at least once的理解
  5. 浅谈JSP中include指令与include动作标识的区别
  6. 华为三层交换机路由配置案例_华为三层交换机配置实例
  7. 【踩坑系列】uniapp之h5 跨域的问题
  8. global-forwards和forward的区别
  9. phpexcel 设置批注_phpexcel中文教程-设置表格字体颜色背景样式、数据格式、对齐方式、添加图片、批注、文字块、合并拆分单元格、单元格密码保护...
  10. 12bit sar adc电路,可直接仿真,逻辑模块也是实际电路,可指导利用cadence或者matlab进行频谱分析