Skip to content

Convolution Layers

Bases: MessagePassing

k3_node.layers.AGNNConv Implementation of Attention-based Graph Neural Network (AGNN) layer

Parameters:

Name Type Description Default
trainable

Whether to learn the scaling factor beta.

True
aggregate

Aggregation function to use (one of 'sum', 'mean', 'max').

'sum'
activation

Activation function to use.

None
**kwargs

Additional arguments to pass to the MessagePassing superclass.

{}
Source code in k3_node/layers/conv/agnn_conv.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class AGNNConv(MessagePassing):
    """
    `k3_node.layers.AGNNConv` 
    Implementation of Attention-based Graph Neural Network (AGNN) layer

    Args:
        trainable: Whether to learn the scaling factor beta.
        aggregate: Aggregation function to use (one of 'sum', 'mean', 'max').
        activation: Activation function to use.
        **kwargs: Additional arguments to pass to the `MessagePassing` superclass.
    """
    def __init__(self, trainable=True, aggregate="sum", activation=None, **kwargs):
        super().__init__(aggregate=aggregate, activation=activation, **kwargs)
        self.trainable = trainable

    def build(self, input_shape):
        assert len(input_shape) >= 2
        if self.trainable:
            self.beta = self.add_weight(shape=(1,), initializer="ones", name="beta")
        else:
            self.beta = ops.cast(1.0, self.dtype)
        self.built = True

    def call(self, x, a, **kwargs):
        x_norm = keras.utils.normalize(x, axis=-1)
        output = self.propagate(x, a, x_norm=x_norm)
        output = self.activation(output)

        return output

    def message(self, x, x_norm=None):
        x_j = self.get_sources(x)
        x_norm_i = self.get_targets(x_norm)
        x_norm_j = self.get_sources(x_norm)
        alpha = self.beta * ops.sum(x_norm_i * x_norm_j, axis=-1)

        if len(alpha.shape) == 2:
            alpha = ops.transpose(alpha)  # For mixed mode
        alpha = segment_softmax(alpha, self.index_targets, self.n_nodes)
        if len(alpha.shape) == 2:
            alpha = ops.transpose(alpha)  # For mixed mode
        alpha = alpha[..., None]

        return alpha * x_j

    @property
    def config(self):
        return {
            "trainable": self.trainable,
        }

Bases: Conv

k3_node.layers.APPNPConv Implementation of Approximate Personalized Propagation of Neural Predictions

Parameters:

Name Type Description Default
channels

The number of output channels.

required
alpha

The teleport probability.

0.2
propagations

The number of propagation steps.

1
mlp_hidden

A list of hidden channels for the MLP.

None
mlp_activation

The activation function to use in the MLP.

'relu'
dropout_rate

The dropout rate for the MLP.

0.0
activation

The activation function to use in the layer.

None
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional keyword arguments.

{}
Source code in k3_node/layers/conv/appnp_conv.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class APPNPConv(Conv):
    """
        `k3_node.layers.APPNPConv`
        Implementation of Approximate Personalized Propagation of Neural Predictions

        Args:
            channels: The number of output channels.
            alpha: The teleport probability.
            propagations: The number of propagation steps.
            mlp_hidden: A list of hidden channels for the MLP.
            mlp_activation: The activation function to use in the MLP.
            dropout_rate: The dropout rate for the MLP.
            activation: The activation function to use in the layer.
            use_bias: Whether to add a bias to the linear transformation.
            kernel_initializer: Initializer for the `kernel` weights matrix.
            bias_initializer: Initializer for the bias vector.
            kernel_regularizer: Regularizer for the `kernel` weights matrix.
            bias_regularizer: Regularizer for the bias vector.
            activity_regularizer: Regularizer for the output.
            kernel_constraint: Constraint for the `kernel` weights matrix.
            bias_constraint: Constraint for the bias vector.
            **kwargs: Additional keyword arguments.
    """
    def __init__(
        self,
        channels,
        alpha=0.2,
        propagations=1,
        mlp_hidden=None,
        mlp_activation="relu",
        dropout_rate=0.0,
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):

        super().__init__(
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )
        self.channels = channels
        self.mlp_hidden = mlp_hidden if mlp_hidden else []
        self.alpha = alpha
        self.propagations = propagations
        self.mlp_activation = activations.get(mlp_activation)
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        assert len(input_shape) >= 2
        layer_kwargs = dict(
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            dtype=self.dtype,
        )
        mlp_layers = []
        for channels in self.mlp_hidden:
            mlp_layers.extend(
                [
                    Dropout(self.dropout_rate),
                    Dense(channels, self.mlp_activation, **layer_kwargs),
                ]
            )
        mlp_layers.append(Dense(self.channels, "linear", **layer_kwargs))
        self.mlp = Sequential(mlp_layers)
        self.built = True

    def call(self, inputs, mask=None):
        x, a = inputs
        mlp_out = self.mlp(x)
        output = mlp_out
        for _ in range(self.propagations):
            output = (1 - self.alpha) * modal_dot(a, output) + self.alpha * mlp_out
        if mask[0] is not None:
            output *= mask[0]
        output = self.activation(output)

        return output

    @property
    def config(self):
        return {
            "channels": self.channels,
            "alpha": self.alpha,
            "propagations": self.propagations,
            "mlp_hidden": self.mlp_hidden,
            "mlp_activation": activations.serialize(self.mlp_activation),
            "dropout_rate": self.dropout_rate,
        }

    @staticmethod
    def preprocess(a):
        return gcn_filter(a)

