1 理论部分

交通预测论文翻译:Deep Learning on Traffic Prediction: Methods,Analysis and Future Directions_UQI-LIUWJ的博客-CSDN博客-4.1.2.1.1 ChebNet

2  类写法

CLASSChebConv(in_channels: int, out_channels: int, K: int, normalization: Optional[str] = 'sym', bias: bool = True, **kwargs)

3 参数说明

in_channels (int)  输入样本的通道数
out_channels (int)

输出样本的通道数

(在Cheb的源码中,每一阶切比雪夫多项式 进行卷积之后,都会再过一个FC,这个就是给每一阶的切比雪夫多项式卷积 修改维度、调整权重用的)

K (int) 几阶切比雪夫多项式近似
normalization (stroptional)

图拉普拉斯矩阵的归一化方法:默认是sym

None 没有归一化       
"sym" 对称归一化        
"rw" 随机游走归一化   

需要将lambda_max参数提供给forward()方法,以防normalization是不对称的

lambda_max 需要时一个[batch_size]维度的Tensor

可以使用torch_geometric.transforms.LaplacianLambdaMax 方法事先计算lambda_max

bias

默认是True ,如果是False,那么这个ChebNet就不会有偏移量

4 forward 函数

forward(x,edge_index, edge_weight: Optional[torch.Tensor] = None, batch: Optional[torch.Tensor] = None, lambda_max: Optional[torch.Tensor] = None)

注:这里的batch是指torch_geometric笔记:数据集 ENZYMES &Minibatches_UQI-LIUWJ的博客-CSDN博客 第2小节中说的batch

5 源码

这里处理得很高妙,它相当于把正则化拉普拉斯矩阵作为新图的邻接矩阵

from typing import Optional
from torch_geometric.typing import OptTensorimport torch
from torch.nn import Parameterfrom torch_geometric.nn.inits import zeros
from torch_geometric.utils import get_laplacian
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loopsclass ChebConv(MessagePassing):def __init__(self, in_channels: int, out_channels: int, K: int,normalization: Optional[str] = 'sym', bias: bool = True,**kwargs):kwargs.setdefault('aggr', 'add')super(ChebConv, self).__init__(**kwargs)#设置聚合方式(add,也就是将各层切比雪夫多项式近似求和)assert K > 0assert normalization in [None, 'sym', 'rw'], 'Invalid normalization'#两个断言,切比雪夫多项式近似的阶数大于0;在这三种normalization里面选择self.in_channels = in_channelsself.out_channels = out_channelsself.normalization = normalizationself.lins = torch.nn.ModuleList([Linear(in_channels, out_channels, bias=False,weight_initializer='glorot') for _ in range(K)])#各层切比雪夫多项式近似之后接的维度转换全连接层if bias:self.bias = Parameter(torch.Tensor(out_channels))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):#初始化参数for lin in self.lins:lin.reset_parameters()zeros(self.bias)def __norm__(self, edge_index, num_nodes: Optional[int],edge_weight: OptTensor, normalization: Optional[str],lambda_max, dtype: Optional[int] = None,batch: OptTensor = None):#这里处理得很高妙,它相当于把正则化拉普拉斯矩阵作为新图的邻接矩阵edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)#去掉自环edge_index, edge_weight = get_laplacian(edge_index, edge_weight,normalization, dtype,num_nodes)#计算拉普拉斯矩阵if batch is not None and lambda_max.numel() > 1:lambda_max = lambda_max[batch[edge_index[0]]]edge_weight = (2.0 * edge_weight) / lambda_maxedge_weight.masked_fill_(edge_weight == float('inf'), 0)#图中所有原来边权重非零的边,权重全部乘以2/lambda_maxedge_index, edge_weight = add_self_loops(edge_index, edge_weight,fill_value=-1.,num_nodes=num_nodes)#由于归一化拉普拉斯矩阵还需要-I,所以所有的自环权重减一assert edge_weight is not Nonereturn edge_index, edge_weight#返回以拉普拉斯矩阵为邻接矩阵的“新图”def forward(self, x, edge_index, edge_weight: OptTensor = None,batch: OptTensor = None, lambda_max: OptTensor = None):""""""if self.normalization != 'sym' and lambda_max is None:raise ValueError('You need to pass `lambda_max` to `forward() in`''case the normalization is non-symmetric.')if lambda_max is None:lambda_max = torch.tensor(2.0, dtype=x.dtype, device=x.device)if not isinstance(lambda_max, torch.Tensor):lambda_max = torch.tensor(lambda_max, dtype=x.dtype,device=x.device)assert lambda_max is not Noneedge_index, norm = self.__norm__(edge_index, x.size(self.node_dim),edge_weight, self.normalization,lambda_max, dtype=x.dtype,batch=batch)#得到以拉普拉斯矩阵为邻接矩阵的“新图”Tx_0 = x#Z_1=Xout = self.lins[0](Tx_0)# propagate_type: (x: Tensor, norm: Tensor)if len(self.lins) > 1:Tx_1 = self.propagate(edge_index, x=x, norm=norm, size=None)#每一轮的propagate相当于对每个点,计算所有邻边的拉普拉斯矩阵权重*临近点,再求和【aggr=add】out = out + self.lins[1](Tx_1)#Z_2=LXfor lin in self.lins[2:]:Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm, size=None)#Tx_2=Z_k=L*Z_k-1Tx_2 = 2. * Tx_2 - Tx_0#Z_k=2*L*k-1-Z_k-2out = out + lin.forward(Tx_2)Tx_0, Tx_1 = Tx_1, Tx_2if self.bias is not None:out += self.biasreturn outdef message(self, x_j, norm):return norm.view(-1, 1) * x_j#就是对应的邻边权重*邻接点def __repr__(self):return '{}({}, {}, K={}, normalization={})'.format(self.__class__.__name__, self.in_channels, self.out_channels,len(self.lins), self.normalization)

