文章目录

  • TransE
    • 知识图谱基础
    • 知识表示
    • 算法描述
    • 代码分析
    • 数据

TransE

知识图谱基础

三元组(h,r,t)

知识表示

即将实体和关系向量化,embedding

算法描述

思想:一个正确的三元组的embedding会满足:h+r=t

定义距离d表示向量之间的距离,一般取L1或者L2,期望正确的三元组的距离越小越好,而错误的三元组的距离越大越好。为此给出目标函数为:


梯度求解

代码分析

  • 定义类:
参数:目标函数的常数——margin学习率——learningRate向量维度——dim实体列表——entityList(读取文本文件,实体+id)关系列表——relationList(读取文本文件,关系 + id)三元关系列表——tripleList(读取文本文件,实体 + 实体 + 关系)损失值——loss距离公式——L1
  • 向量初始化

规定初始化维度和取值范围(TransE算法原理中的取值范围)
涉及的函数:

    init:随机生成值norm:归一化
  • 训练向量
    getSample——随机选取部分三元关系,SbatchgetCorruptedTriplet(sbatch)——随机替换三元组的实体,h、t中任意一个被替换,但不同时替换。update——更新

L2更新向量的推导过程:

python 函数
uniform(a, b)#随机生成a,b之间的数,左闭右开。
求向量的模,var = linalg.norm(list)

