转载

参考

本次代码是在本地win10下3.6python下运行的,只需要修改运行函数的参数设置,也就是修改一两行即可

目录

代码结构分析

结构

数据集

ppi数据集信息(example_data)

toy-ppi-G.json 图的信息

toy-ppi-class_map.json

toy-ppi-id_map.json

toy-ppi-walks.txt

toy-ppi-feats.npy

实验环境要求

配置环境

运行

运行unsupervised_train.py

运行supervised_train.py

代码分析

__init__.py

utils.py

neigh_samplers.py

models.py

layers.py

minibatch.py

aggregators.py

prediction.py

supervised_train.py

unsupervised_train.py

inits.py

citation_eval.py

ppi_eval.py

reddit_eval.py

参考



代码结构分析

结构

1.文件目录
├──eval_scripts //验证集
├──example_data //ppi数据集
└──graphsage//模型结构定义、GCN层定义、……

2.eval_scripts //验证集目录内容
├──citation_eval.py
├──ppi_eval.py
└──reddit_eval.py

3.example_data //ppi数据集
├──toy-ppi-class_map.json //图节点id映射到类。
├──toy-ppi-feats.npy //预训练好得到的features
├──toy-ppi-G.json //图的信息
├──ttoy-ppi-walks//从一点出发随机游走到邻居节点的情况,对于每个点取198次
└──toy-ppi-id_map.json //节点编号与序号的一一对应

4.graphsage//模型结构定义
├── init //导入模块
├──aggregators // 聚合函数定义
├──inits.py // 初始化的一些公用函数
├── layers // GCN层的定义
├── metrics // 评测指标的计算
├── minibatch//minibatch iterator函数定义
├── models // 各种模型结构定义
├── neigh_samplers //定义从节点的邻居中采样的采样器
├── prediction//
├── supervised_models
├── supervised_train
├── unsupervised_train
└── utils // 工具函数的定义

数据集

数据集 #图 #节点 #边 #特征 #标签(y)
Cora 1 2708 5429 1433 7
Citeseer 1 3327 4732 3703 6
Pubmed 1 19717 44338 500 3
PPI 24 56944 818716 50 121
Reddit 1 232965 11606919 602 41
Nell 1 65755 266144 61278 210

ppi数据集信息(example_data)

toy-ppi-G.json 图的信息

数据中只有一个图,用来做节点分类任务。
图为无向图,由nodes集和links集合构成,每个集合都是一个list,里面包含的每一个node或link都是词典形式存储的

数据格式:

