GCN-tensorflow2.0代码实现_keras gcn tensorflow2-程序员宅基地

技术标签: tensorflow  深度学习  

文章目录

代码

定义图卷积层

import tensorflow as tf
from tensorflow.keras import activations, regularizers, constraints, initializers

class GCNConv(tf.keras.layers.Layer):
    def __init__(self,
                 units,
                 activation=lambda x: x,
                 use_bias=True,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 **kwargs):
        super(GCNConv, self).__init__()

        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        self.kernel_initializer = initializers.get(kernel_initializer)
        self.bias_initializer = initializers.get(bias_initializer)


    def build(self, input_shape):
        """ GCN has two inputs : [shape(An), shape(X)]
        """
        fdim = input_shape[1][1]  # feature dim
        # 初始化权重矩阵
        self.weight = self.add_weight(name="weight",
                                      shape=(fdim, self.units),
                                      initializer=self.kernel_initializer,
                                      trainable=True)
        if self.use_bias:
            # 初始化偏置项
            self.bias = self.add_weight(name="bias",
                                        shape=(self.units, ),
                                        initializer=self.bias_initializer,
                                        trainable=True)

    def call(self, inputs):
        """ GCN has two inputs : [An, X]
        """
        self.An = inputs[0]
        self.X = inputs[1]
        # 计算 XW
        if isinstance(self.X, tf.SparseTensor):
            h = tf.sparse.sparse_dense_matmul(self.X, self.weight)
        else:
            h = tf.matmul(self.X, self.weight)
        # 计算 AXW
        output = tf.sparse.sparse_dense_matmul(self.An, h)

        if self.use_bias:
            output = tf.nn.bias_add(output, self.bias)

        if self.activation:
            output = self.activation(output)

        return output

定义 GCN 模型

class GCN():
    def __init__(self, An, X, sizes, **kwargs):
        self.with_relu = True
        self.with_bias = True

        self.lr = FLAGS.learning_rate
        self.dropout = FLAGS.dropout
        self.verbose = FLAGS.verbose
        
        self.An = An
        self.X = X
        self.layer_sizes = sizes
        self.shape = An.shape

        self.An_tf = sp_matrix_to_sp_tensor(self.An)
        self.X_tf = sp_matrix_to_sp_tensor(self.X)

        self.layer1 = GCNConv(self.layer_sizes[0], activation='relu')
        self.layer2 = GCNConv(self.layer_sizes[1])
        self.opt = tf.optimizers.Adam(learning_rate=self.lr)

    def train(self, idx_train, labels_train, idx_val, labels_val):
        K = labels_train.max() + 1
        train_losses = []
        val_losses = []
        # use adam to optimize
        for it in range(FLAGS.epochs):
            tic = time()
            with tf.GradientTape() as tape:
                _loss = self.loss_fn(idx_train, np.eye(K)[labels_train])

            # optimize over weights
            grad_list = tape.gradient(_loss, self.var_list)
            grads_and_vars = zip(grad_list, self.var_list)
            self.opt.apply_gradients(grads_and_vars)

            # evaluate on the training
            train_loss, train_acc = self.evaluate(idx_train, labels_train, training=True)
            train_losses.append(train_loss)
            val_loss, val_acc = self.evaluate(idx_val, labels_val, training=False)
            val_losses.append(val_loss)
            toc = time()
            if self.verbose:
                print("iter:{:03d}".format(it),
                      "train_loss:{:.4f}".format(train_loss),
                      "train_acc:{:.4f}".format(train_acc),
                      "val_loss:{:.4f}".format(val_loss),
                      "val_acc:{:.4f}".format(val_acc),
                      "time:{:.4f}".format(toc - tic))
        return train_losses

    def loss_fn(self, idx, labels, training=True):
        if training:
            # .nnz 是获得X中元素的个数
            _X = sparse_dropout(self.X_tf, self.dropout, [self.X.nnz])
        else:
            _X = self.X_tf

        self.h1 = self.layer1([self.An_tf, _X])
        if training:
            _h1 = tf.nn.dropout(self.h1, self.dropout)
        else:
            _h1 = self.h1

        self.h2 = self.layer2([self.An_tf, _h1])
        self.var_list = self.layer1.weights + self.layer2.weights
        # calculate the loss base on idx and labels
        _logits = tf.gather(self.h2, idx)
        _loss_per_node = tf.nn.softmax_cross_entropy_with_logits(labels=labels,
                                                                 logits=_logits)
        _loss = tf.reduce_mean(_loss_per_node)
        # 加上 l2 正则化项
        _loss += FLAGS.weight_decay * sum(map(tf.nn.l2_loss, self.layer1.weights))
        return _loss

    def evaluate(self, idx, true_labels, training):
        K = true_labels.max() + 1
        _loss = self.loss_fn(idx, np.eye(K)[true_labels], training=training).numpy()
        _pred_logits = tf.gather(self.h2, idx)
        _pred_labels = tf.argmax(_pred_logits, axis=1).numpy()
        _acc = accuracy_score(_pred_labels, true_labels)
        return _loss, _acc