"""
@version: 3.7
@author: jiayalu
@file: trainTransE.py
@time: 22/08/2019 10:56
@description: 用于对知识图谱中的实体、关系基于TransE算法训练获取向量
数据:三元关系
实体id和关系id
结果为:两个文本文件,即entityVector.txt和relationVector.txt    实体 [array向量]"""
from random import uniform, sample
from numpy import *
from copy import deepcopyclass TransE:def __init__(self, entityList, relationList, tripleList, margin = 1, learingRate = 0.00001, dim = 10, L1 = True):self.margin = marginself.learingRate = learingRateself.dim = dim#向量维度self.entityList = entityList#一开始,entityList是entity的list;初始化后,变为字典,key是entity,values是其向量(使用narray)。self.relationList = relationList#理由同上self.tripleList = tripleList#理由同上self.loss = 0self.L1 = L1def initialize(self):'''初始化向量'''entityVectorList = {}relationVectorList = {}for entity in self.entityList:n = 0entityVector = []while n < self.dim:ram = init(self.dim)#初始化的范围entityVector.append(ram)n += 1entityVector = norm(entityVector)#归一化entityVectorList[entity] = entityVectorprint("entityVector初始化完成,数量是%d"%len(entityVectorList))for relation in self. relationList:n = 0relationVector = []while n < self.dim:ram = init(self.dim)#初始化的范围relationVector.append(ram)n += 1relationVector = norm(relationVector)#归一化relationVectorList[relation] = relationVectorprint("relationVectorList初始化完成,数量是%d"%len(relationVectorList))self.entityList = entityVectorListself.relationList = relationVectorListdef transE(self, cI = 20):print("训练开始")for cycleIndex in range(cI):Sbatch = self.getSample(3)Tbatch = []#元组对(原三元组,打碎的三元组)的列表 :{((h,r,t),(h',r,t'))}for sbatch in Sbatch:tripletWithCorruptedTriplet = (sbatch, self.getCorruptedTriplet(sbatch))# print(tripletWithCorruptedTriplet)if(tripletWithCorruptedTriplet not in Tbatch):Tbatch.append(tripletWithCorruptedTriplet)self.update(Tbatch)if cycleIndex % 100 == 0:print("第%d次循环"%cycleIndex)print(self.loss)self.writeRelationVector("E:\pythoncode\knownlageGraph\\transE-master\\relationVector.txt")self.writeEntilyVector("E:\pythoncode\knownlageGraph\\transE-master\\entityVector.txt")self.loss = 0def getSample(self, size):return sample(self.tripleList, size)def getCorruptedTriplet(self, triplet):'''training triplets with either the head or tail replaced by a random entity (but not both at the same time):param triplet::return corruptedTriplet:'''i = uniform(-1, 1)if i < 0:  # 小于0,打坏三元组的第一项while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[0]:breakcorruptedTriplet = (entityTemp, triplet[1], triplet[2])else:  # 大于等于0,打坏三元组的第二项while True:entityTemp = sample(self.entityList.keys(), 1)[0]if entityTemp != triplet[1]:breakcorruptedTriplet = (triplet[0], entityTemp, triplet[2])return corruptedTripletdef update(self, Tbatch):copyEntityList = deepcopy(self.entityList)copyRelationList = deepcopy(self.relationList)for tripletWithCorruptedTriplet in Tbatch:headEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][0]]  # tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tupletailEntityVector = copyEntityList[tripletWithCorruptedTriplet[0][1]]relationVector = copyRelationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTriplet = copyEntityList[tripletWithCorruptedTriplet[1][1]]headEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][0]]  # tripletWithCorruptedTriplet是原三元组和打碎的三元组的元组tupletailEntityVectorBeforeBatch = self.entityList[tripletWithCorruptedTriplet[0][1]]relationVectorBeforeBatch = self.relationList[tripletWithCorruptedTriplet[0][2]]headEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][0]]tailEntityVectorWithCorruptedTripletBeforeBatch = self.entityList[tripletWithCorruptedTriplet[1][1]]if self.L1:distTriplet = distanceL1(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch,relationVectorBeforeBatch)distCorruptedTriplet = distanceL1(headEntityVectorWithCorruptedTripletBeforeBatch,tailEntityVectorWithCorruptedTripletBeforeBatch,relationVectorBeforeBatch)else:distTriplet = distanceL2(headEntityVectorBeforeBatch, tailEntityVectorBeforeBatch,relationVectorBeforeBatch)distCorruptedTriplet = distanceL2(headEntityVectorWithCorruptedTripletBeforeBatch,tailEntityVectorWithCorruptedTripletBeforeBatch,relationVectorBeforeBatch)eg = self.margin + distTriplet - distCorruptedTripletif eg > 0:  # [function]+ 是一个取正值的函数self.loss += egif self.L1:tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)tempPositiveL1 = []tempNegtativeL1 = []for i in range(self.dim):  # 不知道有没有pythonic的写法(比如列表推倒或者numpy的函数)?if tempPositive[i] >= 0:tempPositiveL1.append(1)else:tempPositiveL1.append(-1)if tempNegtative[i] >= 0:tempNegtativeL1.append(1)else:tempNegtativeL1.append(-1)tempPositive = array(tempPositiveL1)tempNegtative = array(tempNegtativeL1)else:#根据损失函数的求梯度tempPositive = 2 * self.learingRate * (tailEntityVectorBeforeBatch - headEntityVectorBeforeBatch - relationVectorBeforeBatch)tempNegtative = 2 * self.learingRate * (tailEntityVectorWithCorruptedTripletBeforeBatch - headEntityVectorWithCorruptedTripletBeforeBatch - relationVectorBeforeBatch)headEntityVector = headEntityVector + tempPositive#更新向量tailEntityVector = tailEntityVector - tempPositiverelationVector = relationVector + tempPositive - tempNegtativeheadEntityVectorWithCorruptedTriplet = headEntityVectorWithCorruptedTriplet - tempNegtativetailEntityVectorWithCorruptedTriplet = tailEntityVectorWithCorruptedTriplet + tempNegtative# 只归一化这几个刚更新的向量,而不是按原论文那些一口气全更新了copyEntityList[tripletWithCorruptedTriplet[0][0]] = norm(headEntityVector)copyEntityList[tripletWithCorruptedTriplet[0][1]] = norm(tailEntityVector)copyRelationList[tripletWithCorruptedTriplet[0][2]] = norm(relationVector)copyEntityList[tripletWithCorruptedTriplet[1][0]] = norm(headEntityVectorWithCorruptedTriplet)copyEntityList[tripletWithCorruptedTriplet[1][1]] = norm(tailEntityVectorWithCorruptedTriplet)self.entityList = copyEntityListself.relationList = copyRelationListdef writeEntilyVector(self, dir):print("写入实体")entityVectorFile = open(dir, 'w', encoding="utf-8")for entity in self.entityList.keys():entityVectorFile.write(entity + "    ")entityVectorFile.write(str(self.entityList[entity].tolist()))entityVectorFile.write("\n")entityVectorFile.close()def writeRelationVector(self, dir):print("写入关系")relationVectorFile = open(dir, 'w', encoding="utf-8")for relation in self.relationList.keys():relationVectorFile.write(relation + "    ")relationVectorFile.write(str(self.relationList[relation].tolist()))relationVectorFile.write("\n")relationVectorFile.close()def init(dim):return uniform(-6/(dim**0.5), 6/(dim**0.5))def norm(list):'''归一化:param 向量:return: 向量的平方和的开方后的向量'''var = linalg.norm(list)i = 0while i < len(list):list[i] = list[i]/vari += 1return array(list)def distanceL1(h, t ,r):s = h + r - tsum = fabs(s).sum()return sumdef distanceL2(h, t, r):s = h + r - tsum = (s*s).sum()return sumdef openDetailsAndId(dir,sp="    "):idNum = 0list = []with open(dir,"r", encoding="utf-8") as file:lines = file.readlines()for line in lines:DetailsAndId = line.strip().split(sp)list.append(DetailsAndId[0])idNum += 1return idNum, listdef openTrain(dir,sp="    "):num = 0list = []with open(dir, "r", encoding="utf-8") as file:lines = file.readlines()for line in lines:triple = line.strip().split(sp)if(len(triple)<3):continuelist.append(tuple(triple))num += 1return num, listif __name__ == '__main__':dirEntity = "E:\pythoncode\ZXknownlageGraph\TransEgetvector\entity2id.txt"entityIdNum, entityList = openDetailsAndId(dirEntity)dirRelation = "E:\pythoncode\ZXknownlageGraph\TransEgetvector\\relation2id.txt"relationIdNum, relationList = openDetailsAndId(dirRelation)dirTrain = "E:\pythoncode\ZXknownlageGraph\TransEgetvector\\train.txt"tripleNum, tripleList = openTrain(dirTrain)# print(tripleNum, tripleList)print("打开TransE")transE = TransE(entityList,relationList,tripleList, margin=1, dim = 128)print("TranE初始化")transE.initialize()transE.transE(1500)transE.writeRelationVector("E:\pythoncode\ZXknownlageGraph\TransEgetvector\\relationVector.txt")transE.writeEntilyVector("E:\pythoncode\ZXknownlageGraph\TransEgetvector\\entityVector.txt")