Bases: Conv

k3_node.layers.ARMAConv Implementation of ARMAConv layer

Parameters:

Name Type Description Default
channels

The number of output channels.

required
order

The order of the ARMA filter.

1
iterations

The number of iterations to perform.

1
share_weights

Whether to share the weights across iterations.

False
gcn_activation

The activation function to use for GCN.

'relu'
dropout_rate

The dropout rate.

0.0
activation

The activation function to use in the layer.

None
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional keyword arguments.

{}
Source code in k3_node/layers/conv/arma_conv.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
class ARMAConv(Conv):
    """
    `k3_node.layers.ARMAConv` 
    Implementation of ARMAConv layer

    Args:
        channels: The number of output channels.
        order: The order of the ARMA filter.
        iterations: The number of iterations to perform.
        share_weights: Whether to share the weights across iterations.
        gcn_activation: The activation function to use for GCN.
        dropout_rate: The dropout rate.
        activation: The activation function to use in the layer.
        use_bias: Whether to add a bias to the linear transformation.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        bias_regularizer: Regularizer for the bias vector.
        activity_regularizer: Regularizer for the output.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional keyword arguments.
    """
    def __init__(
        self,
        channels,
        order=1,
        iterations=1,
        share_weights=False,
        gcn_activation="relu",
        dropout_rate=0.0,
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):

        super().__init__(
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )
        self.channels = channels
        self.iterations = iterations
        self.order = order
        self.share_weights = share_weights
        self.gcn_activation = activations.get(gcn_activation)
        self.dropout_rate = dropout_rate

    def build(self, input_shape):
        assert len(input_shape) >= 2
        F = input_shape[0][-1]

        # Create weights for parallel stacks
        # self.kernels[k][i] refers to the k-th stack, i-th iteration
        self.kernels = []
        for k in range(self.order):
            kernel_stack = []
            current_shape = F
            for i in range(self.iterations):
                kernel_stack.append(
                    self.create_weights(
                        current_shape, F, self.channels, "ARMA_GCS_{}{}".format(k, i)
                    )
                )
                current_shape = self.channels
                if self.share_weights and i == 1:
                    # No need to continue because all weights will be shared
                    break
            self.kernels.append(kernel_stack)

        self.dropout = Dropout(self.dropout_rate, dtype=self.dtype)
        self.built = True

    def call(self, inputs, mask=None):
        x, a = inputs

        output = []
        for k in range(self.order):
            output_k = x
            for i in range(self.iterations):
                output_k = self.gcs([output_k, x, a], k, i)
            output.append(output_k)
        output = ops.stack(output, axis=-1)
        output = ops.mean(output, axis=-1)

        if mask[0] is not None:
            output *= mask[0]
        output = self.activation(output)

        return output

    def create_weights(self, input_dim, input_dim_skip, channels, name):
        kernel_1 = self.add_weight(
            shape=(input_dim, channels),
            name=name + "_kernel_1",
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        kernel_2 = self.add_weight(
            shape=(input_dim_skip, channels),
            name=name + "_kernel_2",
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        bias = None
        if self.use_bias:
            bias = self.add_weight(
                shape=(channels,),
                name=name + "_bias",
                initializer=self.bias_initializer,
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        return kernel_1, kernel_2, bias

    def gcs(self, inputs, stack, iteration):
        x, x_skip, a = inputs

        itr = 1 if self.share_weights and iteration >= 1 else iteration
        kernel_1, kernel_2, bias = self.kernels[stack][itr]

        output = ops.dot(x, kernel_1)
        output = modal_dot(a, output)

        skip = ops.dot(x_skip, kernel_2)
        skip = self.dropout(skip)
        output += skip

        if self.use_bias:
            output = ops.add(output, bias)
        output = self.gcn_activation(output)

        return output

    @property
    def config(self):
        return {
            "channels": self.channels,
            "iterations": self.iterations,
            "order": self.order,
            "share_weights": self.share_weights,
            "gcn_activation": activations.serialize(self.gcn_activation),
            "dropout_rate": self.dropout_rate,
        }

    @staticmethod
    def preprocess(a):
        return normalized_adjacency(a, symmetric=True)

Bases: MessagePassing

k3_node.layers.CrystalConv Implementation of Crystal Graph Convolutional Neural Networks (CGCNN) layer

Parameters:

Name Type Description Default
aggregate

Aggregation function to use (one of 'sum', 'mean', 'max').

'sum'
activation

Activation function to use.

None
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional arguments to pass to the MessagePassing superclass.

{}
Source code in k3_node/layers/conv/crystal_conv.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class CrystalConv(MessagePassing):
    """
    `k3_node.layers.CrystalConv`
    Implementation of Crystal Graph Convolutional Neural Networks (CGCNN) layer

    Args:
        aggregate: Aggregation function to use (one of 'sum', 'mean', 'max').
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        bias_regularizer: Regularizer for the bias vector.
        activity_regularizer: Regularizer for the output.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional arguments to pass to the `MessagePassing` superclass. 
    """
    def __init__(
        self,
        aggregate="sum",
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):

        super().__init__(
            aggregate=aggregate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )

    def build(self, input_shape):
        assert len(input_shape) >= 2
        layer_kwargs = dict(
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            dtype=self.dtype,
        )
        channels = input_shape[0][-1]
        self.dense_f = Dense(channels, activation="sigmoid", **layer_kwargs)
        self.dense_s = Dense(channels, activation=self.activation, **layer_kwargs)

        self.built = True

    def message(self, x, e=None):
        x_i = self.get_targets(x)
        x_j = self.get_sources(x)

        to_concat = [x_i, x_j]
        if e is not None:
            to_concat += [e]
        z = ops.concatenate(to_concat, axis=-1)
        output = self.dense_s(z) * self.dense_f(z)

        return output

    def update(self, embeddings, x=None):
        return x + embeddings

Bases: Conv

k3_node.layers.DiffusionConv Implementation of Diffusion Convolutional Neural Networks (DCNN) layer

Parameters:

Name Type Description Default
channels

The number of output channels.

required
K

The number of diffusion steps.

6
activation

Activation function to use.

'tanh'
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
**kwargs

Additional arguments to pass to the Conv superclass.

{}
Source code in k3_node/layers/conv/diffusion_conv.py
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class DiffusionConv(Conv):
    """
    `k3_node.layers.DiffusionConv`
    Implementation of Diffusion Convolutional Neural Networks (DCNN) layer

    Args:
        channels: The number of output channels.
        K: The number of diffusion steps.
        activation: Activation function to use.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        **kwargs: Additional arguments to pass to the `Conv` superclass.
    """
    def __init__(
        self,
        channels,
        K=6,
        activation="tanh",
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        kernel_constraint=None,
        **kwargs,
    ):
        super().__init__(
            activation=activation,
            kernel_initializer=kernel_initializer,
            kernel_regularizer=kernel_regularizer,
            kernel_constraint=kernel_constraint,
            **kwargs,
        )

        self.channels = channels
        self.K = K + 1

    def build(self, input_shape):
        self.filters = [
            DiffuseFeatures(
                num_diffusion_steps=self.K,
                kernel_initializer=self.kernel_initializer,
                kernel_regularizer=self.kernel_regularizer,
                kernel_constraint=self.kernel_constraint,
            )
            for _ in range(self.channels)
        ]

    def apply_filters(self, x, a):
        diffused_features = []

        for diffusion in self.filters:
            diffused_feature = diffusion((x, a))
            diffused_features.append(diffused_feature)

        return ops.concatenate(diffused_features, -1)

    def call(self, inputs):
        x, a = inputs
        output = self.apply_filters(x, a)

        output = self.activation(output)

        return output

    @property
    def config(self):
        return {"channels": self.channels, "K": self.K - 1}

    @staticmethod
    def preprocess(a):
        return normalized_adjacency(a)

Bases: MessagePassing

k3_node.layers.GatedGraphConv

Implementation of Gated Graph Convolution (GGC) layer

Parameters:

Name Type Description Default
channels

The number of output channels.

required
n_layers

The number of GGC layers to stack.

required
activation

Activation function to use.

None
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional arguments to pass to the MessagePassing superclass.

{}
Source code in k3_node/layers/conv/gated_graph_conv.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
class GatedGraphConv(MessagePassing):
    """
    `k3_node.layers.GatedGraphConv` 

    Implementation of Gated Graph Convolution (GGC) layer

    Args:
        channels: The number of output channels.
        n_layers: The number of GGC layers to stack.
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        bias_regularizer: Regularizer for the bias vector.
        activity_regularizer: Regularizer for the output.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional arguments to pass to the `MessagePassing` superclass.

    """
    def __init__(
        self,
        channels,
        n_layers,
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):
        super().__init__(
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )
        self.channels = channels
        self.n_layers = n_layers

    def build(self, input_shape):
        assert len(input_shape) >= 2
        self.kernel = self.add_weight(
            name="kernel",
            shape=(self.n_layers, self.channels, self.channels),
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        self.rnn = GRUCell(
            self.channels,
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            activity_regularizer=self.activity_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
            use_bias=self.use_bias,
            dtype=self.dtype,
        )
        self.built = True

    def call(self, inputs):
        x, a, _ = self.get_inputs(inputs)
        F = ops.shape(x)[-1]
        assert F <= self.channels
        to_pad = self.channels - F
        ndims = len(x.shape) - 1
        output = ops.pad(x, [[0, 0]] * ndims + [[0, to_pad]])
        for i in range(self.n_layers):
            m = ops.matmul(output, self.kernel[i])
            m = self.propagate(m, a)
            output = self.rnn(m, [output])[0]

        output = self.activation(output)
        return output

    @property
    def config(self):
        return {
            "channels": self.channels,
            "n_layers": self.n_layers,
        }

Bases: Layer

k3_node.layers.GraphConvolution Implementation of Graph Convolution (GCN) layer

Parameters:

Name Type Description Default
units

Positive integer, dimensionality of the output space.

required
activation

Activation function to use.

None
use_bias

Whether to add a bias to the linear transformation.

True
final_layer

Deprecated, use tf.gather or GatherIndices instead.

None
input_dim

Deprecated, use keras.layers.Input with input_shape instead.

None
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_initializer

Initializer for the bias vector.

'zeros'
bias_regularizer

Regularizer for the bias vector.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional arguments to pass to the Layer superclass.

{}
Source code in k3_node/layers/conv/gcn.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
class GraphConvolution(Layer):
    """
    `k3_node.layers.GraphConvolution` 
    Implementation of Graph Convolution (GCN) layer

    Args:
        units: Positive integer, dimensionality of the output space.
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        final_layer: Deprecated, use tf.gather or GatherIndices instead.
        input_dim: Deprecated, use `keras.layers.Input` with `input_shape` instead.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        bias_regularizer: Regularizer for the bias vector.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional arguments to pass to the `Layer` superclass.
    """
    def __init__(
        self,
        units,
        activation=None,
        use_bias=True,
        final_layer=None,
        input_dim=None,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        kernel_constraint=None,
        bias_initializer="zeros",
        bias_regularizer=None,
        bias_constraint=None,
        **kwargs,
    ):
        if "input_shape" not in kwargs and input_dim is not None:
            kwargs["input_shape"] = (input_dim,)

        self.units = units
        self.activation = activations.get(activation)
        self.use_bias = use_bias
        if final_layer is not None:
            raise ValueError(
                "'final_layer' is not longer supported, use 'tf.gather' or 'GatherIndices' separately"
            )

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.bias_constraint = constraints.get(bias_constraint)

        super().__init__(**kwargs)

    def build(self, input_shapes):
        feat_shape = input_shapes[0]
        input_dim = int(feat_shape[-1])

        self.kernel = self.add_weight(
            shape=(1, input_dim, self.units),
            initializer=self.kernel_initializer,
            name="kernel",
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )

        if self.use_bias:
            self.bias = self.add_weight(
                shape=(self.units,),
                initializer=self.bias_initializer,
                name="bias",
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        else:
            self.bias = None
        self.built = True

    def call(self, inputs):
        features, A = inputs

        # Calculate the layer operation of GCN

        h_graph = dot((A, features), axes=1)
        b = ops.shape(h_graph)[0]
        kernel = ops.repeat(self.kernel, b, axis=0)
        output = dot((h_graph, kernel), axes=(-1, 1))

        # Add optional bias & apply activation
        if self.bias is not None:
            output += self.bias
        output = self.activation(output)

        return output

Bases: MessagePassing

k3_node.layers.GeneralConv Implementation of General Graph Convolution

Parameters:

Name Type Description Default
channels

The number of output channels.

256
batch_norm

Whether to use batch normalization.

True
dropout

The dropout rate.

0.0
aggregate

Aggregation function to use (one of 'sum', 'mean', 'max').

'sum'
activation

Activation function to use.

'prelu'
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional arguments to pass to the MessagePassing superclass.

{}
Source code in k3_node/layers/conv/general_conv.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
class GeneralConv(MessagePassing):
    """
    `k3_node.layers.GeneralConv`
    Implementation of General Graph Convolution

    Args:
        channels: The number of output channels.
        batch_norm: Whether to use batch normalization.
        dropout: The dropout rate.
        aggregate: Aggregation function to use (one of 'sum', 'mean', 'max').
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        bias_regularizer: Regularizer for the bias vector.
        activity_regularizer: Regularizer for the output.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional arguments to pass to the `MessagePassing` superclass. 
    """
    def __init__(
        self,
        channels=256,
        batch_norm=True,
        dropout=0.0,
        aggregate="sum",
        activation="prelu",
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):
        super().__init__(
            aggregate=aggregate,
            activation=None,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )
        self.channels = channels
        self.dropout_rate = dropout
        self.use_batch_norm = batch_norm
        if activation == "prelu" or "prelu" in kwargs:
            self.activation = PReLU()
        else:
            self.activation = activations.get(activation)

    def build(self, input_shape):
        input_dim = input_shape[0][-1]
        self.dropout = Dropout(self.dropout_rate)
        if self.use_batch_norm:
            self.batch_norm = BatchNormalization()
        self.kernel = self.add_weight(
            shape=(input_dim, self.channels),
            initializer=self.kernel_initializer,
            name="kernel",
            regularizer=self.kernel_regularizer,
            constraint=self.kernel_constraint,
        )
        if self.use_bias:
            self.bias = self.add_weight(
                shape=(self.channels,),
                initializer=self.bias_initializer,
                name="bias",
                regularizer=self.bias_regularizer,
                constraint=self.bias_constraint,
            )
        self.built = True

    def call(self, inputs, **kwargs):
        x, a, _ = self.get_inputs(inputs)

        # TODO: a = add_self_loops(a)

        x = ops.matmul(x, self.kernel)
        if self.use_bias:
            x = ops.add(x, self.bias)
        if self.use_batch_norm:
            x = self.batch_norm(x)
        x = self.dropout(x)
        x = self.activation(x)

        return self.propagate(x, a)

    @property
    def config(self):
        config = {
            "channels": self.channels,
        }
        if self.activation.__class__.__name__ == "PReLU":
            config["prelu"] = True

        return config

Bases: MessagePassing

k3_node.layers.GINConv Implementation of Graph Isomorphism Network (GIN) layer

Parameters:

Name Type Description Default
channels

The number of output channels.

required
epsilon

The epsilon parameter for the MLP.

None
mlp_hidden

A list of hidden channels for the MLP.

None
mlp_activation

The activation function to use in the MLP.

'relu'
mlp_batchnorm

Whether to use batch normalization in the MLP.

True
aggregate

Aggregation function to use (one of 'sum', 'mean', 'max').

'sum'
activation

Activation function to use.

None
use_bias

Whether to add a bias to the linear transformation.

True
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
bias_initializer

Initializer for the bias vector.

'zeros'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
bias_regularizer

Regularizer for the bias vector.

None
activity_regularizer

Regularizer for the output.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_constraint

Constraint for the bias vector.

None
**kwargs

Additional arguments to pass to the MessagePassing superclass.

{}
Source code in k3_node/layers/conv/gin_conv.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
class GINConv(MessagePassing):
    """
    `k3_node.layers.GINConv` 
    Implementation of Graph Isomorphism Network (GIN) layer

    Args:
        channels: The number of output channels.
        epsilon: The epsilon parameter for the MLP.
        mlp_hidden: A list of hidden channels for the MLP.
        mlp_activation: The activation function to use in the MLP.
        mlp_batchnorm: Whether to use batch normalization in the MLP.
        aggregate: Aggregation function to use (one of 'sum', 'mean', 'max').
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        bias_regularizer: Regularizer for the bias vector.
        activity_regularizer: Regularizer for the output.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_constraint: Constraint for the bias vector.
        **kwargs: Additional arguments to pass to the `MessagePassing` superclass.
    """
    def __init__(
        self,
        channels,
        epsilon=None,
        mlp_hidden=None,
        mlp_activation="relu",
        mlp_batchnorm=True,
        aggregate="sum",
        activation=None,
        use_bias=True,
        kernel_initializer="glorot_uniform",
        bias_initializer="zeros",
        kernel_regularizer=None,
        bias_regularizer=None,
        activity_regularizer=None,
        kernel_constraint=None,
        bias_constraint=None,
        **kwargs,
    ):
        super().__init__(
            aggregate=aggregate,
            activation=activation,
            use_bias=use_bias,
            kernel_initializer=kernel_initializer,
            bias_initializer=bias_initializer,
            kernel_regularizer=kernel_regularizer,
            bias_regularizer=bias_regularizer,
            activity_regularizer=activity_regularizer,
            kernel_constraint=kernel_constraint,
            bias_constraint=bias_constraint,
            **kwargs,
        )
        self.channels = channels
        self.epsilon = epsilon
        self.mlp_hidden = mlp_hidden if mlp_hidden else []
        self.mlp_activation = activations.get(mlp_activation)
        self.mlp_batchnorm = mlp_batchnorm

    def build(self, input_shape):
        assert len(input_shape) >= 2
        layer_kwargs = dict(
            kernel_initializer=self.kernel_initializer,
            bias_initializer=self.bias_initializer,
            kernel_regularizer=self.kernel_regularizer,
            bias_regularizer=self.bias_regularizer,
            kernel_constraint=self.kernel_constraint,
            bias_constraint=self.bias_constraint,
        )

        self.mlp = Sequential()
        for channels in self.mlp_hidden:
            self.mlp.add(Dense(channels, self.mlp_activation, **layer_kwargs))
            if self.mlp_batchnorm:
                self.mlp.add(BatchNormalization())
        self.mlp.add(
            Dense(
                self.channels, self.activation, use_bias=self.use_bias, **layer_kwargs
            )
        )

        if self.epsilon is None:
            self.eps = self.add_weight(shape=(1,), initializer="zeros", name="eps")
        else:
            # If epsilon is given, keep it constant
            self.eps = ops.cast(self.epsilon, self.dtype)
        self.one = ops.cast(1, self.dtype)

        self.built = True

    def call(self, inputs, **kwargs):
        x, a, _ = self.get_inputs(inputs)
        output = self.mlp((self.one + self.eps) * x + self.propagate(x, a))

        return output

    @property
    def config(self):
        return {
            "channels": self.channels,
            "epsilon": self.epsilon,
            "mlp_hidden": self.mlp_hidden,
            "mlp_activation": self.mlp_activation,
            "mlp_batchnorm": self.mlp_batchnorm,
        }

Bases: Layer

k3_node.layers.GraphAttention Implementation of Graph Attention (GAT) layer

Parameters:

Name Type Description Default
units

Positive integer, dimensionality of the output space.

required
attn_heads

Positive integer, number of attention heads.

1
attn_heads_reduction

{'concat', 'average'} Method for reducing attention heads.

'concat'
in_dropout_rate

Dropout rate applied to the input (node features).

0.0
attn_dropout_rate

Dropout rate applied to attention coefficients.

0.0
activation

Activation function to use.

'relu'
use_bias

Whether to add a bias to the linear transformation.

True
final_layer

Deprecated, use tf.gather or GatherIndices instead.

None
saliency_map_support

Whether to support saliency map calculations.

False
kernel_initializer

Initializer for the kernel weights matrix.

'glorot_uniform'
kernel_regularizer

Regularizer for the kernel weights matrix.

None
kernel_constraint

Constraint for the kernel weights matrix.

None
bias_initializer

Initializer for the bias vector.

'zeros'
bias_regularizer

Regularizer for the bias vector.

None
bias_constraint

Constraint for the bias vector.

None
attn_kernel_initializer

Initializer for the attention kernel weights matrix.

'glorot_uniform'
attn_kernel_regularizer

Regularizer for the attention kernel weights matrix.

None
attn_kernel_constraint

Constraint for the attention kernel weights matrix.

None
**kwargs

Additional arguments to pass to the Layer superclass.

{}
Source code in k3_node/layers/conv/graph_attention.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class GraphAttention(Layer):
    """
    `k3_node.layers.GraphAttention`
    Implementation of Graph Attention (GAT) layer

    Args:
        units: Positive integer, dimensionality of the output space.
        attn_heads: Positive integer, number of attention heads.
        attn_heads_reduction: {'concat', 'average'} Method for reducing attention heads.
        in_dropout_rate: Dropout rate applied to the input (node features).
        attn_dropout_rate: Dropout rate applied to attention coefficients.
        activation: Activation function to use.
        use_bias: Whether to add a bias to the linear transformation.
        final_layer: Deprecated, use tf.gather or GatherIndices instead.
        saliency_map_support: Whether to support saliency map calculations.
        kernel_initializer: Initializer for the `kernel` weights matrix.
        kernel_regularizer: Regularizer for the `kernel` weights matrix.
        kernel_constraint: Constraint for the `kernel` weights matrix.
        bias_initializer: Initializer for the bias vector.
        bias_regularizer: Regularizer for the bias vector.
        bias_constraint: Constraint for the bias vector.
        attn_kernel_initializer: Initializer for the attention kernel weights matrix.
        attn_kernel_regularizer: Regularizer for the attention kernel weights matrix.
        attn_kernel_constraint: Constraint for the attention kernel weights matrix.
        **kwargs: Additional arguments to pass to the `Layer` superclass.
    """
    def __init__(
        self,
        units,
        attn_heads=1,
        attn_heads_reduction="concat",  # {'concat', 'average'}
        in_dropout_rate=0.0,
        attn_dropout_rate=0.0,
        activation="relu",
        use_bias=True,
        final_layer=None,
        saliency_map_support=False,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        kernel_constraint=None,
        bias_initializer="zeros",
        bias_regularizer=None,
        bias_constraint=None,
        attn_kernel_initializer="glorot_uniform",
        attn_kernel_regularizer=None,
        attn_kernel_constraint=None,
        **kwargs,
    ):
        if attn_heads_reduction not in {"concat", "average"}:
            raise ValueError(
                "{}: Possible heads reduction methods: concat, average; received {}".format(
                    type(self).__name__, attn_heads_reduction
                )
            )

        self.units = units  # Number of output features (F' in the paper)
        self.attn_heads = attn_heads  # Number of attention heads (K in the paper)
        self.attn_heads_reduction = attn_heads_reduction  # Eq. 5 and 6 in the paper
        self.in_dropout_rate = in_dropout_rate  # dropout rate for node features
        self.attn_dropout_rate = attn_dropout_rate  # dropout rate for attention coefs
        self.activation = activations.get(activation)  # Eq. 4 in the paper
        self.use_bias = use_bias
        if final_layer is not None:
            raise ValueError(
                "'final_layer' is not longer supported, use 'tf.gather' or 'GatherIndices' separately"
            )

        self.saliency_map_support = saliency_map_support
        # Populated by build()
        self.kernels = []  # Layer kernels for attention heads
        self.biases = []  # Layer biases for attention heads
        self.attn_kernels = []  # Attention kernels for attention heads

        if attn_heads_reduction == "concat":
            # Output will have shape (..., K * F')
            self.output_dim = self.units * self.attn_heads
        else:
            # Output will have shape (..., F')
            self.output_dim = self.units

        self.kernel_initializer = initializers.get(kernel_initializer)
        self.kernel_regularizer = regularizers.get(kernel_regularizer)
        self.kernel_constraint = constraints.get(kernel_constraint)
        self.bias_initializer = initializers.get(bias_initializer)
        self.bias_regularizer = regularizers.get(bias_regularizer)
        self.bias_constraint = constraints.get(bias_constraint)
        self.attn_kernel_initializer = initializers.get(attn_kernel_initializer)
        self.attn_kernel_regularizer = regularizers.get(attn_kernel_regularizer)
        self.attn_kernel_constraint = constraints.get(attn_kernel_constraint)

        super().__init__(**kwargs)

    def build(self, input_shapes):
        feat_shape = input_shapes[0]
        input_dim = int(feat_shape[-1])

        # Variables to support integrated gradients
        self.delta = self.add_weight(
            name="ig_delta", shape=(), trainable=False, initializer=initializers.ones()
        )
        self.non_exist_edge = self.add_weight(
            name="ig_non_exist_edge",
            shape=(),
            trainable=False,
            initializer=initializers.zeros(),
        )

        # Initialize weights for each attention head
        for head in range(self.attn_heads):
            # Layer kernel
            kernel = self.add_weight(
                shape=(input_dim, self.units),
                initializer=self.kernel_initializer,
                regularizer=self.kernel_regularizer,
                constraint=self.kernel_constraint,
                name="kernel_{}".format(head),
            )
            self.kernels.append(kernel)

            # # Layer bias
            if self.use_bias:
                bias = self.add_weight(
                    shape=(self.units,),
                    initializer=self.bias_initializer,
                    regularizer=self.bias_regularizer,
                    constraint=self.bias_constraint,
                    name="bias_{}".format(head),
                )
                self.biases.append(bias)

            # Attention kernels
            attn_kernel_self = self.add_weight(
                shape=(self.units, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name="attn_kernel_self_{}".format(head),
            )
            attn_kernel_neighs = self.add_weight(
                shape=(self.units, 1),
                initializer=self.attn_kernel_initializer,
                regularizer=self.attn_kernel_regularizer,
                constraint=self.attn_kernel_constraint,
                name="attn_kernel_neigh_{}".format(head),
            )
            self.attn_kernels.append([attn_kernel_self, attn_kernel_neighs])
        self.built = True

    def call(self, inputs):
        X = inputs[0]  # Node features (1 x N x F)
        A = inputs[1]  # Adjacency matrix (1 X N x N)
        N = ops.shape(A)[-1]

        assert len(ops.shape(A)) == 2, f"Adjacency matrix A should be 2-D"

        outputs = []
        for head in range(self.attn_heads):
            kernel = self.kernels[head]  # W in the paper (F x F')
            attention_kernel = self.attn_kernels[
                head
            ]  # Attention kernel a in the paper (2F' x 1)

            # Compute inputs to attention network

            features = ops.dot(X, kernel)  # (N x F')

            # Compute feature combinations
            # Note: [[a_1], [a_2]]^T [[Wh_i], [Wh_2]] = [a_1]^T [Wh_i] + [a_2]^T [Wh_j]
            attn_for_self = ops.dot(
                features, attention_kernel[0]
            )  # (N x 1), [a_1]^T [Wh_i]
            attn_for_neighs = ops.dot(
                features, attention_kernel[1]
            )  # (N x 1), [a_2]^T [Wh_j]

            # Attention head a(Wh_i, Wh_j) = a^T [[Wh_i], [Wh_j]]
            dense = attn_for_self + ops.transpose(
                attn_for_neighs
            )  # (N x N) via broadcasting

            dense = LeakyReLU(0.2)(dense)

            if not self.saliency_map_support:
                mask = -10e9 * (1.0 - A)
                dense += mask
                dense = ops.softmax(dense)  # (N x N), Eq. 3 of the paper

            else:
                # dense = dense - tf.reduce_max(dense)
                # GAT with support for saliency calculations
                W = (self.delta * A) * ops.exp(
                    dense - ops.max(dense, axis=1, keepdims=True)
                ) * (1 - self.non_exist_edge) + self.non_exist_edge * (
                    A + self.delta * (ops.ones((N, N)) - A) + ops.eye(N)
                ) * ops.exp(
                    dense - ops.max(dense, axis=1, keepdims=True)
                )
                dense = W / ops.sum(W, axis=1, keepdims=True)

            # Apply dropout to features and attention coefficients
            dropout_feat = Dropout(self.in_dropout_rate)(features)  # (N x F')
            dropout_attn = Dropout(self.attn_dropout_rate)(dense)  # (N x N)

            # Linear combination with neighbors' features [YT: see Eq. 4]
            node_features = ops.dot(dropout_attn, dropout_feat)  # (N x F')

            if self.use_bias:
                node_features = ops.add(node_features, self.biases[head])

            # Add output of attention head to final output
            outputs.append(node_features)

        # Aggregate the heads' output according to the reduction method
        if self.attn_heads_reduction == "concat":
            output = ops.concatenate(outputs, axis=1)  # (N x KF')
        else:
            output = ops.mean(ops.stack(outputs), axis=0)  # N x F')

        output = self.activation(output)

        return output

Bases: Layer

k3_node.layers.conv.MessagePassing Base class for message passing layers.

Parameters:

Name Type Description Default
aggregate

Aggregation function to use (one of 'sum', 'mean', 'max').

'sum'
**kwargs

Additional arguments to pass to the Layer superclass.

{}
Source code in k3_node/layers/conv/message_passing.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class MessagePassing(layers.Layer):
    """
    `k3_node.layers.conv.MessagePassing`
    Base class for message passing layers.

    Args:
        aggregate: Aggregation function to use (one of 'sum', 'mean', 'max').
        **kwargs: Additional arguments to pass to the `Layer` superclass.
    """
    def __init__(self, aggregate="sum", **kwargs):
        super().__init__(**{k: v for k, v in kwargs.items() if is_keras_kwarg(k)})
        self.kwargs_keys = []
        for key in kwargs:
            if is_layer_kwarg(key):
                attr = kwargs[key]
                attr = deserialize_kwarg(key, attr)
                self.kwargs_keys.append(key)
                setattr(self, key, attr)

        self.msg_signature = inspect.signature(self.message).parameters
        self.agg_signature = inspect.signature(self.aggregate).parameters
        self.upd_signature = inspect.signature(self.update).parameters
        self.agg = deserialize_scatter(aggregate)

    def call(self, inputs, **kwargs):
        x, a, e = self.get_inputs(inputs)
        return self.propagate(x, a, e)

    def build(self, input_shape):
        self.built = True

    def propagate(self, x, a, e=None, **kwargs):
        self.n_nodes = ops.shape(x)[-2]
        self.index_sources, self.index_targets = get_source_target(a)

        msg_kwargs = self.get_kwargs(x, a, e, self.msg_signature, kwargs)
        messages = self.message(x, **msg_kwargs)

        agg_kwargs = self.get_kwargs(x, a, e, self.agg_signature, kwargs)
        embeddings = self.aggregate(messages, **agg_kwargs)

        upd_kwargs = self.get_kwargs(x, a, e, self.upd_signature, kwargs)
        output = self.update(embeddings, **upd_kwargs)
        return output

    def message(self, x, **kwargs):
        return self.get_sources(x)

    def aggregate(self, messages, **kwargs):
        return self.agg(messages, self.index_targets, self.n_nodes)

    def update(self, embeddings, **kwargs):
        return embeddings

    def get_targets(self, x):
        return ops.take(x, self.index_targets, axis=-2)

    def get_sources(self, x):
        return ops.take(x, self.index_sources, axis=-2)

    def get_kwargs(self, x, a, e, signature, kwargs):
        output = {}
        for k in signature.keys():
            if signature[k].default is inspect.Parameter.empty or k == "kwargs":
                pass
            elif k == "x":
                output[k] = x
            elif k == "a":
                output[k] = a
            elif k == "e":
                output[k] = e
            elif k in kwargs:
                output[k] = kwargs[k]
            else:
                raise ValueError("Missing key {} for signature {}".format(k, signature))

        return output

    @staticmethod
    def get_inputs(inputs):
        if len(inputs) == 3:
            x, a, e = inputs
            assert len(ops.shape(e)) in (2, 3), "E must have rank 2 or 3"
        elif len(inputs) == 2:
            x, a = inputs
            e = None
        else:
            raise ValueError(
                "Expected 2 or 3 inputs tensors (X, A, E), got {}.".format(len(inputs))
            )
        assert len(ops.shape(a)) == 2, "A must have rank 2"

        return x, a, e

    @staticmethod
    def preprocess(a):
        return a

    def get_config(self):
        mp_config = {"aggregate": serialize_scatter(self.agg)}
        keras_config = {}
        for key in self.kwargs_keys:
            keras_config[key] = serialize_kwarg(key, getattr(self, key))
        base_config = super().get_config()

        return {**base_config, **keras_config, **mp_config, **self.config}

    @property
    def config(self):
        return {}

Bases: Layer

k3_node.layers.PPNPPropagation Implementation of PPNP layer

Parameters:

Name Type Description Default
units

Positive integer, dimensionality of the output space.

required
final_layer

Deprecated, use tf.gather or GatherIndices instead.

None
input_dim

Deprecated, use keras.layers.Input with input_shape instead.

None
**kwargs

Additional arguments to pass to the Layer superclass.

{}
Source code in k3_node/layers/conv/ppnp.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class PPNPPropagation(Layer):
    """
    `k3_node.layers.PPNPPropagation`
    Implementation of PPNP layer

    Args:
        units: Positive integer, dimensionality of the output space.
        final_layer: Deprecated, use tf.gather or GatherIndices instead.
        input_dim: Deprecated, use `keras.layers.Input` with `input_shape` instead.
        **kwargs: Additional arguments to pass to the `Layer` superclass. 
    """
    def __init__(self, units, final_layer=None, input_dim=None, **kwargs):
        if "input_shape" not in kwargs and input_dim is not None:
            kwargs["input_shape"] = (input_dim,)

        super().__init__(**kwargs)

        self.units = units
        if final_layer is not None:
            raise ValueError("'final_layer' is not longer supported.")

    def get_config(self):
        config = {"units": self.units}

        base_config = super().get_config()
        return {**base_config, **config}

    def compute_output_shape(self, input_shapes):
        feature_shape, *As_shapes = input_shapes

        batch_dim = feature_shape[0]
        out_dim = feature_shape[1]

        return batch_dim, out_dim, self.units

    def build(self, input_shapes):
        self.built = True

    def call(self, inputs):
        x, a = inputs
        n_nodes, _ = ops.shape(x)
        output = ops.dot(x, a)
        return output

Bases: Layer

k3_node.layers.SAGEConv Implementation of GraphSAGE layer

Parameters:

Name Type Description Default
out_channels

The number of output channels.

required
normalize

Whether to normalize the output.

False
bias

Whether to add a bias to the linear transformation.

True
Source code in k3_node/layers/conv/sage_conv.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class SAGEConv(layers.Layer):
    """
    `k3_node.layers.SAGEConv`
    Implementation of GraphSAGE layer

    Args:
        out_channels: The number of output channels.
        normalize: Whether to normalize the output.
        bias: Whether to add a bias to the linear transformation.
    """
    def __init__(self, out_channels, normalize=False, bias=True):
        super().__init__()
        self.out_channels = out_channels
        self.normalize = normalize

        self.lin_rel = layers.Dense(out_channels, use_bias=False)
        self.lin_root = layers.Dense(out_channels, use_bias=bias)

    def call(self, x, adj, mask=None):
        # x = ops.expand_dims(x, axis=0) if len(ops.shape(x)) == 2 else x
        # adj = ops.expand_dims(adj, axis=0) if len(ops.shape(adj)) == 2 else adj

        out = ops.matmul(adj, x)
        out = out / ops.clip(
            ops.sum(adj, axis=-1, keepdims=True), x_min=1.0, x_max=float("inf")
        )
        out = self.lin_rel(out) + self.lin_root(x)

        if self.normalize:
            out = keras.utils.normalize(out, axis=-1)
        if mask is not None:
            mask = ops.expand_dims(mask, axis=-1)
            out = ops.multiply(out, mask)

        return out