废话不多说先放代码。
本文的代码需要两个部分组成——自定义的GCN层的GCN_layer和训练代码train。
首先是自定义的GCN层的GCN_layer.py:

import tensorflow as tf
import numpy as npclass my_GCN_layer(tf.keras.layers.Layer):def __init__(self, fea_dim, out_dim):super().__init__()self.fea_dim = fea_dimself.out_dim = out_dimdef build(self, input_shape):features_dim = self.fea_dimself.wei = self.add_variable(name='wei', shape=[features_dim, self.out_dim], initializer=tf.zeros_initializer())def call(self, inputs, support):#inputs = np.array(inputs, dtype=float)#support = np.array(support, dtype=float)inputs = tf.cast(inputs, dtype=tf.float32)support = tf.cast(support, dtype=tf.float32)H_t = tf.matmul(support, inputs)output = tf.matmul(H_t, self.wei)return tf.sigmoid(output)

本身代码并不复杂,简单熟悉tensorflow语法便可以轻松阅读。就是tensorflow2.0中自定义层的典型流程——注意需要先继承keras的layers类。在构造函数中规定一下矩阵的shape,而后需要定义build函数创建层的权重,也就是训练参数啦。最后的call函数则是定义了自定义层的前向传播过程。
其中的张量support是输入的图数据的邻接矩阵的归一化矩阵,在GCN的公式推导中有提到,这里就不再赘述,推荐一个简洁的推导博客:https://www.cnblogs.com/denny402/p/10917820.html。
同时上面的inputs是图数据的特征矩阵

接下来是train.py:

import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from GCN_layer import my_GCN_layer as Connum_epoch = 10000
learning_rate = 1e-3
path1 = 'cora.content'
path2 = 'cora.cites'class PCAtest():def __init__(self, k):self.k = kdef get_res(self, data):mean_vector = np.mean(data, axis=0)standdata = data - mean_vectorcov_mat = np.cov(standdata, rowvar=0)fValue, fVector = np.linalg.eig(cov_mat)fValueSort = np.argsort(-fValue)fValueTopN = fValueSort[:self.k]vectorMat = fVector[:, fValueTopN]return np.dot(standdata, vectorMat)class data_loader:def __init__(self, path1, path2):self.path1 = path1self.path2 = path2self.raw_data = pd.read_csv(self.path1, sep='\t', header=None)num = self.raw_data.shape[0]a = list(self.raw_data.index)b = list(self.raw_data[0])c = zip(b, a)map = dict(c)self.features = self.raw_data.iloc[:, 1:-1]self.labels = pd.get_dummies(self.raw_data[1434])raw_data_cites = pd.read_csv(self.path2, sep='\t', header=None)self.matrix = np.zeros((num, num))for i, j in zip(raw_data_cites[0], raw_data_cites[1]):x = map[i]; y = map[j]self.matrix[x][y] = self.matrix[y][x] = 1tem_d = np.sum(self.matrix, axis=1)self.d = np.zeros((num, num))for i in range(len(tem_d)):self.d[i][i] = 1.0/np.sqrt(tem_d[i])self.support = tf.matmul(tf.matmul(self.d, self.matrix), self.d)def get_data(self):return np.array(self.features), np.array(self.labels), np.array(self.support)def get_batch(self, batch_size):for i in range(batch_size):class GCN(tf.keras.Model):def __init__(self, node_dim, fea_dim, out_dim):super().__init__()self.node_dim = node_dimself.fea_dim = fea_dimself.out_dim = out_dimself.con1 = Con(self.fea_dim, 2048)self.con2 = Con(2048, self.out_dim)#self.con3 = Con(4096, self.out_dim)self.fla1 = tf.keras.layers.Flatten(input_shape=[self.node_dim, self.out_dim])self.den1 = tf.keras.layers.Dense(self.type_num, activation='softmax')def call(self, inputs, support):hidden1 = self.con1(inputs, support)#hidden2 = self.con2(hidden1, support)unflattened = self.con2(hidden1, support)undensed = self.fla1(unflattened)output = self.den1(undensed)return outputdataloader = data_loader(path1, path2)
X, Y, support = dataloader.get_data()
model = GCN(X.shape[-1], Y.shape[-1])
check_point = tf.train.Checkpoint(myAwesomeModel=model)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
for i in range(num_epoch):with tf.GradientTape() as tape:y_pred = model(X, support)loss = tf.keras.losses.categorical_crossentropy(Y, y_pred)loss = tf.reduce_mean(loss)#loss = tf.nn.l2_loss(Y - y_pred)#loss = tf.reduce_mean(loss)print("epoch:", i+1, " || loss:", loss)grads = tape.gradient(loss, model.variables)optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))if((i+1)%50 == 0):path = check_point.save('./save/GCN_model.ckpt')print('Saving in:', path)if((i+1) == 10000):pca = PCAtest(2)res = pca.get_res(y_pred)for point in res:plt.plot(point[0], point[1], color='g', marker='o', ls='None')plt.title('test')plt.show()