6 举例

from torch_geometric.nn import ChebConvdata
#Batch(x=[9893, 1], edge_index=[2, 34637], y=[9893, 1], batch=[9893], ptr=[2])conv1 = ChebConv(1, 32, 2)x = conv1(data.x, data.edge_index)type(x)
#torch.Tensorx.shape
#torch.Size([9893, 32]) 每个点的维度是[9893,32]

torch_geometric 笔记:nn.ChebNet相关推荐

  1. torch_geometric笔记:nn. graclus (图点分类)

    torch_geometric.nn.graclus(edge_index, weight: Optional[torch.Tensor] = None, num_nodes: Optional[in ...

  2. torch_geometric 笔记:TORCH_GEOMETRIC.UTILS(更新中)

    1 torch_geometric.utils.add_self_loops add_self_loops(edge_index, edge_weight: Optional[torch.Tensor ...

  3. torch_geometric笔记:max_pool 与max_pool_x

    1 max_pool 1.1 函数介绍 torch_geometric.nn.max_pool(cluster, data,transform=None) 对由torch_geometricy .da ...

  4. torch_geometric 笔记: 数据集Cora 简易 GNN

    1 获取数据集 该数据集用于semi-supervised的节点分类任务 from torch_geometric.datasets import Planetoiddataset = Planeto ...

  5. torch_geometric笔记:数据集 ENZYMES Minibatches

    Pytorch Geometric中包含大量的常见基准数据集.在初始化数据集的时候,框架会自动下载数据集的原始文件,并将其处理为Data对象.例如要下载ENZYMES数据集(由600个graph划分为 ...

  6. Hadoop学习笔记-NN与2NN

    几个概念 NameNode被格式化之后,将在/opt/module/hadoop-3.1.3/data/tmp/dfs/name/current目录中产生如下文件: fsimage_000000000 ...

  7. torch_geometric 笔记:global_mean_pool

    对全图的点嵌入(node embedding)进行池化操作,返回一个图嵌入(graph embedding) global_mean_pool(x, batch, size=None)[ x (tor ...

  8. pytorch 学习笔记目录

    1 部分内容 pytorch笔记 pytorch模型中的parameter与buffer_刘文巾的博客-CSDN博客 pytorch学习笔记 torchnn.ModuleList_刘文巾的博客-CSD ...

  9. Pytorch使用笔记

    Pytorch使用笔记 nn.Module CNN torch.nn.Conv1d torch.nn.Conv2d torch.nn.ConvTranspose1d RNN LSTM self.reg ...

最新文章

  1. html右侧浮动栏随着滚动,jQuery实现div浮动层跟随页面滚动效果
  2. pfSense 2.4.4-RELEASE现已发布!
  3. APDU命令的结构和处理【转】
  4. 教育谋定应用型高校 经济和信息化研究共建成都工业学院
  5. 赶集网MySQL开发36军规
  6. java转换汇编,请问如何把JAVA程序转为汇编?
  7. 程序员相比于黑客(Hacker),差距有多远?看看程序员怎么说!
  8. 计算机组成原理 第六章 总线
  9. python读取xml编码gb2312_【转】python XML 操作总结(创建、保存和删除,支持utf-8和gb2312)...
  10. cmake使用教(一)多目录下多个文件的构建
  11. 搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(四)
  12. phonegap文件上传(java_php),Android应用开发之使用PhoneGap实现位置上报功能
  13. Eprime 倒计时代码
  14. FeedingBottle3.2的下载网站
  15. 基于asp.net338医院体检信息管理系统
  16. 医学信息学计算机技术,2017年医学信息学专业大学排名
  17. highCharts x轴过长
  18. 【Leetcode】打气球的最大分数 (暴力递归+动态规划)
  19. Unity中一些小技巧
  20. R语言分析财收与税收的线性回归关系

热门文章

  1. python进阶资源整理
  2. 用OleDb写的一个导出Excel的方法
  3. 802.11协议中的广播与tcp/ip中的广播
  4. javaweb学习总结(五)——Servlet开发(一)
  5. 算法提高课-图论-单源最短路的建图方式-AcWing 903. 昂贵的聘礼:建图巧妙、dijkstra、考虑等级
  6. 2020年高等数学方法与提高(上海理工大学)学习笔记:多元函数积分学
  7. html正则表达式确认密码,如何使用正则表达式在流星中验证确认密码
  8. 浪潮linux网卡驱动,浪潮NF5280M5安装redhat7.2下网卡驱动
  9. php的封装继承多态,PHP面向对象深入理解之二(封装、继承、多态、克隆)
  10. java.text.dateformat_使用java.text.SimpleDateFormat类进行文本日期和Date日期的转换