模型介绍

TransE模型的基本思想是使head向量和relation向量的和尽可能靠近tail向量。这里我们用L1或L2范数来衡量它们的靠近程度。

损失函数是使用了负抽样的max-margin函数。

L(y, y’) = max(0, margin - y + y’)

y是正样本的得分,y'是负样本的得分。然后使损失函数值最小化,当这两个分数之间的差距大于margin的时候就可以了(我们会设置这个值,通常是1)。

由于我们使用距离来表示得分,所以我们在公式中加上一个减号,知识表示的损失函数为:

其中,d是:

这是L1或L2范数。至于如何得到负样本,则是将head实体或tail实体替换为三元组中的随机实体。


代码实现:

具体的代码和数据集(YAGO、umls、FB15K、WN18)请见Github:
https://github.com/Colinasda/TransE.git

import codecs
import numpy as np
import copy
import time
import randomentities2id = {}
relations2id = {}def dataloader(file1, file2, file3):print("load file...")entity = []relation = []with open(file2, 'r') as f1, open(file3, 'r') as f2:lines1 = f1.readlines()lines2 = f2.readlines()for line in lines1:line = line.strip().split('\t')if len(line) != 2:continueentities2id[line[0]] = line[1]entity.append(line[1])for line in lines2:line = line.strip().split('\t')if len(line) != 2:continuerelations2id[line[0]] = line[1]relation.append(line[1])triple_list = []with codecs.open(file1, 'r') as f:content = f.readlines()for line in content:triple = line.strip().split("\t")if len(triple) != 3:continueh_ = entities2id[triple[0]]r_ = relations2id[triple[1]]t_ = entities2id[triple[2]]triple_list.append([h_, r_, t_])print("Complete load. entity : %d , relation : %d , triple : %d" % (len(entity), len(relation), len(triple_list)))return entity, relation, triple_listdef norm_l1(h, r, t):return np.sum(np.fabs(h + r - t))def norm_l2(h, r, t):return np.sum(np.square(h + r - t))class TransE:def __init__(self, entity, relation, triple_list, embedding_dim=50, lr=0.01, margin=1.0, norm=1):self.entities = entityself.relations = relationself.triples = triple_listself.dimension = embedding_dimself.learning_rate = lrself.margin = marginself.norm = normself.loss = 0.0def data_initialise(self):entityVectorList = {}relationVectorList = {}for entity in self.entities:entity_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)entityVectorList[entity] = entity_vectorfor relation in self.relations:relation_vector = np.random.uniform(-6.0 / np.sqrt(self.dimension), 6.0 / np.sqrt(self.dimension),self.dimension)relation_vector = self.normalization(relation_vector)relationVectorList[relation] = relation_vectorself.entities = entityVectorListself.relations = relationVectorListdef normalization(self, vector):return vector / np.linalg.norm(vector)def training_run(self, epochs=1, nbatches=100, out_file_title = ''):batch_size = int(len(self.triples) / nbatches)print("batch size: ", batch_size)for epoch in range(epochs):start = time.time()self.loss = 0.0# Normalise the embedding of the entities to 1for entity in self.entities.keys():self.entities[entity] = self.normalization(self.entities[entity]);for batch in range(nbatches):batch_samples = random.sample(self.triples, batch_size)Tbatch = []for sample in batch_samples:corrupted_sample = copy.deepcopy(sample)pr = np.random.random(1)[0]if pr > 0.5:# change the head entitycorrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]while corrupted_sample[0] == sample[0]:corrupted_sample[0] = random.sample(self.entities.keys(), 1)[0]else:# change the tail entitycorrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]while corrupted_sample[2] == sample[2]:corrupted_sample[2] = random.sample(self.entities.keys(), 1)[0]if (sample, corrupted_sample) not in Tbatch:Tbatch.append((sample, corrupted_sample))self.update_triple_embedding(Tbatch)end = time.time()print("epoch: ", epoch, "cost time: %s" % (round((end - start), 3)))print("running loss: ", self.loss)with codecs.open(out_file_title +"TransE_entity_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f1:for e in self.entities.keys():# f1.write("\t")# f1.write(e + "\t")f1.write(str(list(self.entities[e])))f1.write("\n")with codecs.open(out_file_title +"TransE_relation_" + str(self.dimension) + "dim_batch" + str(batch_size), "w") as f2:for r in self.relations.keys():# f2.write("\t")# f2.write(r + "\t")f2.write(str(list(self.relations[r])))f2.write("\n")def update_triple_embedding(self, Tbatch):# deepcopy 可以保证,即使list嵌套list也能让各层的地址不同, 即这里copy_entity 和# entitles中所有的elements都不同copy_entity = copy.deepcopy(self.entities)copy_relation = copy.deepcopy(self.relations)for correct_sample, corrupted_sample in Tbatch:correct_copy_head = copy_entity[correct_sample[0]]correct_copy_tail = copy_entity[correct_sample[2]]relation_copy = copy_relation[correct_sample[1]]corrupted_copy_head = copy_entity[corrupted_sample[0]]corrupted_copy_tail = copy_entity[corrupted_sample[2]]correct_head = self.entities[correct_sample[0]]correct_tail = self.entities[correct_sample[2]]relation = self.relations[correct_sample[1]]corrupted_head = self.entities[corrupted_sample[0]]corrupted_tail = self.entities[corrupted_sample[2]]# calculate the distance of the triplesif self.norm == 1:correct_distance = norm_l1(correct_head, relation, correct_tail)corrupted_distance = norm_l1(corrupted_head, relation, corrupted_tail)else:correct_distance = norm_l2(correct_head, relation, correct_tail)corrupted_distance = norm_l2(corrupted_head, relation, corrupted_tail)loss = self.margin + correct_distance - corrupted_distanceif loss > 0:self.loss += lossprint(loss)correct_gradient = 2 * (correct_head + relation - correct_tail)corrupted_gradient = 2 * (corrupted_head + relation - corrupted_tail)if self.norm == 1:for i in range(len(correct_gradient)):if correct_gradient[i] > 0:correct_gradient[i] = 1else:correct_gradient[i] = -1if corrupted_gradient[i] > 0:corrupted_gradient[i] = 1else:corrupted_gradient[i] = -1correct_copy_head -= self.learning_rate * correct_gradientrelation_copy -= self.learning_rate * correct_gradientcorrect_copy_tail -= -1 * self.learning_rate * correct_gradientrelation_copy -= -1 * self.learning_rate * corrupted_gradientif correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replaces the tail entity, the head entity's embedding need to be updated twicecorrect_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrupted_copy_tail -= self.learning_rate * corrupted_gradientelif correct_sample[2] == corrupted_sample[2]:# if corrupted_triples replaces the head entity, the tail entity's embedding need to be updated twicecorrupted_copy_head -= -1 * self.learning_rate * corrupted_gradientcorrect_copy_tail -= self.learning_rate * corrupted_gradient# normalising these new embedding vector, instead of normalising all the embedding togethercopy_entity[correct_sample[0]] = self.normalization(correct_copy_head)copy_entity[correct_sample[2]] = self.normalization(correct_copy_tail)if correct_sample[0] == corrupted_sample[0]:# if corrupted_triples replace the tail entity, update the tail entity's embeddingcopy_entity[corrupted_sample[2]] = self.normalization(corrupted_copy_tail)elif correct_sample[2] == corrupted_sample[2]:# if corrupted_triples replace the head entity, update the head entity's embeddingcopy_entity[corrupted_sample[0]] = self.normalization(corrupted_copy_head)# the paper mention that the relation's embedding don't need to be normalisedcopy_relation[correct_sample[1]] = relation_copy# copy_relation[correct_sample[1]] = self.normalization(relation_copy)self.entities = copy_entityself.relations = copy_relationif __name__ == '__main__':file1 = "/umls/train.txt"file2 = "/umls/entity2id.txt"file3 = "/umls/relation2id.txt"entity_set, relation_set, triple_list = dataloader(file1, file2, file3)# modify by yourselftransE = TransE(entity_set, relation_set, triple_list, embedding_dim=30, lr=0.01, margin=1.0, norm=2)transE.data_initialise()transE.training_run(out_file_title="umls_")