数据



结果向量

TransE算法原理与案例相关推荐

  1. Python大数据综合应用 :零基础入门机器学习、深度学习算法原理与案例

    机器学习.深度学习算法原理与案例实现暨Python大数据综合应用高级研修班 一.课程简介 课程强调动手操作:内容以代码落地为主,以理论讲解为根,以公式推导为辅.共4天8节,讲解机器学习和深度学习的模型 ...

  2. 动态规划算法原理及案例

    参考链接:https://www.cnblogs.com/huststl/p/8664608.html 动态规划基本概念 定义:动态规划的意义就是通过采用递推(或者分而治之)的策略,通过解决大问题的子 ...

  3. 用通俗易懂的方式讲解:CatBoost 算法原理及案例

    文章目录 知识汇总 解惑答疑 1.梯度提升概述 2.什么是 CatBoost 3.CatBoost 的主要特点 01 对称树 FloatFeature OneHotFeature OnlineCtr ...

  4. 数据结构常见算法原理讲解100篇(一)-递归和分治算法原理及案例应用

    前言 在计算机科学中,分治法是一种很重要的算法.字面上的解释是"分而治之",就是把一个复杂的问题分成两个或更多的相同或相似的子问题,再把子问题分成更小的子问题--直到最后子问题可以 ...

  5. 【分类算法】Logistic算法原理、标准评分卡开发流程、python代码案例

    [博客地址]:https://blog.csdn.net/sunyaowu315 [博客大纲地址]:https://blog.csdn.net/sunyaowu315/article/details/ ...

  6. [Python图像处理] 二十六.图像分类原理及基于KNN、朴素贝叶斯算法的图像分类案例

    该系列文章是讲解Python OpenCV图像处理知识,前期主要讲解图像入门.OpenCV基础用法,中期讲解图像处理的各种算法,包括图像锐化算子.图像增强技术.图像分割等,后期结合深度学习研究图像识别 ...

  7. Spark 随机森林算法原理、源码分析及案例实战

    图 1. Spark 与其它大数据处理工具的活跃程度比较 回页首 环境要求 操作系统:Linux,本文采用的 Ubuntu 10.04,大家可以根据自己的喜好使用自己擅长的 Linux 发行版 Jav ...

  8. 机器学习:朴素贝叶斯模型算法原理(含实战案例)

    机器学习:朴素贝叶斯模型算法原理 作者:i阿极 作者简介:Python领域新星作者.多项比赛获奖者:博主个人首页

  9. k均值聚类算法案例 r语言iris_K-means算法原理

    聚类的基本思想 俗话说"物以类聚,人以群分" 聚类(Clustering)是一种无监督学习(unsupervised learning),简单地说就是把相似的对象归到同一簇中.簇内 ...

  10. 推荐算法实战项目:用户协同过滤(UserCF)原理以及案例实战(附完整 Python 代码)

    协同过滤(collaborative filtering)是一种在推荐系统中广泛使用的技术.该技术通过分析用户或者事物之间的相似性,来预测用户可能感兴趣的内容并将此内容推荐给用户. 这里的相似性可以是 ...