训练模型

# 计算标准化的邻接矩阵:根号D * A * 根号D
def preprocess_graph(adj):
    # _A = A + I
    _adj = adj + sp.eye(adj.shape[0])
    # _dseq:各个节点的度构成的列表
    _dseq = _adj.sum(1).A1
    # 构造开根号的度矩阵
    _D_half = sp.diags(np.power(_dseq, -0.5))
    # 计算标准化的邻接矩阵, @ 表示矩阵乘法
    adj_normalized = _D_half @ _adj @ _D_half
    return adj_normalized.tocsr()

if __name__ == "__main__":
    # 读取数据
    # A_mat:邻接矩阵,以scipy的csr形式存储
    # X_mat:特征矩阵,以scipy的csr形式存储
    # z_vec:label
    # train_idx,val_idx,test_idx: 要使用的节点序号
    A_mat, X_mat, z_vec, train_idx, val_idx, test_idx = load_data_planetoid(FLAGS.dataset)
    # 邻居矩阵标准化
    An_mat = preprocess_graph(A_mat)

    # 节点的类别个数
    K = z_vec.max() + 1

    # 构造GCN模型
    gcn = GCN(An_mat, X_mat, [FLAGS.hidden1, K])
    # 训练
    gcn.train(train_idx, z_vec[train_idx], val_idx, z_vec[val_idx])
    # 测试
    test_res = gcn.evaluate(test_idx, z_vec[test_idx], training=False)
    print("Dataset {}".format(FLAGS.dataset),
          "Test loss {:.4f}".format(test_res[0]),
          "test acc {:.4f}".format(test_res[1]))

详细介绍请看这篇博客https://blog.csdn.net/VariableX/article/details/109820684

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/in546/article/details/120513954

智能推荐

ImportError: No module named six,已解决-程序员宅基地

文章浏览阅读2.9k次。问题如题目所示。在解决问题之前,我们先来看看six这么模块是什么。six: 一个专门用来兼容 Python 2 和 Python 3 的库。它解决了诸如 urllib 的部分方法不兼容, str 和 bytes 类型不兼容等问题。问题:我是在调用pandas时出现的这个问题。在搜索解决办法时,我发现很多人在导入Matlibplot等其他库时也遇到了同样的问题,出错页面如下:解决办法:下载six..._importerror: no module named six

我的博客今天2岁104天了,我领取了…-程序员宅基地

文章浏览阅读45次。我的博客今天2岁104天了,我领取了徽章.2011.06.09,我在新浪博客安家。2011.06.09,我写下了第一篇博文:《看懂这些故事 你做人就很成功了》。2011.06.09,我上传了第一张图片到相册。至今,我的博客共获得174次访问。这些年,新浪博客伴我点点滴滴谱写生活! ...

C/C++ 找出最大素数 算法_c++找最大素数-程序员宅基地

文章浏览阅读5.6k次,点赞2次,收藏7次。24.【中学】找出最大素数小明在中学学习了什么是素数。素数是指一个只能被1和它本身整除的数,在数论中占有重要的研究地位,在当代密码学中也被广泛应用。输入:取值范围输出:该范围内的最大素数#include <stdio.h> #include <stdlib.h> int main(int argc, char *argv[]) { int input = 0, answer = 0; scanf("%d", &input); _c++找最大素数

【最新】IDEA 2021.3 使用_janf_config.txt-程序员宅基地

文章浏览阅读8.1k次。IDEA 2021.3之前的版本都可以采用重试脚本的方式来使用IDEA,最近IDEA官方发布了新款2021.3的版本,原作者不在维护升级重置脚本,遂提供了新型使用方法,如下下载JAR包地址:https://github.com/pengzhile/ja-netfilter/releases编写配置文件指定你自己的目录创建此文件:janf_config.txt,最佳实践为把刚刚的Jar包和这个文件放置在同一目录下,就不用二次指定这个文件了# jb 的 janf_config.txt 配置文件[D_janf_config.txt