TransE模型的简单介绍TransE模型的python代码实现相关推荐

  1. 【论文必用】模糊C均值聚类的简单介绍、复现及Python代码详解、聚类可视化图的绘制过程详解!

    详解模糊C均值聚类 一.聚类 二.模糊C均值聚类 三.模糊C均值聚类的Python实现 四.参考链接 一.聚类 聚类的定义: 将物理或抽象对象的集合分成由类似的对象组成的多个类的过程被称为聚类.由聚类 ...

  2. Python 框架 之 Django MVT 下的 M 的 模型的简单介绍和使用

    Python 框架 之 Django MVT 下的 M 的 模型的简单介绍和使用 目录

  3. VAE 模型基本原理简单介绍

    VAE 模型基本原理简单介绍 1. 编写目的 2. 推荐资料 3. 相关背景 3.1 生成模型(Generative model): 3.2 隐变量模型(Latent Variable Models) ...

  4. 图像去噪简单介绍—并给出示例代码

    文章目录 图像去噪简单介绍-并给出示例代码 去噪的基本原理 常见的噪声类型 高斯噪声 椒盐噪声 马赛克噪声 脉冲噪声 添加噪声的代码 添加高斯噪声 添加椒盐噪声 常用的去噪方法 均值滤波 中值滤波 高 ...

  5. 联邦学习算法介绍-FedAvg详细案例-Python代码获取

    联邦学习算法介绍-FedAvg详细案例-Python代码获取 一.联邦学习系统框架 二.联邦平均算法(FedAvg) 三.联邦随梯度下降算法 (FedSGD) 四.差分隐私随联邦梯度下降算法 (DP- ...

  6. xgboost简单介绍_XGBOOST模型介绍

    描述 前言 这是机器学习系列的第三篇文章,对于住房租金预测比赛的总结这将是最后一篇文章了,比赛持续一个月自己的总结竟然也用了一个月,牵强一点来说机器学习也将会是一个漫长的道路,后续机器学习的文章大多数 ...

  7. 退化过程及模型的简单介绍

    1. 退化研究背景 产品可靠性是指元件.产品.系统等在规定条件下和规定时间内完成规定功能的能力.在可靠性理论中,将产品丧失所规定功能的现象称为失效. 产品的失效主要分成两种类型.第一种是突发型失效,指 ...

  8. Pytorch模型层简单介绍

    模型层layers 深度学习模型一般由各种模型层组合而成. torch.nn中内置了非常丰富的各种模型层.它们都属于nn.Module的子类,具备参数管理功能. 例如: nn.Linear, nn.F ...

  9. ARMA模型时间序列分析全流程(附python代码)

    ARMA模型建模流程 建模流程 1)平稳性检验 原始数据data经过清洗得到data_new,然后进行平稳性检验,非平稳数据无法采用ARMA模型进行预测,ADF检验可以用来确定数据的平稳性,这里导入的 ...

  10. 数据结构之图:有向图的介绍与实现,Python代码实现——25

    有向图的介绍 引入 在实际生活中,很多应用相关的图都是有方向性的,最直观的就是网络,可以从A页面通过链接跳转到B页面,那么a和b连接的方向是a->b,但不能说是b->a,此时我们就需要使用 ...