这部分的代码也并非十分复杂,基本上遵循了tensorflow2.0框架使用keras api进行网络构建的流程,即在GCN中根据之前写的GCN_layer来搭建多层的网络结构。故这部分不再赘述。但需要注意跑代码时需要修改path1与path2的值,修改为自己本地的cora数据集路径便可。

上面代码中较为复杂的还是data_loader类的构建,关系到整个数据集的合理的输入方式以及归一化的邻接矩阵(即上面所提到的张量support)的计算,故下面还是要针对这个类进行一定的讨论的。

首先,cora数据集的python读取方法是参考了大奸猫大佬的博客(https://blog.csdn.net/yeziand01/article/details/93374216)对于本文的工作帮助很大,在此感谢。而将cora数据集全部读取的方法在这个博客里面已经讲得很清楚了,在此就不再赘述。重点讲解一下邻接矩阵的归一化过程,也就是这部分代码:

     tem_d = np.sum(self.matrix, axis=1)self.d = np.zeros((num, num))for i in range(len(tem_d)):self.d[i][i] = 1.0/np.sqrt(tem_d[i])self.support = tf.matmul(tf.matmul(self.d, self.matrix), self.d)

根据GCN的公式推导:

也就是中间的是需要我们进行计算的support。
那么回到代码,tem_d就是通过邻接矩阵计算了每个节点的度,当然包括它自己,而后通过将一个全零的矩阵的对角全部填入对应节点度的-1/2次方之后,就完成了D^(-1/2)的计算(还不熟悉md的公式编写,还望见谅),而A在上面已经计算出来了(self.matrix),故接下来矩阵乘法两次便可求出support。

那么到此代码的解析便结束了,笔者也是刚刚开始学习GCN和tensorflow2.0的相关知识,故代码极为粗糙简陋,还望见谅。

Tensorflow2.0的简单GCN代码(使用cora数据集)相关推荐

  1. 在tensorflow2.0环境下使用RandLA-Net训练S3DIS数据集

    之前的文章介绍了在tensorflow2.0环境下使用RandLA-Net训练Semantic3D数据集,这里我们记录一下如何在在tensorflow2.0环境下使用RandLA-Net训练S3DIS ...

  2. 推荐TensorFlow2.0的样例代码下载

    TensorFlow推出2.0版本后,TF2.0相比于1.x版本默认使用Keras.Eager Execution.支持跨平台.简化了API等.这次更新使得TF2.0更加的接近PyTorch,一系列烦 ...

  3. Tensorflow2.0:Faster RCNN 代码详解(一)

    第一部分给出Fater RCNN文件的代码解析,主要是模型主体的执行过程,在此文件 引入下述几个文件的函数引用,对于backbones,necks和test_mixins文件来说,主要是用来构建模型结 ...

  4. internetreadfile读取数据长度为0_YOLOV3的TensorFlow2.0实现,支持在自己的数据集上训练...

    GitHub链接: calmisential/YOLOv3_TensorFlow2​github.com 我主要参考了yolov3的一个keras实现版本: qqwweee/keras-yolo3​g ...

  5. python血压测量程序代码_利用TensorFlow2.0为胆固醇、血脂、血压数据构建时序深度学习模型(python源代码)...

    背景数据描述 胆固醇.高血脂.高血压是压在广大中年男性头上的三座大山,如何有效的监控他们,做到早发现.早预防.早治疗尤为关键,趁着这个假期我就利用TF2.0构建了一套时序预测模型,一来是可以帮我预发疾 ...

  6. ​TensorFlow2.0系列教程集合版(附PDF下载)

    文章来源于机器学习算法与Python实战,作者奥辰 TensorFlow2.0(1):基本数据结构--张量 TensorFlow2.0(2):数学运算 TensorFlow2.0(3):张量排序.最大 ...

  7. Transformers2.0让你三行代码调用语言模型,兼容TF2.0和PyTorch

    Transformers2.0让你三行代码调用语言模型,兼容TF2.0和PyTorch 能够灵活地调用各种语言模型,一直是 NLP 研究者的期待.近日 HuggingFace 公司开源了最新的 Tra ...

  8. Tensorflow2.0安装(win10系统cpu版本)

    前提条件:已经安装好Anaconda,本文重点不详细介绍,安装Anaconda见这篇文章: [Anaconda教程01]怎么安装Anaconda3 - 知乎 默认安装完Anaconda 步骤一:启动A ...

  9. Tensorflow2.0实现对抗生成网络(GAN)

    在这篇文章中,我们使用Tensorflow2.0来实现GAN,使用的数据集是手写数字数据集. 引入需要的库 import tensorflow as tf from tensorflow import ...

最新文章

  1. 设置IDEA自动导入import 关联的包
  2. 【Android 插件化】Hook 插件化框架 ( 从 Hook 应用角度分析 Activity 启动流程 一 | Activity 进程相关源码 )
  3. hadoop程序开发--- Java
  4. php轻量级的性能分析工具xhprof的安装使用
  5. 【Java6】Date类/Calendar类,System类/Math类,包装类,集合,泛型,内部类
  6. Cocoa的MVC架构分析 cocoa的mvc实现
  7. 《多元统计分析》学习笔记之聚类分析
  8. java gettime_Java Util.getTime方法代码示例
  9. python目前版本强势英雄_王者荣耀目前版本什么英雄强势?
  10. (附源码gitHub下载地址)spring boot -jta-atomikos分布式事务
  11. python爬虫绕过验证码_爬虫怎样绕过验证码?
  12. 数学系鄙视物理系的经典桥段,全部看懂了算我输!
  13. 飞鸽传书为我们提供了方便的聊天工具
  14. 记一次使用Dapper 进行的数据迁移和清洗工作
  15. java在线学习系统源码_java学习成长之路(基础,源码,项目,实战)
  16. 57. web 攻击技术
  17. Windows 95 输入法编辑器
  18. Javascript 设置Cookie
  19. 景区门票怎么在线上渠道分销?
  20. 什么是通配符SSL证书?

热门文章

  1. Python evel函数
  2. “问题事件名称:BEX 故障模块名称:StackHash_9fba”的解决办法
  3. Java基础查漏补缺(个人向)
  4. 【Allegro_SPB_16.6安装详细教程】手把手搭建到Win10
  5. 理解Schnorr签名算法
  6. system verilog断言学习笔记
  7. matlab绘制散点拟合图
  8. 大数据主要所学技术(简介)
  9. android电视自动关机,android实现自动关机的具体方法代码
  10. java拆分excel_apache poi拆分excel表格