{ directed: falsegraph : {{name: disjoint_union(,) }nodes:  [{  test: falseid: 0features: [ ... ]val: falselable: [ ... ]}{...}...]links: [{  test_removed: falsetrain_removed: falsetarget: 800 # 指向的节点id(默认从小节点指向大节点)source: 0   # 从0节点按顺序展示}{...}...]}
}
  • name: disjoint_union(,)表示图的名字
  • toy-ppi-G.json里只有一个图可能是因为用于节点分类只需要一张图即可,做图分类任务需要多张图
  • 可以看出,这是个无向图,并且由nodes集和links集合构成,每个集合都是一个list,里面包含的每一个node或link都是词典形式存储的
  • 从github下载的源码中,没有links部分的数据?其实是由于文件过大显示不完整,其实是存在的,比如节点只显示到1883,总共14754个

toy-ppi-class_map.json

图节点id映射到类。格式为:{“0”: [1, 0, 0,…],…,“14754”: [1, 1, 0, 0,…]}

toy-ppi-id_map.json

节点编号与序号的一一对应;数据格式为:{“0”: 0, “1”: 1,…, “14754”: 14754}

toy-ppi-walks.txt

0    708
0   3163
0   276
0   1789
...
1   15
1   1455
1   1327
1   317
1   63
1   1420
...
9715    7369
9715    8983
9715    6983
  • 从一点出发随机游走到邻居节点的情况,对于每个点取198次(即可能有重复情况
  • 例如:0 708 表示从0点走到708点。

toy-ppi-feats.npy

预训练好得到的features。

数据处理的时候主要通过两个函数
(1):np.save(“test.npy”,数据结构) ----存数据
(2):data =np.load('test.npy") ----取数据
例如,存列表

z = [[[1, 2, 3], ['w']], [[1, 2, 3], ['w']]]
np.save('test.npy', z)
x = np.load('test.npy')x:
->array([[list([1, 2, 3]), list(['w'])],[list([1, 2, 3]), list(['w'])]], dtype=object)

例如,存字典

x
-> {0: 'wpy', 1: 'scg'}
np.save('test.npy',x)
x = np.load('test.npy')
x
->array({0: 'wpy', 1: 'scg'}, dtype=object)

在存为字典格式读取后,需要先调用如下语句
data.item()
将数据numpy.ndarray对象转换为dict

实验环境要求

  • networkx版本必须小于等于1.11,pip install networkx==1.11
  • 其他的也要严格按照实验室环境要求的做,要不然引起不必要的麻烦
  • python版本3.6
absl-py==0.2.2
astor==0.6.2
backports.weakref==1.0.post1
bleach==1.5.0
decorator==4.3.0
enum34==1.1.6
funcsigs==1.0.2
futures==3.2.0
gast==0.2.0
grpcio==1.12.1
html5lib==0.9999999
Markdown==2.6.11
mock==2.0.0
networkx==1.11
numpy==1.14.5
pbr==4.0.4
protobuf==3.6.0
scikit-learn==0.19.1
scipy==1.1.0
six==1.11.0
sklearn==0.0
tensorboard==1.8.0
tensorflow==1.8.0
termcolor==1.1.0
Werkzeug==0.14.1

配置环境

python最开始选的3.6
conda activate GraphSAGE-master
conda install tensorflow==1.8.0
pip list
conda install networkx==1.11
pip list
conda install scikit-learn==0.19.1
pip list
#发现上面的py3.6版本安装不了enum34==1.1.6和futures==3.2.0,但是发现运行代码的时候不影响#发现上面的py3.6版本安装不了enum34==1.1.6和futures==3.2.0,于是重新创建环境py2.7(未成功)
conda create -n py27 python=2.7
conda activate py27
conda install enum34==1.1.6
conda install futures==3.2.0
conda install tensorflow==1.8.0#貌似不行,需要离线安装,因为win不支持py2.7安装tf了,但是没安好
conda install networkx==1.11
conda install scikit-learn==0.19.1

【第一次的环境】

【第二次的环境,未成功,就差tf1.8的离线安装】

运行

pycharm选择解析器,会发现有些安装包不适配,pyachrm提醒的时候直接安装即可

【发现下面还是不行,暂时放弃,而且发现后面运行代码的时候并不受影响】

运行unsupervised_train.py

#cmd中运行
python -m graphsage.unsupervised_train --train_prefix ./example_data/toy-ppi --model graphsage_mean --max_total_steps 1000 --validate_iter 10
等价于
#pycharm中运行
python ./graphsage/unsupervised_train.py  --train_prefix ./example_data/toy-ppi --model graphsage_mean --max_total_steps 1000 --validate_iter 10

注意,上述数据集路径和官方给的不一样。如果是在Pycharm中运行,需要更改train_prefix,model等参数的值,需要注意在ide和命令行中参数的格式,在idea中修改成

####./是同级目录,../是上一级####flags.DEFINE_string('model', 'graphsage_mean', 'model names. See README for possible values.')
flags.DEFINE_string('train_prefix', '../example_data/toy-ppi', 'prefix identifying training data. must be specified.')

也就是在下面的代码中进行修改,然后直接右击该py文件直接运行即可

【运行结果】

D:\Anaconda\envs\GraphSAGE-master\python.exe F:/code/GraphSAGE-master/graphsage/unsupervised_train.py
Loading training data..
Removed 0 nodes that lacked proper annotations due to networkx versioning issues
Loaded data.. now preprocessing..
Done loading training data..
Unexpected missing: 0
9716 train nodes
5039 test nodes
2021-01-05 12:40:09.431313: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epoch: 0001
Iter: 0000 train_loss= 18.78066 train_mrr= 0.23649 train_mrr_ema= 0.23649 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 1.80019
Iter: 0050 train_loss= 18.67712 train_mrr= 0.16173 train_mrr_ema= 0.21749 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.30518
Iter: 0100 train_loss= 18.41344 train_mrr= 0.17981 train_mrr_ema= 0.20753 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.29077
Iter: 0150 train_loss= 18.10207 train_mrr= 0.21065 train_mrr_ema= 0.19910 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28748
Iter: 0200 train_loss= 17.45003 train_mrr= 0.18005 train_mrr_ema= 0.19214 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.29262
Iter: 0250 train_loss= 16.71679 train_mrr= 0.21261 train_mrr_ema= 0.18919 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28971
Iter: 0300 train_loss= 16.64080 train_mrr= 0.20941 train_mrr_ema= 0.18904 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28670
Iter: 0350 train_loss= 16.33514 train_mrr= 0.18145 train_mrr_ema= 0.18745 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28372
Iter: 0400 train_loss= 15.88267 train_mrr= 0.18800 train_mrr_ema= 0.18749 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28125
Iter: 0450 train_loss= 15.74382 train_mrr= 0.18654 train_mrr_ema= 0.18716 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27997
Iter: 0500 train_loss= 15.58050 train_mrr= 0.17311 train_mrr_ema= 0.18805 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28234
Iter: 0550 train_loss= 15.37372 train_mrr= 0.19895 train_mrr_ema= 0.18720 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28504
Iter: 0600 train_loss= 15.11785 train_mrr= 0.18306 train_mrr_ema= 0.18627 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28553
Iter: 0650 train_loss= 15.04833 train_mrr= 0.17784 train_mrr_ema= 0.18642 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28460
Iter: 0700 train_loss= 14.93566 train_mrr= 0.17898 train_mrr_ema= 0.18615 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28595
Iter: 0750 train_loss= 14.94030 train_mrr= 0.16468 train_mrr_ema= 0.18470 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28714
Iter: 0800 train_loss= 14.82021 train_mrr= 0.17996 train_mrr_ema= 0.18407 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28590
Iter: 0850 train_loss= 14.75895 train_mrr= 0.20370 train_mrr_ema= 0.18402 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28603
Iter: 0900 train_loss= 14.79193 train_mrr= 0.17865 train_mrr_ema= 0.18508 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28522
Iter: 0950 train_loss= 14.68051 train_mrr= 0.18984 train_mrr_ema= 0.18638 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28426
Iter: 1000 train_loss= 14.66581 train_mrr= 0.18520 train_mrr_ema= 0.18604 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28330
Iter: 1050 train_loss= 14.64359 train_mrr= 0.18334 train_mrr_ema= 0.18624 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28250
Iter: 1100 train_loss= 14.66787 train_mrr= 0.16166 train_mrr_ema= 0.18589 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28193
Iter: 1150 train_loss= 14.65202 train_mrr= 0.19368 train_mrr_ema= 0.18547 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28159
Iter: 1200 train_loss= 14.65571 train_mrr= 0.17497 train_mrr_ema= 0.18529 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28211
Iter: 1250 train_loss= 14.63282 train_mrr= 0.18568 train_mrr_ema= 0.18499 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28408
Iter: 1300 train_loss= 14.63904 train_mrr= 0.17733 train_mrr_ema= 0.18549 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28504
Iter: 1350 train_loss= 14.62205 train_mrr= 0.17858 train_mrr_ema= 0.18492 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28581
Iter: 1400 train_loss= 14.59377 train_mrr= 0.18335 train_mrr_ema= 0.18578 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28678
Iter: 1450 train_loss= 14.61559 train_mrr= 0.19628 train_mrr_ema= 0.18639 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28726
Iter: 1500 train_loss= 14.58464 train_mrr= 0.18871 train_mrr_ema= 0.18576 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28664
Iter: 1550 train_loss= 14.61813 train_mrr= 0.17187 train_mrr_ema= 0.18629 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28593
Iter: 1600 train_loss= 14.61389 train_mrr= 0.19341 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28527
Iter: 1650 train_loss= 14.61634 train_mrr= 0.19737 train_mrr_ema= 0.18766 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28483
Iter: 1700 train_loss= 14.57671 train_mrr= 0.19137 train_mrr_ema= 0.18717 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28439
Iter: 1750 train_loss= 14.55233 train_mrr= 0.20713 train_mrr_ema= 0.18679 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28405
Iter: 1800 train_loss= 14.58431 train_mrr= 0.20119 train_mrr_ema= 0.18758 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28359
Iter: 1850 train_loss= 14.59033 train_mrr= 0.18874 train_mrr_ema= 0.18673 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28320
Iter: 1900 train_loss= 14.61115 train_mrr= 0.18718 train_mrr_ema= 0.18686 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28276
Iter: 1950 train_loss= 14.59950 train_mrr= 0.17403 train_mrr_ema= 0.18849 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28236
Iter: 2000 train_loss= 14.59908 train_mrr= 0.18091 train_mrr_ema= 0.18714 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28214
Iter: 2050 train_loss= 14.58339 train_mrr= 0.19607 train_mrr_ema= 0.18727 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28182
Iter: 2100 train_loss= 14.56937 train_mrr= 0.19161 train_mrr_ema= 0.18767 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28153
Iter: 2150 train_loss= 14.61444 train_mrr= 0.19147 train_mrr_ema= 0.18828 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28113
Iter: 2200 train_loss= 14.61586 train_mrr= 0.19568 train_mrr_ema= 0.18844 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28100
Iter: 2250 train_loss= 14.58835 train_mrr= 0.17902 train_mrr_ema= 0.18864 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28069
Iter: 2300 train_loss= 14.59437 train_mrr= 0.18586 train_mrr_ema= 0.18851 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28042
Iter: 2350 train_loss= 14.58622 train_mrr= 0.19174 train_mrr_ema= 0.18775 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28034
Iter: 2400 train_loss= 14.59255 train_mrr= 0.19002 train_mrr_ema= 0.18706 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28019
Iter: 2450 train_loss= 14.61124 train_mrr= 0.19741 train_mrr_ema= 0.18872 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.28009
Iter: 2500 train_loss= 14.61114 train_mrr= 0.18167 train_mrr_ema= 0.18885 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27986
Iter: 2550 train_loss= 14.58769 train_mrr= 0.21136 train_mrr_ema= 0.18786 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27974
Iter: 2600 train_loss= 14.59534 train_mrr= 0.19214 train_mrr_ema= 0.18814 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27957
Iter: 2650 train_loss= 14.57917 train_mrr= 0.17698 train_mrr_ema= 0.18707 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27946
Iter: 2700 train_loss= 14.59813 train_mrr= 0.18994 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27930
Iter: 2750 train_loss= 14.58086 train_mrr= 0.18377 train_mrr_ema= 0.18592 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27909
Iter: 2800 train_loss= 14.60047 train_mrr= 0.18927 train_mrr_ema= 0.18653 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27886
Iter: 2850 train_loss= 14.60010 train_mrr= 0.18386 train_mrr_ema= 0.18697 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27862
Iter: 2900 train_loss= 14.60427 train_mrr= 0.18447 train_mrr_ema= 0.18848 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27853
Iter: 2950 train_loss= 14.56377 train_mrr= 0.20184 train_mrr_ema= 0.18818 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27839
Iter: 3000 train_loss= 14.59801 train_mrr= 0.16331 train_mrr_ema= 0.18764 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27824
Iter: 3050 train_loss= 14.60474 train_mrr= 0.18347 train_mrr_ema= 0.18711 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27809
Iter: 3100 train_loss= 14.59281 train_mrr= 0.18725 train_mrr_ema= 0.18756 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27805
Iter: 3150 train_loss= 14.62110 train_mrr= 0.19122 train_mrr_ema= 0.18804 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27790
Iter: 3200 train_loss= 14.57584 train_mrr= 0.17738 train_mrr_ema= 0.18747 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27779
Iter: 3250 train_loss= 14.60866 train_mrr= 0.17803 train_mrr_ema= 0.18844 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27764
Iter: 3300 train_loss= 14.58529 train_mrr= 0.20240 train_mrr_ema= 0.18789 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27751
Iter: 3350 train_loss= 14.62195 train_mrr= 0.18435 train_mrr_ema= 0.18832 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27737
Iter: 3400 train_loss= 14.56922 train_mrr= 0.19166 train_mrr_ema= 0.18750 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27717
Iter: 3450 train_loss= 14.58548 train_mrr= 0.19197 train_mrr_ema= 0.18727 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27697
Iter: 3500 train_loss= 14.58611 train_mrr= 0.18371 train_mrr_ema= 0.18716 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27681
Iter: 3550 train_loss= 14.58547 train_mrr= 0.18298 train_mrr_ema= 0.18861 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27674
Iter: 3600 train_loss= 14.57893 train_mrr= 0.18505 train_mrr_ema= 0.18868 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27701
Iter: 3650 train_loss= 14.57411 train_mrr= 0.17655 train_mrr_ema= 0.18987 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27730
Iter: 3700 train_loss= 14.59241 train_mrr= 0.19698 train_mrr_ema= 0.18930 val_loss= 19.08405 val_mrr= 0.18671 val_mrr_ema= 0.18671 time= 0.27722
Optimization Finished!Process finished with exit code 0

运行supervised_train.py

注意train_prefix参数的值也需要改: ../example_data/toy-ppi

python -m graphsage.supervised_train --train_prefix ./example_data/toy-ppi --model graphsage_mean --sigmoid
等价于
python ./graphsage/supervised_train.py --train_prefix ./example_data/toy-ppi --model graphsage_mean --sigmoid

也就是代码修改同上

运行结果

D:\Anaconda\envs\GraphSAGE-master\python.exe F:/code/GraphSAGE-master/graphsage/supervised_train.py
Loading training data..
Removed 0 nodes that lacked proper annotations due to networkx versioning issues
Loaded data.. now preprocessing..
Done loading training data..
WARNING:tensorflow:From F:\code\GraphSAGE-master\graphsage\supervised_models.py:118: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.See @{tf.nn.softmax_cross_entropy_with_logits_v2}.2021-01-05 14:24:52.211993: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2
Epoch: 0001
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.'precision', 'predicted', average, warn_for)
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1137: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true samples.'recall', 'true', average, warn_for)
Iter: 0000 train_loss= 160.39902 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 1.31548
Iter: 0005 train_loss= 177.72525 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.35738
Iter: 0010 train_loss= 168.67435 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.26747
Iter: 0015 train_loss= 174.82602 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.88171 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.23475
Epoch: 0002
Iter: 0001 train_loss= 169.43646 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.22217
Iter: 0006 train_loss= 171.03656 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.21048
Iter: 0011 train_loss= 168.95322 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.20330
Iter: 0016 train_loss= 164.48836 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 196.30547 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19861
Epoch: 0003
Iter: 0002 train_loss= 170.99802 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19645
Iter: 0007 train_loss= 170.51253 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19307
Iter: 0012 train_loss= 174.38806 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.19067
Iter: 0017 train_loss= 162.57272 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 181.05011 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18782
Epoch: 0004
Iter: 0003 train_loss= 170.45332 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18709
Iter: 0008 train_loss= 169.09729 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18537
Iter: 0013 train_loss= 166.32990 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18381
Iter: 0018 train_loss= 173.10933 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 190.83347 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18272
Epoch: 0005
Iter: 0004 train_loss= 165.87482 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18264
Iter: 0009 train_loss= 168.55566 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18191
Iter: 0014 train_loss= 173.05153 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 191.74899 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18122
Epoch: 0006
Iter: 0000 train_loss= 168.48744 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18131
Iter: 0005 train_loss= 164.95117 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.18065
Iter: 0010 train_loss= 166.21835 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17936
Iter: 0015 train_loss= 177.44318 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 177.71576 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17864
Epoch: 0007
Iter: 0001 train_loss= 167.81136 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17919
Iter: 0006 train_loss= 174.58884 train_f1_mic= 0.00195 train_f1_mac= 0.00022 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17845
Iter: 0011 train_loss= 165.81683 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17768
Iter: 0016 train_loss= 171.65659 train_f1_mic= 0.00195 train_f1_mac= 0.00020 val_loss= 186.21609 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17679
Epoch: 0008
Iter: 0002 train_loss= 174.44943 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17647
Iter: 0007 train_loss= 176.99825 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17634
Iter: 0012 train_loss= 167.94687 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17634
Iter: 0017 train_loss= 172.83304 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 192.19485 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17586
Epoch: 0009
Iter: 0003 train_loss= 176.01657 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17600
Iter: 0008 train_loss= 169.39464 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17550
Iter: 0013 train_loss= 168.62959 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17529
Iter: 0018 train_loss= 168.67769 train_f1_mic= 0.00200 train_f1_mac= 0.00023 val_loss= 197.73112 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17495
Epoch: 0010
Iter: 0004 train_loss= 165.34845 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17522
Iter: 0009 train_loss= 169.52484 train_f1_mic= 0.00195 train_f1_mac= 0.00019 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17464
Iter: 0014 train_loss= 168.72287 train_f1_mic= 0.00000 train_f1_mac= 0.00000 val_loss= 185.15099 val_f1_mic= 0.00000 val_f1_mac= 0.00000 time= 0.17434
Optimization Finished!
Full validation stats: loss= 184.90506 f1_micro= 0.00055 f1_macro= 0.00005 time= 0.54853
Writing test set stats to file (don't peak!)Process finished with exit code 0
  • graphsage_mean – GraphSage with mean-based aggregator
  • graphsage_seq – GraphSage with LSTM-based aggregator
  • graphsage_maxpool – GraphSage with max-pooling aggregator (as described in the NIPS 2017 paper)
  • graphsage_meanpool – GraphSage with mean-pooling aggregator (a variant of the pooling aggregator, where the element-wie mean replaces the element-wise max).
  • gcn – GraphSage with GCN-based aggregator
  • n2v – an implementation of DeepWalk (called n2v for short in the code.)
  • 可以看出,unsupervised_train.py只运行了1个epoch,共3700次迭代,每50个迭代运行一次validation,batch_size:512
  • 可以看出,supervised_train.py只运行了10个epoch,共40次迭代,每5个迭代运行一次validation,batch_size:512
  • python -m graphsage.unsupervised_train 表示以模块运行,不用具体路径
  • python ./graphsage/unsupervised_train.py 表示以脚本文件直接运行

注意

D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.'precision', 'predicted', average, warn_for)
D:\Anaconda\envs\GraphSAGE-master\lib\site-packages\sklearn\metrics\classification.py:1137: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true samples.'recall', 'true', average, warn_for)
  • 原因:存在一些样本 label 为 y_true,但是y_pred 并没有预测到,即在预测数据中存在实际类别没有的标签时报此warning,此时F1当作0
    比如
    y_true = (0, 1, 2, 3, 4)
    y_pred = (0, 1, 1, 3, 4)
    label‘2’ 从来没有被预测到,所以F-score没有计算这项 label, 因此这种情况下 F-score 就被当作为 0.0 了。
    但是又因为,要计算所有分类结果的平均得分就必须将这项得分为 0 的情况考虑进去,所以,scikit-learn出来提醒你,warning警告一下,但不是错误。

代码分析

__init__.py

from __future__ import print_function
'''即使在python2.X,使用print就得像python3.X那样加括号使用'''from __future__ import division
'''导入python未来支持的语言特征division(精确除法),
6 # 当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division);
7 # 当我们导入精确除法之后,"/"执行的是精确除法, "//"执行截断除除法'''

utils.py

from __future__ import print_functionimport numpy as np'''导入numpy模块'''
import random'''导入randomm模块'''
import json'''导入json模块'''
import sys'''导入系统模块'''
import os'''导入操作系统模块'''import networkx as nx'''networkx(图论)的基本操作,用于创建图等操作'''
from networkx.readwrite import json_graph'''用于将networks图保存为json图'''
version_info = list(map(int, nx.__version__.split('.')))#获取netwoeks版本信息然后转换为列表
major = version_info[0]#获取版本号点号前面的数字
minor = version_info[1]#获取版本号点号后面的数字
assert (major <= 1) and (minor <= 11), "networkx major version > 1.11"#networkx版本必须小于等于1.11,否则断言WALK_LEN=5
N_WALKS=50
'''
type() 与 isinstance() 区别:
type() 不会认为子类是一种父类类型,不考虑继承关系。
isinstance() 会认为子类是一种父类类型,考虑继承关系。
如果要判断两个类型是否相同推荐使用 isinstance()。
参数
object – 实例对象。
classinfo – 可以是直接或间接类名、基本类型或者由它们组成的元组。返回值
如果对象的类型与参数二的类型(classinfo)相同则返回 True,否则返回 False。
'''
'''G.nodes()  返回的是图中节点n与节点属性nodedata。'''def load_data(prefix, normalize=True, load_walks=False):G_data = json.load(open(prefix + "-G.json"))#加载图信息  图信息为json文件 所以用json模块导入G = json_graph.node_link_graph(G_data)#Return graph from node-link data format#定义conversion函数#判断G.nodes()[0] 是否为int型(即不带nodedata)if isinstance(G.nodes()[0], int):conversion = lambda n : int(n)# lambda parameters:expresselse:conversion = lambda n : n#保持n不动if os.path.exists(prefix + "-feats.npy"):#如果路径下面存在预训练好得到的features文件feats = np.load(prefix + "-feats.npy")else:print("No features present.. Only identity features will be used.")feats = None#一个json存储的字典,将图节点id映射为连续整数。id_map = json.load(open(prefix + "-id_map.json"))#加载节点编号与序号的一一对应的id数据id_map = {conversion(k):int(v) for k,v in id_map.items()}walks = []class_map = json.load(open(prefix + "-class_map.json"))#标签数据加载,字典集合# print("class_map:",class_map                                    )#{"0": [1, 0, 0,...],...,"14754": [1, 1, 0, 0,...]}if isinstance(list(class_map.values())[0], list):#将字典数据转换为列表数据,在判断是否转换成功lab_conversion = lambda n : nelse:lab_conversion = lambda n : int(n)#将标签转换为整型class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()}"""遍历标签数据集中所有键--值,以列表返回,构造集合   #id_map的迭代中k为str类型,v为int型,将其全部转换成整型"""# print("class_map:",class_map)#{0: [1, 0, 0,...],...,14754: [1, 1, 0, 0,...]}'''代码中edge对edges迭代,每次去list中的一个元组,而edge[0], edge[1]则分别表示两个顶点。
若两个顶点中至少有一个的val / test不为空,则将该边的’train_removed’设为True,否则为False。
该操作为保证’train_removed’不为空。
'''## Remove all nodes that do not have val/test annotations## (necessary because of networkx weirdness with the Reddit data)怪异性broken_count = 0for node in G.nodes():if not 'val' in G.node[node] or not 'test' in G.node[node]:G.remove_node(node)broken_count += 1print("Removed {:d} nodes that lacked proper annotations due to networkx versioning issues".format(broken_count))#format():把传统的%替换为{}来实现格式化输出
'''G.edges() 得到edge_list, [( , ), ( , ), … ( , )]。list中每一个元素是所表示边的两个节点信息。若设置data = True,则会显示边的权重等属性信息。'''## Make sure the graph has edge train_removed annotations## (some datasets might already have this..)print("Loaded data.. now preprocessing..")for edge in G.edges():if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] orG.node[edge[0]]['test'] or G.node[edge[1]]['test']):G[edge[0]][edge[1]]['train_removed'] = Trueelse:G[edge[0]][edge[1]]['train_removed'] = False#获取训练数据features并标准化
'''将val,test均为None的node选为训练数据,通过id_map获取其在feature表中的索引值,添加到train_ids数组中。根据索引train_ids,train_fests获取这些nodes的features.'''if normalize and not feats is None:#feats 非空  获取训练数据features并标准化from sklearn.preprocessing import StandardScalertrain_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])#将val,test均为None的node选为训练数据train_feats = feats[train_ids]#获取节点特征scaler = StandardScaler()scaler.fit(train_feats)#计算训练数据的均值和方差feats = scaler.transform(feats)#计算训练数据的均值和方差,还会基于计算出来的均值和方差来转换训练数据,从而把数据转换成标准的正太分布## 标准化数据,保证每个维度的特征数据方差为1,均值为0,使得预测结果不会被某些维度过大的特征值而主导if load_walks:# false by defaultwith open(prefix + "-walks.txt") as fp:for line in fp:walk

neigh_samplers.py

models.py

layers.py

minibatch.py

aggregators.py

prediction.py

supervised_train.py

unsupervised_train.py

# _*_ coding:UTF-8
#supervised_train.py 是用节点分类的label来做loss训练,不能输出节点embedding,使用NodeMinibatchIterator#unsupervised_train.py 是用节点和节点的邻接信息做loss训练,训练好可以输出节点embedding,使用EdgeMinibatchIteratorfrom __future__ import division
'''即使在python2.X,使用print就得像python3.X那样加括号使用'''from __future__ import print_function
'''导入python未来支持的语言特征division(精确除法),
6 # 当我们没有在程序中导入该特征时,"/"操作符执行的是截断除法(Truncating Division);
7 # 当我们导入精确除法之后,"/"执行的是精确除法, "//"执行截断除除法'''import os#导入操作系统模块
import time#导入时间模块
import tensorflow as tf#导入TensorFlow模块
import numpy as np#导入numpy模块from graphsage.models import SampleAndAggregate, SAGEInfo, Node2VecModel
from graphsage.minibatch import EdgeMinibatchIterator
from graphsage.neigh_samplers import UniformNeighborSampler
from graphsage.utils import load_data
'''如果服务器有多个GPU,tensorflow默认会全部使用。如果只想使用部分GPU,可以通过参数CUDA_VISIBLE_DEVICES来设置GPU的可见性。'''os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"# 按照PCI_BUS_ID顺序从0开始排列GPU设备   # 使用哪一块gpu,本人只有一块,需将1改为0'''Set random seed设置相同的seed,则每次生成的随机数也相同,如果不设置seed,则每次生成的随机数都会不一样。'''
seed = 123
np.random.seed(seed)#random是一个算法,设置随机数种子,再不同设备上生成的随机数一样。
tf.set_random_seed(seed)# Settings
flags = tf.app.flags
FLAGS = flags.FLAGS  #构造了一个解析器FLAGS  这样就可以从命令行中传入数据,从外部定义参数,如python train.py --model gcntf.app.flags.DEFINE_boolean('log_device_placement', False,"""Whether to log device placement.""")#定义变量bool型。
#core params..#定义变量,通过命令行解析传入参数
flags.DEFINE_string('model', 'graphsage', 'model names. See README for possible values.')  #传入模型,模型名字等参数
flags.DEFINE_float('learning_rate', 0.00001, 'initial learning rate.')
flags.DEFINE_string("model_size", "small", "Can be big or small; model specific def'ns")
flags.DEFINE_string('train_prefix', '', 'name of the object file that stores the training data. must be specified.')# left to default values in main experiments 实验默认值
flags.DEFINE_integer('epochs', 1, 'number of epochs to train.')#迭代次数
flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).')#dropout率 避免过拟合(按照一定的概率随机丢弃一部分神经元)
# loss计算方式(权值衰减+正则化):self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')#权衰减 目的就是为了让权重减少到更小的值,在一定程度上减少模型过拟合的问题
flags.DEFINE_integer('max_degree', 100, 'maximum node degree.')#矩阵的度flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')#第一层采样节点数 k =1 s = 25
flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')#第二层采样节点数 k = 2 s = 10#若有concat操作,则维度变为2倍
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
flags.DEFINE_integer('neg_sample_size', 20, 'number of negative samples')#负采样数
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')#bachsize
flags.DEFINE_integer('n2v_test_epochs', 1, 'Number of new SGD epochs for n2v.')n2v SGD迭代次数
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')#设置负值,用于识别嵌入特征维度  默认为0#logging, saving, validation settings etc.
flags.DEFINE_boolean('save_embeddings', True, 'whether to save embeddings for all nodes after training')#选择是否保存嵌入
flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings')#用于记录和保存嵌入的基本目录
flags.DEFINE_integer('validate_iter', 5000, "how often to run a validation minibatch.")#验证集迭代次数
flags.DEFINE_integer('validate_batch_size', 256, "how many nodes per validation sample.")#验证集bach_size
flags.DEFINE_integer('gpu', 1, "which gpu to use.")#使用哪一个GPU,只有1块时需要改为0
flags.DEFINE_integer('print_every', 50, "How often to print training info.")#设置多久打印训练信息
flags.DEFINE_integer('max_total_steps', 10**10, "Maximum total number of iterations")#最大迭代次数os.environ["CUDA_VISIBLE_DEVICES"]=str(FLAGS.gpu)#传入参数 # 使用哪一块gpu,只有一块,需将1改为0GPU_MEM_FRACTION = 0.8#分配GPU多少资源给它使用def log_dir():#定义嵌入数据保存目录设置函数log_dir = FLAGS.base_log_dir + "/unsup-" + FLAGS.train_prefix.split("/")[-2]log_dir += "/{model:s}_{model_size:s}_{lr:0.6f}/".format(model=FLAGS.model,model_size=FLAGS.model_size,lr=FLAGS.learning_rate)if not os.path.exists(log_dir):#如果不存在就创建一个os.makedirs(log_dir)#os模块创建dirreturn log_dir#将保存目录返回# Define model evaluation function
def evaluate(sess, model, minibatch_iter, size=None):#定义模型评估函数,t_test = time.time()feed_dict_val = minibatch_iter.val_feed_dict(size)#采用minibatch梯度下降,比SGD、BGD快,通过feed_dict传值outs_val = sess.run([model.loss, model.ranks, model.mrr], feed_dict=feed_dict_val)#运行损失函数return outs_val[0], outs_val[1], outs_val[2], (time.time() - t_test)#返回损失值,模型ranks,def incremental_evaluate(sess, model, minibatch_iter, size):#增加评估t_test = time.time()finished = Falseval_losses = []val_mrrs = []iter_num = 0while not finished:feed_dict_val, finished, _ = minibatch_iter.incremental_val_feed_dict(size, iter_num)iter_num += 1outs_val = sess.run([model.loss, model.ranks, model.mrr], feed_dict=feed_dict_val)val_losses.append(outs_val[0])val_mrrs.append(outs_val[2])return np.mean(val_losses), np.mean(val_mrrs), (time.time() - t_test)def save_val_embeddings(sess, model, minibatch_iter, size, out_dir, mod=""):#保存验证集嵌入val_embeddings = []finished = Falseseen = set([])nodes = []iter_num = 0name = "val"while not finished:feed_dict_val, finished, edges = minibatch_iter.incremental_embed_feed_dict(size, iter_num)iter_num += 1outs_val = sess.run([model.loss, model.mrr, model.outputs1], feed_dict=feed_dict_val)#ONLY SAVE FOR embeds1 because of planetoidfor i, edge in enumerate(edges):if not edge[0] in seen:val_embeddings.append(outs_val[-1][i,:])nodes.append(edge[0])seen.add(edge[0])if not os.path.exists(out_dir):os.makedirs(out_dir)val_embeddings = np.vstack(val_embeddings)#按垂直方向嵌入np.save(out_dir + name + mod + ".npy",  val_embeddings)with open(out_dir + name + mod + ".txt", "w") as fp:fp.write("\n".join(map(str,nodes)))#将节点映射后转化为json格式数据存储def construct_placeholders():#定义放置placeholder函数,tf中占位符# Define placeholdersplaceholders = {'batch1' : tf.placeholder(tf.int32, shape=(None), name='batch1'),'batch2' : tf.placeholder(tf.int32, shape=(None), name='batch2'),# negative samples for all nodes in the batch  所有nodes均为负样本'neg_samples': tf.placeholder(tf.int32, shape=(None,),name='neg_sample_size'),'dropout': tf.placeholder_with_default(0., shape=(), name='dropout'),'batch_size' : tf.placeholder(tf.int32, name='batch_size'),}return placeholdersdef train(train_data, test_data=None):#定义训练函数G = train_data[0]# 加载图信息features = train_data[1] # 训练数据的featuresid_map = train_data[2]# "n" : n  节点与节点直接的id映射 已经删除了节点是不具有'val'或'test'属性 的节点if not features is None:#只要features不为None#vstack为features添加列一行0向量,用于WX + b中与b相加features = np.vstack([features, np.zeros((features.shape[1],))])#这里vstack为features添加列一行0向量,用于WX + b中与b相加。context_pairs = train_data[3] if FLAGS.random_context else None  #random walk的点对placeholders = construct_placeholders()# def construct_placeholders()定义的placeholders包含:# batch1, batch2, neg_samples, dropout, batch_sizeminibatch = EdgeMinibatchIterator(G, id_map,placeholders, batch_size=FLAGS.batch_size,max_degree=FLAGS.max_degree, num_neg_samples=FLAGS.neg_sample_size,context_pairs = context_pairs)adj_info_ph = tf.placeholder(tf.int32, shape=minibatch.adj.shape)adj_info = tf.Variable(adj_info_ph, trainable=False, name="adj_info")if FLAGS.model == 'graphsage_mean':# Create modelsampler = UniformNeighborSampler(adj_info)layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]model = SampleAndAggregate(placeholders, features,adj_info,minibatch.deg,layer_infos=layer_infos, model_size=FLAGS.model_size,identity_dim = FLAGS.identity_dim,logging=True)elif FLAGS.model == 'gcn':# Create modelsampler = UniformNeighborSampler(adj_info)layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, 2*FLAGS.dim_1),SAGEInfo("node", sampler, FLAGS.samples_2, 2*FLAGS.dim_2)]model = SampleAndAggregate(placeholders, features,adj_info,minibatch.deg,layer_infos=layer_infos, aggregator_type="gcn",model_size=FLAGS.model_size,identity_dim = FLAGS.identity_dim,concat=False,logging=True)elif FLAGS.model == 'graphsage_seq':sampler = UniformNeighborSampler(adj_info)layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]model = SampleAndAggregate(placeholders, features,adj_info,minibatch.deg,layer_infos=layer_infos, identity_dim = FLAGS.identity_dim,aggregator_type="seq",model_size=FLAGS.model_size,logging=True)elif FLAGS.model == 'graphsage_maxpool':sampler = UniformNeighborSampler(adj_info)layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]model = SampleAndAggregate(placeholders, features,adj_info,minibatch.deg,layer_infos=layer_infos, aggregator_type="maxpool",model_size=FLAGS.model_size,identity_dim = FLAGS.identity_dim,logging=True)elif FLAGS.model == 'graphsage_meanpool':sampler = UniformNeighborSampler(adj_info)layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]model = SampleAndAggregate(placeholders, features,adj_info,minibatch.deg,layer_infos=layer_infos, aggregator_type="meanpool",model_size=FLAGS.model_size,identity_dim = FLAGS.identity_dim,logging=True)elif FLAGS.model == 'n2v':model = Node2VecModel(placeholders, features.shape[0],minibatch.deg,#2x because graphsage uses concatnodevec_dim=2*FLAGS.dim_1,lr=FLAGS.learning_rate)else:raise Exception('Error: model name unrecognized.')config = tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)config.gpu_options.allow_growth = True# 使用allow_growth option,刚一开始分配少量的GPU容量,然后按需慢慢的增加#config.gpu_options.per_process_gpu_memory_fraction = GPU_MEM_FRACTION# 设置每个GPU应该拿出多少容量给进程使用,# per_process_gpu_memory_fraction =0.4代表 40%config.allow_soft_placement = True  # 如果指定的设备不存在,允许TF自动分配设备# 自动选择运行设备# 在tf中,通过命令 "with tf.device('/cpu:0'):",允许手动设置操作运行的设备# 如果手动设置的设备不存在或者不可用,就会导致tf程序等待或异常,# 为了防止这种情况,可以设置tf.ConfigProto()中参数allow_soft_placement=True,# 允许tf自动选择一个存在并且可用的设备来运行操作。# Initialize sessionsess = tf.Session(config=config)merged = tf.summary.merge_all()#能够保存训练过程以及参数分布图并在tensorboard显示  #merge_all 可以将所有summary全部保存到磁盘,以便tensorboard显示。# 指定一个文件用来保存图# 格式:tf.summary.FileWritter(path,sess.graph)# 可以调用其add_summary()方法将训练过程数据保存在filewriter指定的文件中summary_writer = tf.summary.FileWriter(log_dir(), sess.graph)# Init variablessess.run(tf.global_variables_initializer(), feed_dict={adj_info_ph: minibatch.adj})# Train modeltrain_shadow_mrr = Noneshadow_mrr = Nonetotal_steps = 0avg_time = 0.0epoch_val_costs = []train_adj_info = tf.assign(adj_info, minibatch.adj)val_adj_info = tf.assign(adj_info, minibatch.test_adj)for epoch in range(FLAGS.epochs): minibatch.shuffle() iter = 0print('Epoch: %04d' % (epoch + 1))epoch_val_costs.append(0)while not minibatch.end():# Construct feed dictionaryfeed_dict = minibatch.next_minibatch_feed_dict()feed_dict.update({placeholders['dropout']: FLAGS.dropout})t = time.time()# Training stepouts = sess.run([merged, model.opt_op, model.loss, model.ranks, model.aff_all, model.mrr, model.outputs1], feed_dict=feed_dict)train_cost = outs[2]train_mrr = outs[5]if train_shadow_mrr is None:train_shadow_mrr = train_mrr#else:train_shadow_mrr -= (1-0.99) * (train_shadow_mrr - train_mrr)if iter % FLAGS.validate_iter == 0:# Validationsess.run(val_adj_info.op)val_cost, ranks, val_mrr, duration  = evaluate(sess, model, minibatch, size=FLAGS.validate_batch_size)sess.run(train_adj_info.op)epoch_val_costs[-1] += val_costif shadow_mrr is None:shadow_mrr = val_mrrelse:shadow_mrr -= (1-0.99) * (shadow_mrr - val_mrr)if total_steps % FLAGS.print_every == 0:summary_writer.add_summary(outs[0], total_steps)# Print resultsavg_time = (avg_time * total_steps + time.time() - t) / (total_steps + 1)if total_steps % FLAGS.print_every == 0:print("Iter:", '%04d' % iter, "train_loss=", "{:.5f}".format(train_cost),"train_mrr=", "{:.5f}".format(train_mrr), "train_mrr_ema=", "{:.5f}".format(train_shadow_mrr), # exponential moving average"val_loss=", "{:.5f}".format(val_cost),"val_mrr=", "{:.5f}".format(val_mrr), "val_mrr_ema=", "{:.5f}".format(shadow_mrr), # exponential moving average"time=", "{:.5f}".format(avg_time))iter += 1total_steps += 1if total_steps > FLAGS.max_total_steps:breakif total_steps > FLAGS.max_total_steps:breakprint("Optimization Finished!")if FLAGS.save_embeddings:# 训练以后是否存储节点的embeddingssess.run(val_adj_info.op)save_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir())if FLAGS.model == "n2v":# stopping the gradient for the already trained nodestrain_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if not G.node[n]['val'] and not G.node[n]['test']],dtype=tf.int32)test_ids = tf.constant([[id_map[n]] for n in G.nodes_iter() if G.node[n]['val'] or G.node[n]['test']], dtype=tf.int32)update_nodes = tf.nn.embedding_lookup(model.context_embeds, tf.squeeze(test_ids))no_update_nodes = tf.nn.embedding_lookup(model.context_embeds,tf.squeeze(train_ids))update_nodes = tf.scatter_nd(test_ids, update_nodes, tf.shape(model.context_embeds))no_update_nodes = tf.stop_gradient(tf.scatter_nd(train_ids, no_update_nodes, tf.shape(model.context_embeds)))model.context_embeds = update_nodes + no_update_nodessess.run(model.context_embeds)# run random walksfrom graphsage.utils import run_random_walksnodes = [n for n in G.nodes_iter() if G.node[n]["val"] or G.node[n]["test"]]start_time = time.time()pairs = run_random_walks(G, nodes, num_walks=50)walk_time = time.time() - start_timetest_minibatch = EdgeMinibatchIterator(G, id_map,placeholders, batch_size=FLAGS.batch_size,max_degree=FLAGS.max_degree, num_neg_samples=FLAGS.neg_sample_size,context_pairs = pairs,n2v_retrain=True,fixed_n2v=True)start_time = time.time()print("Doing test training for n2v.")test_steps = 0for epoch in range(FLAGS.n2v_test_epochs):test_minibatch.shuffle()while not test_minibatch.end():feed_dict = test_minibatch.next_minibatch_feed_dict()feed_dict.update({placeholders['dropout']: FLAGS.dropout})outs = sess.run([model.opt_op, model.loss, model.ranks, model.aff_all, model.mrr, model.outputs1], feed_dict=feed_dict)if test_steps % FLAGS.print_every == 0:print("Iter:", '%04d' % test_steps, "train_loss=", "{:.5f}".format(outs[1]),"train_mrr=", "{:.5f}".format(outs[-2]))test_steps += 1train_time = time.time() - start_timesave_val_embeddings(sess, model, minibatch, FLAGS.validate_batch_size, log_dir(), mod="-test")print("Total time: ", train_time+walk_time)print("Walk time: ", walk_time)print("Train time: ", train_time)# main函数,加载数据并训练
def main(argv=None):print("Loading training data..")train_data = load_data(FLAGS.train_prefix, load_walks=True)'''load_data函数在graphsage.utils中定义,加载标签数据集'''print("Done loading training data..")train(train_data)'''# train函数在该文件中定义def train(train_data, test_data=None)'''if __name__ == '__main__':tf.app.run()  # 解析命令行参数,调用main 函数 main(sys.argv)
'''
tf.app.run()的作用:通过处理flag解析,然后执行main函数
如果你的代码中的入口函数不叫main(),而是一个其他名字的函数,如test(),则你应该这样写入口tf.app.run(test())如果你的代码中的入口函数叫main(),则你就可以把入口写成tf.app.run()使用tf.app.run() ,上面已经有FLAGS = tf.app.flags.FLAGS了,则已经解析了输入。则tf.app.run() 中argv=None,通过args = argv[1:] if argv else None则args=None(即不指定,后面会自动解析command)f = flags.FLAGS构造了解析器f用以解析args, f._parse_flags(参数args)解析args列表或者command输入,args列表为空,则解析command输入,返回的flags_passthrough内为无法解析的数据列表(不包括文件名) 。
'''

inits.py

citation_eval.py

ppi_eval.py

reddit_eval.py

参考

【源码】https://github.com/williamleif/GraphSAGE

【分析】https://www.cnblogs.com/shiyublog/tag/graphsage/

GraphSage-TF代码解读相关推荐

  1. Unet论文解读代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 网络 架构: a.U-net建立在FCN的网络架构上,作者修改并扩大了这个网络框架,使其能够使用很少 ...

  2. BERT:代码解读、实体关系抽取实战

    目录 前言 一.BERT的主要亮点 1. 双向Transformers 2.句子级别的应用 3.能够解决的任务 二.BERT代码解读 1. 数据预处理 1.1 InputExample类 1.2 In ...

  3. nsga2代码解读python_代码资料

    faster RCNN TensorFlow版本: 龙鹏:[技术综述]万字长文详解Faster RCNN源代码(一) buptscdc:tensorflow 版faster rcnn代码理解(1) l ...

  4. PredRNN++:网络结构和代码解读

    已经有很多帖子对PredRNN++的理论和改进效果进行了解读,不再赘述.直接分析结构和代码. Causal LSTM 单元 三层级联结构: 第一层(蓝色框)类似传统的LSTM结构用于更新时间状态C(t ...

  5. 2D激光SLAM::AMCL发布的odom----map坐标TF变换解读

    摘自:https://blog.csdn.net/dieju8330/article/details/96770964 2D激光SLAM::AMCL发布的odom----map坐标TF变换解读 die ...

  6. mask rcnn 超详细代码解读(一)

    mask r-cnn 代码解读(一) 文章目录 1 代码架构 2 model.py 的结构 3 train过程代码解析 3.1 Resnet Graph 3.2 Region Proposal Net ...

  7. siris 显著性排序网络代码解读(training过程)Inferring Attention Shift Ranks of Objects for Image Saliency

    阅前说明 前面已经出现的代码用 - 代替. 本文仅解析train部分的代码(inference的部分会后续更新). 不对网络结构做过多解释,默认已经熟悉 mrcnn 的结构以及读过这篇论文了. 另:i ...

  8. 元学习之《Matching Networks for One Shot Learning》代码解读

    元学习系列文章 optimization based meta-learning <Model-Agnostic Meta-Learning for Fast Adaptation of Dee ...

  9. Capsule 核心代码解读

    原文地址 Capsule核心代码解读 前几天,Sara Sabour 开源了一份 Capsule 代码,该代码是论文 Dynamic Routing between Capsules 中所采用的实现. ...

  10. BEGAN-边界均衡生成对抗网络-代码解读

    当前论文代码 首先注意: 不同点: 该论文的输入是噪音,鉴别器和生成器都是哑铃型结构, 相同点: 输出是一张图片,D都是用真实图像去比对. 已知信息 可见,是从main.py开始训练的.测试的时候,只 ...

最新文章

  1. nginx http 服务器搭建
  2. 李国杰院士等:未来移动通信系统中的通信与计算融合
  3. 看微信了解MySQL及相关IT技术
  4. remote_os_authent参数测试!
  5. 一文读懂 @Decorator 装饰器——理解 VS Code 源码的基础
  6. java做一个客房管理系统定制_管理皮孩子很难?来,教你一个java设计简单的学生管理系统...
  7. Struts2中的OGNL表达式
  8. Windows编程初步(三)【说明:有敏感字眼已全删,不知道为啥还审核不通过】
  9. 非常好的在网页中显示pdf的方法
  10. NYOJ82-迷宫寻宝1
  11. C语言 Win静态库
  12. java 内部类_java的内部类和静态内部类(嵌套类)
  13. 华三ospf联动bfd_OSPF、BGP、ISIS的路由收敛时间、缩减路由收敛时间的措施有哪些...
  14. OpenCV源码解析之动态内存管理CvMemStorage与CvSeq
  15. OpenPoseDemo的用法
  16. Python爬虫工程师必备工具 Charles 的安装,以及爬取淘宝网+学UI网
  17. vue 全年日历显示并且日期批量选择
  18. nuvoton 开发环境安装问题
  19. (翻译)简化模式(Reduce)
  20. 重构:改善既有代码的设计(评注版)

热门文章

  1. cli模式下php会超时吗,php cli模式下调试
  2. 如何解决Excel文档已损坏呢?
  3. Javascript-蔬菜运算价格
  4. Python求离散序列导数
  5. 最适合游戏开发的语言是什么?
  6. 五个最适合做博客的开源系统 开源免费大量精美模板使用!
  7. Ztree Fa-Awesome 图标使用
  8. Linux下的桥接模式
  9. 2021011029王芯悦-实验1
  10. 研究云计算中调度算法遇到的相关概念