消息传递的图神经网络_图神经网络 信息传递 公式-程序员宅基地

一、消息传递范式介绍

消息传递范式是一种聚合邻接节点信息来更新中心节点信息的范式,它将卷积算子推广到不规则数据领域,实现了图与神经网络的连接。此范式包含三个步骤:(1)邻接节点信息变换;(2)邻接节点信息聚合到中心节点;(3)聚合信息变换。

消息传递图神经网络可以描述为:
x i ( k ) = γ ( k ) ( x i ( k − 1 ) , □ j ∈ N ( i )   ϕ ( k ) ( x i ( k − 1 ) , x j ( k − 1 ) , e j , i ) ) , \mathbf{x}_i^{(k)} = \gamma^{(k)} \left( \mathbf{x}_i^{(k-1)}, \square_{j \in \mathcal{N}(i)} \, \phi^{(k)}\left(\mathbf{x}_i^{(k-1)}, \mathbf{x}_j^{(k-1)},\mathbf{e}_{j,i}\right) \right), xi(k)=γ(k)(xi(k1),jN(i)ϕ(k)(xi(k1),xj(k1),ej,i)),
x i ( k − 1 ) ∈ R F \mathbf{x}^{(k-1)}_i\in\mathbb{R}^F xi(k1)RF表示(k-1)层中节点i的节点特征, e j , i ∈ R D \mathbf{e}_{j,i} \in \mathbb{R}^D ej,iRD表示从节点j到节点i的边的特征, □ \square 表示可微分的、具有排列不变形的函数,具有排列不变形的函数有和函数、均值函数和最大值函数。 γ \gamma γ ϕ \phi ϕ表示可微分的函数。

二、Pytorch Geometric中的MessagePassing基类

Pytorch Geometric提供了MessagePassing类,实现了消息传播的自动处理,继承该基类可以方便地构造消息传递图神经网络,我们只需要定义函数 ϕ \phi ϕ(即message函数)和函数 γ \gamma γ(即update函数),以及消息聚合方案(aggr=“add”、aggr="mean"或aggr=“max”)。

  • MessagePassing(aggr=“add”, flow=“source_to_target”, node_dim=-2):
    aggr: 定义要使用的聚合方案(“add”、“mean"或"max”)
    flow: 定义消息传递的流向(“source_to_target"或"target_to_source”)
    node_dim: 定义沿着哪个轴线传播

  • MessagePassing.propagate(edge_index, size=None, **kwargs):
    开始传播消息的起始调用。它以edge_index(边的端点的索引)和flow(消息的流向)以及一些额外的数据为参数。
    size=(N,M)设置对称邻接矩阵的形状。

  • MessagePassing.message(…)接受最初传递给propagate函数的所有参数。

  • MessagePassing.aggregate(…)将从源节点传递过来的消息聚合在目标节点上,一般可选的聚合方式有sum,mean和max。

  • MessagePassing.message_and_aggregate(…)融合了邻接节点信息变换和邻接节点信息聚合。

  • MessagePassing.update(aggr_out, …)为每个节点更新节点表征,即实现 γ \gamma γ函数。该函数以聚合函数的输出为第一参数,并接收所有传递给propagate函数的参数。

三、继承MessagePassing类的GCNConv

GCNConv的数学定义为:
x i ( k ) = ∑ j ∈ N ( i ) ∪ { i } 1 deg ⁡ ( i ) ⋅ deg ⁡ ( j ) ⋅ ( Θ ⋅ x j ( k − 1 ) ) , \mathbf{x}_i^{(k)} = \sum_{j \in \mathcal{N}(i) \cup \{ i \}} \frac{1}{\sqrt{\deg(i)} \cdot \sqrt{\deg(j)}} \cdot \left( \mathbf{\Theta} \cdot \mathbf{x}_j^{(k-1)} \right), xi(k)=jN(i){ i}deg(i) deg(j) 1(Θxj(k1)),
其中相邻节点的特征通过权重矩阵 Θ \mathbf{\Theta} Θ进行转换,然后按端点的度进行归一化处理,最后进行加总。这个公式可以分为以下几个步骤:

  1. 向邻接矩阵添加自环边。
  2. 线性转换节点特征矩阵。
  3. 计算归一化系数。
  4. 归一化j中的节点特征。
  5. 将相邻节点特征相加。
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        # "Add" aggregation (Step 5).
        # flow='source_to_target' 表示消息从源节点传播到目标节点
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j