在Windows10\11的WSL2中使用图形化应用Chrome和搜狗拼音输入法_win11 wsl2 图形化-程序员宅基地

文章浏览阅读1.4k次,点赞24次,收藏18次。在Windows的WSL2中使用图形化应用Chrome,并配置中文界面和输入法。_win11 wsl2 图形化

python绘制Loss和Acc曲线+读取txt和log文件_python 提取txt中包含loss的行的数据-程序员宅基地

文章浏览阅读5.5k次,点赞4次,收藏40次。log文件如下所示:代码def read_log(filename): fp=open(filename) for line in fp.readlines(): train_loss=line[-27:-18] train_acc=line[-8:] with open('test.txt','a') as fp2: fp2.write(train_loss+train_acc) fp.close()_python 提取txt中包含loss的行的数据

随便推点

Android取消RecyclerView、ListView、ScrollView、HorizontalScrollView滑动到边缘闪现灰白色水波纹动画_android recyclerview去除下拉时波纹-程序员宅基地

文章浏览阅读6.1k次。Android取消RecyclerView、ListView、ScrollView、HorizontalScrollView滑动到边缘闪现灰白色水波纹动画标准的Android RecyclerView、ListView、ScrollView、HorizontalScrollView滑动到边缘,会闪现灰白色水波纹动画,以这样大的动画效果提示用户已经滑动到边缘,没法再滑动了。对于这种增强体验是一个很好..._android recyclerview去除下拉时波纹

javaFx新建弹窗页面并传值_javafx窗口传值-程序员宅基地

文章浏览阅读2.6k次。由于之前图省事在弹窗Controller类中用static定义变量接受原始页面传值导致被sonar校验,特地研究了一下javaFx向弹窗传值的方式。 方式有两种 1、直接传一个controler实例过来,后面弹窗页面如果有用到的话直接可以从controller中获取。 2、传需要的属性到工具类的setControllerFactory中,直接赋值给新建的窗口controller。 我的创建窗口工具类如下,关键在于lorder.setControl..._javafx窗口传值

Linux+libusb开发用户USB无驱通讯_libusb 无驱-程序员宅基地

文章浏览阅读1.1k次。项目上需要将一个自己开发的设备通过USB接口 连接到Linux系统或Android系统的设备,然后通过发送命令来控制我们的设备操作。要求做到“即插即用”,不需要再安装驱动,于是想到用libusb库来做。 在网上搜索了一些关于libusb的使用方法,写篇文章记录下开发过程,主要解决3个问题:① libusb是什么?② libusb有什么用?③ libusb怎么用?1. li_libusb 无驱

cmake:LINK : error LNK2001: 无法解析的外部符号 WinMainCRTStartup_cmake rtk编译报错无法解析的外部符号-程序员宅基地

文章浏览阅读5.8k次。正在设计的一个C/C++混合语言项目是用cmake来管理编译的,用cmake生成的一个Visual Studio工程(c++)在编译时报了个错: LINK : error LNK2001: 无法解析的外部符号 WinMainCRTStartup好是莫名其妙的问题,之前是没有这个问题的,反复查看了GIT提交记录,发现问题出在cmake脚本中 原本项目的定义是这样的,语言指定C,CXX..._cmake rtk编译报错无法解析的外部符号

Win10 Cortana 搜索框字体颜色_win10搜索框字体变成绿色-程序员宅基地

文章浏览阅读3k次。不知什么原因,cortana搜索框的字体颜色突然变成白色,搜索框背景也是白色,这就直接导致看不到自己输入的内容。怎么解决呢?找了一圈也没找到啥好的办法。google了一下才找到办法。据说这个是微软服务器那边的bug。解决办法如下:进入开始 --》 设置 --》时间和语言区域和语言–》国家和地区–》也门注销登入,恢复正常修改区域回中国继续happy吧。..._win10搜索框字体变成绿色

repo下载国内链接android源码_repo下载安卓源码-程序员宅基地

文章浏览阅读1w次,点赞5次,收藏3次。刚好碰到要下载一个指定版本的android源码,在网上没有找到,所以只能自己去下载,看了谷歌官方下载帮助但是苦于墙抽风,下载速度也奇慢,所以找了几个国内的源头下载,那么跟着我动起来,在这之前你需要一台装有linux的电脑或者是虚拟机上装有linux也是可以的. (一).科普一下git与repo的区别 1. Git:Git是一个开源的分布式版本控制系统,用以有效、高速的处理从很小到非常大的项_repo下载安卓源码

推荐文章

热门文章

相关标签