最新文章

  1. ASP.NET 2.0中GRIDVIEW排序
  2. 【计算机网络】网络安全 : 数据加密模型 ( 加密模型 | 密钥 | 密码学 | 密码安全 )
  3. 这个陶瓷电阻烙铁架不错哦,最新一期的电子趣事分享给大家
  4. wikioi 1034 家 实时动态的网络流量(费用流)
  5. 【jzoj】2018.2.3NOIP普及组——D组模拟赛
  6. json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
  7. 12-Flutter移动电商实战-首页导航区域编写
  8. nlp 财务提取_RPA,智慧财务时代的“珍妮纺织机”来了?
  9. Python获取2019-nCoV疫情实时追踪数据
  10. vue 表单 input text
  11. AddLinkedServer
  12. 用python画圣诞树-python圣诞树
  13. java 文件上传 乱码_java中文传值乱码问题的解决方法
  14. Mybatis-实现逆向代理
  15. Atitit 安全登录退出管理法v3 tbb.docx 目录 1.1. 未持有效出入边界票据。。校验票据有效性 1 1.2. 从事与票据种类不符的活动的 2 1.3. 数据为数字的应该校验数字 2
  16. 调用阿里API获取城市天气信息
  17. 使用Eclipse编译运行MapReduce程序
  18. Tableau 中国最美八条骑行线路(二)海拔和气温
  19. __stack_chk_fail之栈帧溢出检测技术
  20. Enolsoft PDF Converter with OCR mac(带有OCR的PDF转换器)

热门文章

  1. GBK汉字的索引方法
  2. 企业文化是数字化转型最大障碍-解读《2022年首席数据官调查报告》
  3. 巧用 10分钟邮箱 申请小红伞 免费KEY 92天
  4. linux端口快速释放,Linux 快速释放端口与释放内存缓存,linux释放端口缓存
  5. 初尝Mcafee之CEE企业版概述【01】
  6. 资源---2020考研---英语网站---资料3(考研英语,英语学习。宣言:自从用了这个英语网站,七大姑八大姨开始担心他家孩子比不过我了~~~~~~~~~FT中文网)
  7. 信息论基础知识:【信息熵 不确定性】
  8. mcal rtm_RTM的完整形式是什么?
  9. Ubuntu下安装QQ(wineQQ)
  10. 菜鸟学习C++之Console Application