# 初始化和调用
conv = GCNConv(16, 32)
x = conv(x, edge_index)

四、复写message函数

class GCNConv(MessagePassing):
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i

五、覆写aggregate函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

六、覆写aggregate函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        return self.propagate(edge_index, x=x, norm=norm, d=d)

    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)

七、覆写message_and_aggregate函数

from torch_sparse import SparseTensor

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')
        
    def forward(self, x, edge_index):
        # ....
        adjmat = SparseTensor(row=edge_index[0], col=edge_index[1], value=torch.ones(edge_index.shape[1]))
        # 此处传的不再是edge_idex,而是SparseTensor类型的Adjancency Matrix
        return self.propagate(adjmat, x=x, norm=norm, d=d)
    
    def message(self, x_j, norm, d_i):
        # x_j has shape [E, out_channels]
        return norm.view(-1, 1) * x_j * d_i # 这里不管正确性
    
    def aggregate(self, inputs, index, ptr, dim_size):
        print(self.aggr)
        print("`aggregate` is called")
        return super().aggregate(inputs, index, ptr=ptr, dim_size=dim_size)
    
    def message_and_aggregate(self, adj_t, x, norm):
        print('`message_and_aggregate` is called')

八、覆写update函数

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add', flow='source_to_target')

    def update(self, inputs: Tensor) -> Tensor:
        return inputs
版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/Lazybones_3/article/details/118061456

智能推荐

python字符串类型定义_python--字符串类型-程序员宅基地

文章浏览阅读1.2k次。*************** 字符串类型 ***************1.字符串的定义:第一种方式:str1 = 'our company is westos'第二种方式:str2 = "our company is westos"第三种方式:str3 = """our company is westos"""2.转义符号一个反斜线加一个单一字符可以表示一个特殊字符,通常是不可打印的字符\n:..._python字符串类型定义

计算机图形学和工程图学,计算机图形学与印刷工程我与工程图学及计算机图形学...-程序员宅基地

文章浏览阅读418次。一、我与浙大工程图学一同发展      我国工程图学课程量大面广,从新中国成立以来主要承袭于前苏联,经过50多年的建设、改革和发展,课程体系形成了较明显的中国特色,既不同于欧美,也不同于前苏联。可以分为如下四个发展阶段:第一阶段***1949-1962年***称为初期阶段,1957年教育部邀请苏联专家在清华大学举办了画法几何及机械制图教学研究进修班,在总结交流基础上,初步形成了中国模式的工程制图教..._印刷工程 计算机相关课程

(转)C语言家族扩展_c语言家族分支-程序员宅基地