最新文章

  1. 操作系统学习:内存分页与中断
  2. 2011 ScrumGathering敏捷个人.pptx
  3. 初创公司MongoDB最佳实践策略和躲坑秘笈
  4. Redis持久化实践及数据恢复
  5. ios 自定义UIView绘制时文字上下颠倒问题解决
  6. mysql set schema_Mysql数据库优化学习之一 Schema优化
  7. 利用管道实现进程间通信
  8. raidrive安装失败_记一次RaiDrive映射OneDrive遇到的问题
  9. JS中URL中的特殊字符问题:escape,encodeURI,encodeURIComponent(转)
  10. 【剑指offer】旋转数组的最小数字
  11. Pandas使用DataFrame进行数据分析比赛进阶之路(一)
  12. 世界第一台电脑_电脑的诞生与发展
  13. 【12张手绘图】我搞懂了微服务架构!
  14. 猿人学第二题,手撕OB混淆给你看(Step1-开篇)
  15. 手机HiFi嗨不嗨 看完这些才能算是入坑
  16. PostgreSQL数据库——Pigsty
  17. 实验8 OpenGL太阳系动画
  18. linux下安装mysql8.0
  19. View的测量规则以及三大方法流程
  20. 论文好词好句开源共享@GitHub

热门文章

  1. 《OpenGL编程指南(原书第9版)》——第1章 1.0OpenGL概述
  2. java计算机毕业设计网上书店进销存管理系统源程序+mysql+系统+lw文档+远程调试
  3. delphi7 安装/卸载控件通用方法
  4. VC2008学习笔记
  5. 动易 转 html5,动易dedecms数据转成dedecms的php程序
  6. 韦东山第3期嵌入式Linux项目-视频监控-2-从零写USB摄像头驱动(UVC驱动)
  7. 学生信息管理系统html界面,学生信息管理系统jsp课程设计.doc
  8. 各省能源平衡表(2000-2018年)
  9. DevExpress控件之主题
  10. Ubuntu16.04安装为知笔记(WizNote)