文章浏览阅读3.1k次。(转)C语言家族扩展 翻译:5.1--5.6林峰5.7--5.20董溥5.21--5.26王聪5.27--5.34刘洋5.35--5.43贾孟树致谢:感谢陈老师指出其中的一些错误,已修订。 修订记录: 修订一些用词和标点符号。(董溥)2007年1月12号修订一些用词和错别字。(王聪)2006年12月14号修正一些错误的标签。(王聪,董溥)2006年12月13号GNU C提供了多种在ISO标准C中没有的特性。(‘-pedantic’选项会使GC_c语言家族分支

elementui css,elementUI——主题及自定义-程序员宅基地

文章浏览阅读1.3k次。说明:本文基于[email protected],源码详见element。一、主题相关代码结构:可以看出两点:a. 每个element ui组件基本都对应有单独的scss文件;b. 单独的组件scss文件,支持了组件按需引入时,样式部分也能按需引入的诉求;c.theme-chalk/src/common和theme-chalk/src/mixins目录下,主要是一些公共样式的设置,全局sass变量..._elementui css

H5 如何实现唤起 APP_如何实现拼多多h5直接跳转淘宝或者拼多多-程序员宅基地

文章浏览阅读1.5k次。写过hybrid的同学,想必都会遇到这样的需求,如果用户安装了自己的APP,就打开APP或跳转到APP内某个页面,如果没安装则引导用户到对应页面或应用商店下载。这里就涉及到了H5与Native之间的交互,为什么H5能够唤起APP并且跳转到对应的页面?就算你没写过想必也体验过,最常见的就是一些广告了,如果你点击了广告,他判断你手机装了对应APP,那他就会去打开那个APP,如果没安装,他会帮你跳转到应用商店去下载,这个还算人性化一点的,有些直接后台给你去下载,你完全无感知。_如何实现拼多多h5直接跳转淘宝或者拼多多

《计算机网络技术》教材分析,《计算机网络技术基础》教学计划-程序员宅基地

文章浏览阅读324次。《《计算机网络技术基础》教学计划》由会员分享,可在线阅读,更多相关《《计算机网络技术基础》教学计划(4页珍藏版)》请在人人文库网上搜索。1、临 湘 市 职 业 中 专教师工作计划及实施情况表姓 名 周 小 敏 任教科目 计算机网络技术基础 任教年级 高 二 班 次 14级计算机、苹果班 教研组长 李 岳 军 教务主任 冯 云 主管校长 李 晓 红 时 间 2016 年 上 学期本期教学的主要任务和..._计算机网络教材分析

随便推点

php mysql安装图解_MySQL安装教程图解-程序员宅基地

文章浏览阅读245次。下面的是MySQL安装的图解,用的可执行文件安装的,详细说明了一下!打开下载的mysql安装文件mysql-5.0.27-win32.zip,双击解压缩,运行setup.exe,出现如下界面 mysql安装图文教程1 mysql安装向导启动,按Next继续 mysql图文安装教程2 选择安装类型,有Typica下面的是MySQL安装的图解,用的可执行文件安装的,详细说明了一下!打开下载的mysql..._php-mysql安装

python支持oracle的驱动_python 支持oracle数据库-程序员宅基地

文章浏览阅读339次。大年初一mark一下新的一年将会下线个人ORACLE外文blog,精力有限,会全力在阿里云云栖社区分享,主要内容还是数据库相关,包括但不限于以下内容:ORACLE数据库性能分析PostgreSQL数据库全栈支持PPAS数据库全栈支持专注ORACLE数据库和应用迁移至阿里云PPAS、PostgreSQL等数据...文章唐修2019-02-051169浏览量PostgreSQL修炼之道:从小工到专家...._python官方镜像里面没有oracle驱动

mysql导入报错1071_导入sql文件报错:1071 Specified key was too long; max key length is 767 bytes...-程序员宅基地

文章浏览阅读644次。一、背景今天把服务器的数据库导出了一份sql文件,准备导入到本地,但是在导入的时候,报了个错:Syntax error or access violation: 1071 Specified key was too long; max key length is 767 bytes这就很奇怪了,明明服务器上都可以,凭什么我这边就报错呢。二、错误分析1、错误部分的sql文件CREATE TABLE ..._mysql导入报错[err] cannot create table [商品成本]: 1071 - specified key was

小米路由r2d论坛_小米路由器R2D固件 V2.24.10 官方稳定版-程序员宅基地

文章浏览阅读3k次。小米路由器R2D固件是小米官方为其推出的路由器刷机更新固件,优化了Wifi能力,同时修复了系统中的一些小问题,提升了系统的稳定性和安全性,如果你的路由器出现了一些小问题,可以通过这个固件来更新下系统。【更新说明】1.优化了共享WiFi功能,提升了使用体验2.修复了一些小问题,提高了系统安全性及稳定性【刷机教程】1、准备一个系统格式为FAT或FAT32的U盘;重要的事情再说三遍:U盘刷机会清空硬盘上..._小米路由r2d论坛

OpenWRT使用SNMP监测网络状态_openwrt snmp-程序员宅基地

文章浏览阅读1.9w次。最近在写毕业论文,需要监测路由器的网络状态,路由器是TP-Link TL1043ND v2,操作系统版本是OpenWRT 15.05,本来打算在路由器和监测机上使用Socket通信来交互信息的,写着写着发现太麻烦了,因为路由器有许多个,就需要在监测机上要实现多线程之类的东西,后来发现OpenWRT上已经有编译好的SNMP包了,于是就直接用它了,然后通过配置snmpd.config文件来扩展,调用iw_openwrt snmp

python中concat的用法_pandas中concat()的用法-程序员宅基地

文章浏览阅读5.4k次。pandas.concat()通常用来连接DataFrame对象。默认情况下是对两个DataFrame对象进行纵向连接, 当然通过设置参数,也可以通过它实现DataFrame对象的横向连接。让我们通过几个例子来看看concat()的用法。1. 纵向连接DataFrame对象(1)两个DataFrame对象的列完全相同# 初始化两个DataFrame对象df1 = pd.DataFrame([['a..._python中concat函数