[Papers Review & Implementation] Hyperbolic Neural Networks - Part 2 : Implementing Hyperbolic Graph Convolutional Networks (HGCN)

[Papers Review & Implementation] Hyperbolic Neural Networks - Part 2 : Implementing Hyperbolic Graph Convolutional Networks (HGCN)

2023, Oct 23    

Outlines



Reference



HGCN : Implementing GCN in Hyperbolic Space


  • Real-world graphs, such as molecules and social networks, have highly complex hierarchical structures that expands exponentially in space.

  • Embedding these tree-like structures into Euclidean space that only grows polynomially can cause large distortion to the graph representation.

  • Hyperbolic space, with its intrinsic property to grow exponentially, allows optimal embedding representation of real-world hierarchical data.

  • HGCN (Hyperbolic Graph Convolutional Networks) successfully combines this geometrical advantage of hyperbolic space with the great inductive capacity of GCN and shows remarkably improved performance on both link prediction and node classification compared to other GNN-based baselines.

Figure 3. Visualization of Euclidean (Left) and Hyperbolic (Right) Embeddings on Poincare Model

  

  • Embeddings from HGCN show better class separation (indicated by different colors)

  • HGCN claims to make three major contributions.

    1. Derive the core operations of GCN in the hyperboloid model to transform Euclidean input features onto the hyperbolic space.

    2. Employ attention based neighborhood aggregation that improves the expressivness of the networks by inductively reflecting the neighboring features weighted by attention.

    3. Introduce trainable curvature for each layer of the neural networks, which facilitates an optimization of the model by learning the right scale of embeddings at each layer.


0. Hyperboloid Model


  

  • Part 1 describes the hyperbolic model primarily focusing on the Poincare disk model, which is well-suited for visual representation of hierarchical graph structure.

  • However, Poincare model is not a feasible option to apply in neural networks due to the weak computational efficacy with its complex metric and distance function.

  • In contrast, hyperboloid model (also known as Minkowski or Lorentz model) has much more straightforward and simple metric and distance functions, making it more applicable to hyperbolic neural networks.

  • Because of these practical reasons, HGCN adopts hyperboloid model instead of Poincare model as its geometric framework.

  • Although Poincare model is not directly utilized in HGCN, it can still be employed to visualize the embeddings, as the two models are isometric to each other.

  • This section will review the basic mathematical machineries for a hyperboloid manifold with the implementation code for each (all from the class Hyperboloid)


  1. Riemannian Manifold

    • hyperboloid with constant negative curvature -1/K where K > 0: $\large (\mathbb{H}^{d, K}, g^{K}_{x})$

    • $\large g_{x} \, = \, \langle x, y \rangle_{\mathcal{L}} \, := \, -x_{0}y_{0} \,+ \, x_{1}y_{1} \,+ \, x_{2}y_{2} \,+ … \, + \, x_{n}y_{n}$

    • $\large \mathbb{H}^{d, K}$ := {$\large x \, \in \, \mathbb{R}^{d+1} \, | \, \langle x,x \rangle_{\mathcal{L}} \, = \, -K \, , x_{0} \, > \, 0$}

    • Then, $\large x_{0} \, = \, \sqrt{1 \, + \, |x|^{2}}$

    • Distance Function : $\large d_{\mathcal{L}}^{1}(x,y) = arcosh(-<x,y>_{\mathcal{L}})$

     def minkowski_dot(self, x, y, keepdim=True):
         res = torch.sum(x * y, dim=-1) - 2 * x[..., 0] * y[..., 0]
         if keepdim:
             res = res.view(res.shape + (1,))
         return res
    
     def minkowski_norm(self, u, keepdim=True):
         dot = self.minkowski_dot(u, u, keepdim=keepdim)
         return torch.sqrt(torch.clamp(dot, min=self.eps[u.dtype]))   
    
     def sqdist(self, x, y, c):
         K = 1. / c
         prod = self.minkowski_dot(x, y)
         theta = torch.clamp(-prod / K, min=1.0 + self.eps[x.dtype])
         sqdist = K * arcosh(theta) ** 2
         # clamp distance to avoid nans in Fermi-Dirac decoder
         return torch.clamp(sqdist, max=50.0)
    


  2. Tangent Space

    • $\large T_{x}\mathbb{H}^{d, K}$ : Hyperplane on $\large \mathbb{R}^{d+1}$, which is the most close Euclidean approximation of the hyperboloid with curvature -1/K around the point x.

    • Alternatively, it’s a space that contains all possible tanget vectors passing through the point x on the manifold.

    • $\large T_{x}\mathbb{H}^{d, K}$ = {$\large v \, \in \, \mathbb{R}^{d+1} \, | \, \langle x, v \rangle_{\mathcal{L}} = 0$}

    • Origin of the manifold and $\large T_{o}\mathbb{H}^{d, K}$ is defined as

      • $\large o \, :=$ {$\large \sqrt{K}, 0, … , 0$} and $\large (0, x^{0,E})$ such that $\large \langle (0, x^{0,E}), o \rangle_{\mathbb{L}} = 0$
  3. Projections

    • Projections of a point $\large (x_{0}, x_{1:d}) \, \in \, \mathbb{R}^{d+1}$ onto a hyperboloid $\large \mathbb{H}^{d,1}(x_{1:d})$

      • Considering the inner product rule of hyperboloid, can easily understand the projection rule.

    • Point $\large v \, \in \, \mathbb{R}^{d+1}$ projected to a tangent space $\large T_{x}\mathbb{H}^{d,1}$

      • Starting from the point v, moves along the vector x by an amount proportional to how much vector v is aligned in the direction of x.

    • implementation code here doesn’t exactly reflect the operations presented above.

     def proj(self, x, c):
         K = 1. / c
         d = x.size(-1) - 1
         y = x.narrow(-1, 1, d)
         y_sqnorm = torch.norm(y, p=2, dim=1, keepdim=True) ** 2 
         mask = torch.ones_like(x)
         mask[:, 0] = 0
         vals = torch.zeros_like(x)
         vals[:, 0:1] = torch.sqrt(torch.clamp(K + y_sqnorm, min=self.eps[x.dtype]))
         return vals + mask * x
    
     def proj_tan(self, u, x, c):
         K = 1. / c
         d = x.size(1) - 1
         ux = torch.sum(x.narrow(-1, 1, d) * u.narrow(-1, 1, d), dim=1, keepdim=True)
         mask = torch.ones_like(u)
         mask[:, 0] = 0
         vals = torch.zeros_like(u)
         vals[:, 0:1] = ux / torch.clamp(x[:, 0:1], min=self.eps[x.dtype])
         return vals + mask * u
    
     def proj_tan0(self, u, c):
         narrowed = u.narrow(-1, 0, 1)
         vals = torch.zeros_like(u)
         vals[:, 0:1] = narrowed
         return u - vals   
    


  4. Geodesic and Exponential/Logarithmic Map

    • Conditions for Geodesic, which is the shortest path from a point $\large x \, \in \, \mathbb{H}^{d,K}$ to another along the direction tangent to the curve.

      • For a tangent vector $\large u \, \in \, \mathcal{T}_{x}\mathbb{H}^{d, K}$ at point $\large x$, geodesic $\large \gamma(t)$ satisfies $\large \gamma(0) \, = \, x$ and $\large \dot{\gamma}(0) \, = \, u$.

      • Accleration along the geodesic is orthogonal to the manifold, which formally can be expressed as

        • $\large \ddot{\gamma}(t) \, \perp \, T_{\gamma(t)}\mathbb{H}^{d, K}$ w.r.t the metric defined on a hyperboloid.

        • In other words, $\large \langle \ddot{\gamma}(t), \dot{\gamma}(t) \rangle_{\mathbb{L}} \, = \, 0$

      • Satisfying all these, unique unit-speed ($\large ||\dot{\gamma}(t)||^{2} \, = \, 1$) geodesic at point $\large t$ on a hyperboloid with constant negative curvature $\large K$ can be written as

        • Note that hyperbolic curvature only affects the embeddings on the hyperboloid in a affine invariant way such that for any node in embeddings $\large H \, =$ {$\large h_{i}$} $\large \in \, \mathbb{H}^{d, K}$, Embeddings on a hyperboloid of curvature $K’$ satisfies

    • Considering that exponential mapping is done by following geodesic coordinate, which means exp map moves a tangent vector $\large v$ at $\large T_{x}\mathbb{H}$ along a geodesic starting from $\large x$ in the direction of unit vector $\large \frac{v}{||v||}$ for a distance proportional to $\large ||v||$, this leads to

    • Logarithmic Map

      • Exact inversion of the exponential map, $\large \log_{x}^{K}(\exp_{x}^{K}(v)) \, = \, v$ : Map a point $\large y$ (here, $\large \exp_{x}^{K}(v)$) on a manifold onto the tangent space of another point $\large x$ on the manifold.

      • Note that in order to take log and exp map in series, need to fix the intermediate space at same point.

        • For example, $\large \exp_{x}(\log_{x}(v))$. check that exponential mapping is performed at the same point x where the tangent space is mapped by logarithmic mapping.

        • Same applied for the opposite case.

      • From $\large \langle x, x \rangle_{\mathbb{L}} \, = \, -K$ and $\large \langle x, v \rangle_{\mathbb{L}} \, = \, 0$, and let $\large y \, = \, exp_{x}^{K}(v)$.

        • Then, inner product $\large \langle x, y \rangle_{\mathbb{L}} \, = \, -K\cosh(\frac{||v||_{\mathbb{L}}}{\sqrt{K}})$ (as $\large y$ is a simple addition of $\large x$ and $\large v$ component.)
      • This leads to

      • Then, taking inverse between $\large v$ and $\large y \, + \, \frac{1}{K}\langle x, y \rangle_{\mathbb{L}}x$ from the equation aobve, (sinh is bijective)

        • Finally, you can get $\large log^{K}_{x}(y)$

      • The final form of the logarithmic map is as follows

        , which comes from

     def expmap(self, u, x, c):
         K = 1. / c
         sqrtK = K ** 0.5
         normu = self.minkowski_norm(u)
         normu = torch.clamp(normu, max=self.max_norm)
         theta = normu / sqrtK
         theta = torch.clamp(theta, min=self.min_norm)
         result = cosh(theta) * x + sinh(theta) * u / theta
         return self.proj(result, c)
            
     def logmap(self, x, y, c):
         K = 1. / c
         xy = torch.clamp(self.minkowski_dot(x, y), max=-self.eps[x.dtype])
         u = y + xy * x * c
         normu = self.minkowski_norm(u)
         normu = torch.clamp(normu, min=self.min_norm)
         dist = self.sqdist(x, y, c) ** 0.5
         result = dist * u / normu
         return self.proj_tan(result, x, c)
    
     def expmap0(self, u, c):
         # map the vector u on its tangent space to the origin (point o) on the manifold. 
         # origin here stands for (sqrtK, 0, 0, ..., 0)
         # x : (N_batch, dim+1)
         K = 1. / c
         sqrtK = K ** 0.5
         d = u.size(-1) - 1
         x = u.narrow(-1, 1, d).view(-1, d)
         x_norm = torch.norm(x, p=2, dim=1, keepdim=True)
         x_norm = torch.clamp(x_norm, min=self.min_norm)
         theta = x_norm / sqrtK
         res = torch.ones_like(u)
         res[:, 0:1] = sqrtK * cosh(theta)
         res[:, 1:] = sqrtK * sinh(theta) * (x / x_norm)
         return self.proj(res, c)
    
     def logmap0(self, x, c):
         # project point x on the hyperboloid to the tangent space of the origin (o) of hyperboloid (sqrtK, 0, 0, ..., 0) 
         K = 1. / c
         sqrtK = K ** 0.5
         d = x.size(-1) - 1
         y = x.narrow(-1, 1, d).view(-1, d)    # don't need to be like y + xy*x*c as x here is the origin.
         y_norm = torch.norm(y, p=2, dim=1, keepdim=True)
         y_norm = torch.clamp(y_norm, min=self.min_norm)
         res = torch.zeros_like(x)
         theta = torch.clamp(x[:, 0:1] / sqrtK, min=1.0 + self.eps[x.dtype])   # -<o/sqrtK, x/sqrtK> = -<0, x>/K = -(-sqrtK * x0)/K = x0/sqrtK 
         res[:, 1:] = sqrtK * arcosh(theta) * y / y_norm
         return res    
    


  1. Parallel Transport

    • Transporting a tangent vector $\large v$ on $\large \mathcal{T}_{x} \mathbb{H}^{d,K}$

      to $\large \mathcal{T}_{y} \mathbb{H}^{d,K}$ along the geodesic connecting $\large x$ and $\large y$ preserving the metric.

      • The equation above is written for the hyperboloid of curvature 1 and you can simply convert it to curvature K by replacing all machineries (log maps, distance function, metric) specified for curvature K.
  • Note that hyperboloid model can be isometrically mapped to Poincare model and vice versa via the diffeomorphism mappings (bijective, invertible, smooth).

      def ptransp(self, x, y, u, c):
          logxy = self.logmap(x, y, c)
          logyx = self.logmap(y, x, c)
          sqdist = torch.clamp(self.sqdist(x, y, c), min=self.min_norm)
          alpha = self.minkowski_dot(logxy, u) / sqdist
          res = u - alpha * (logxy + logyx)
          return self.proj_tan(res, y, c)
    
      def ptransp0(self, x, u, c):
          # parallel transport of vector u (ToH) from the origin to the point x on the manifold. 
          K = 1. / c
          sqrtK = K ** 0.5
          x0 = x.narrow(-1, 0, 1)
          d = x.size(-1) - 1
          y = x.narrow(-1, 1, d)
          y_norm = torch.clamp(torch.norm(y, p=2, dim=1, keepdim=True), min=self.min_norm)
          y_normalized = y / y_norm
          v = torch.ones_like(x)
          v[:, 0:1] = - y_norm 
          v[:, 1:] = (sqrtK - x0) * y_normalized
          alpha = torch.sum(y_normalized * u[:, 1:], dim=1, keepdim=True) / sqrtK
          res = u - alpha * v
          return self.proj_tan(res, x, c)    
    


1. HGCN Layer


class HyperbolicGraphConvolution(nn.Module):
    """
    Hyperbolic graph convolution layer.
    """

    def __init__(self, manifold, in_features, out_features, c_in, c_out, dropout, act, use_bias, use_att, local_agg):
        super(HyperbolicGraphConvolution, self).__init__()
        self.linear = HypLinear(manifold, in_features, out_features, c_in, dropout, use_bias)
        self.agg = HypAgg(manifold, c_in, out_features, dropout, use_att, local_agg)
        self.hyp_act = HypAct(manifold, c_in, c_out, act)

    def forward(self, input):
        x, adj = input
        h = self.linear.forward(x)
        h = self.agg.forward(h, adj)
        h = self.hyp_act.forward(h)
        output = h, adj
        return output


  • Each of the layers (HypLinear, HypAgg, HypAct) is the generalization of linear transformation, attention, and non-linear activation of the Euclidean neural networks, respectively.

  • Specific operations required for each layer will be discussed in further detail with the implementation codes.


2. Generalization of Hyperbolic Linear Feature Transformation


  • Linear transformation is the most basic and fundamental layer in neural networks to project the given features into more meaningful embedding representations.

  • Based on the mathematical machineries defined on the hyperboloid manifold, lienar feature transformation can be generalized to the hyperbolic space.

  • Feature Transformation in Euclidean Neural Networks

    • Each layer for feature transformation consists of 1. Multiplying weight matrix, 2. Adding Bias

    • $\large H^{k+1} \, = \, \sigma(H^{k}W + b)$


  • Feature Transformation in Hyperbolic Neural Networks

    • Hyperbolic linear layer (weight matrix multiplication and bias addition) can be implemented similar to Euclidean space.


2.1. Matrix Multiplication


  • Multipying weight matrix to the node features embedded on the hyperboloid is done by

    $\large W \, ⊗^{K} \, X^{H} \, := \, \large \exp_{0}^{K}(W \log_{0}^{K}(X^{H}))$

    1. Project the hyperbolic features onto its corresponding tangent space.

    2. Matrix multiplication in Euclidean space

    3. Re-map back to the hyperbolic manifold.

      def mobius_matvec(self, m, x, c):
      u = self.logmap0(x, c)    # map the point x on the hyperboloid to tangent space of the origin (ToH)
      mu = u @ m.transpose(-1, -2)
      return self.expmap0(mu, c)   # map back to the manifold at point 0. 
    


2.2. Bias Addition


  • Bias addition uses parallel transport between tangent spaces of two points.

    1. Project the Euclidean bias vector ($\large b$) to the tangent space at the origin ($\large \mathcal{T}_{o}(\mathbb{H})$), map the vector to the manifold at origin via exp map, and then project the outcome again to the hyperboloid space.

    2. Now take the weighted matrix ($\large W \, ⊗^{K} \, X^{H}$) and hyperbolic bias that are both projected on the hyperboloid and then pass them to the mobius_add function where bias vector is re-mapped to tanget space of the origin and parallel transported to the tangent space of the target point x (where hyperbolic features lie). Then the transported vector $\large v$ is re-mapped to point x on the manifold via exp map.

       def mobius_add(self, x, y, c):
       u = self.logmap0(y, c)
       v = self.ptransp0(x, u, c)
       return self.expmap(v, x, c)
      


    3. Project the resultant to hyperboloid


  • The entire process of hyperoblic linear feature transformation is here in the class HypLinear

      class HypLinear(nn.Module):
          """
          Hyperbolic linear layer.
          """
    
          def __init__(self, manifold, in_features, out_features, c, dropout, use_bias):
              super(HypLinear, self).__init__()
              self.manifold = manifold
              self.in_features = in_features
              self.out_features = out_features
              self.c = c
              self.dropout = dropout
              self.use_bias = use_bias
              self.bias = nn.Parameter(torch.Tensor(out_features))
              self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
              self.reset_parameters()
    
          def reset_parameters(self):
              init.xavier_uniform_(self.weight, gain=math.sqrt(2))
              init.constant_(self.bias, 0)
    
          def forward(self, x):
              drop_weight = F.dropout(self.weight, self.dropout, training=self.training)
              mv = self.manifold.mobius_matvec(drop_weight, x, self.c)
              res = self.manifold.proj(mv, self.c)
              if self.use_bias:
                  bias = self.manifold.proj_tan0(self.bias.view(1, -1), self.c)
                  hyp_bias = self.manifold.expmap0(bias, self.c)
                  hyp_bias = self.manifold.proj(hyp_bias, self.c)
                  res = self.manifold.mobius_add(res, hyp_bias, c=self.c)
                  res = self.manifold.proj(res, self.c)
              return res
    
          def extra_repr(self):
              return 'in_features={}, out_features={}, c={}'.format(
                  self.in_features, self.out_features, self.c
              )
    


3. Attention-Based Neighborhood Feature Aggregation


  • Attention is a crucial deep learning techinique to get relative association between features and to construct new weighted feature in accordance with these attention scores.

  • Basic arithmetic operations required for getting attention scores are matrix multiplication, inner product between feature vectors, and softmax function to get probabilistic scores for each feature.

Operations of Attention


3.1. Attention-Weighted Feature Aggregation


  • Analogous to this, GAT (Graph Attention Networks) also employ the concept of attention to get relative attention between the nodes in the graph.

  • Here, HGCN also uses the attention mechanism to weight each node featuares for neighborhood feature aggregation step .

    • $\large a_{ij}$ = $\large \text{Attn} (x_{i}^{H}, \, x_{j}^{H})$

    • $\large AGG^{K}(x_{i}^{H}) \, = \, \exp_{x_{i}}^{K}(\sum_{j \in N(i)} a_{ij} \log_{x_{i}}^{K}(X^{H})$

    • Note that tangent mapping and exponential mapping are all done at the point $\large x_{i}$, which is the center of the neighborhood as the ablation studies showed that aggregation at the tangent space of $\large x_{i}$ is the best Euclidean approximation of hyperbolic aggregation.


      class HypAgg(Module):
          """
          Hyperbolic aggregation layer.
          """
    
          def __init__(self, manifold, c, in_features, dropout, use_att, local_agg):
              super(HypAgg, self).__init__()
              self.manifold = manifold
              self.c = c
    
              self.in_features = in_features
              self.dropout = dropout
              self.local_agg = local_agg
              self.use_att = use_att
              if self.use_att:
                  self.att = DenseAtt(in_features, dropout)
    
          def forward(self, x, adj):
              x_tangent = self.manifold.logmap0(x, c=self.c)   # Entire features (x) mapped onto the tangent space at origin
              if self.use_att:
                  if self.local_agg:           # mapping to x[i] instead of origin (o)
                      x_local_tangent = []
                      for i in range(x.size(0)):
                          x_local_tangent.append(self.manifold.logmap(x[i], x, c=self.c))  # Values : node features x mapped to the tangent space of the x[i], which will then be weighted by attention score calculated.
                      x_local_tangent = torch.stack(x_local_tangent, dim=0)      # (N, features_mapped_at_x[i])
                      adj_att = self.att(x_tangent, adj)                         # get attention scores using x_tangent (not local) and adj matrix
                      att_rep = adj_att.unsqueeze(-1) * x_local_tangent          # weighting neighborhood features with masked attn scores
                      support_t = torch.sum(att_rep, dim=1)                      # neighborhood feature aggregation : (n x n x 1) * (n x d) -> (n x n x d) => sum of the weighted features of neighbors along the second dimension -> (n x d)
                      output = self.manifold.proj(self.manifold.expmap(x, support_t, c=self.c), c=self.c)   # map new feature vector to the manifold at the point of its previous feature vector -> then project it to hyperboloid
                      return output
                  else:
                      adj_att = self.att(x_tangent, adj)
                      support_t = torch.matmul(adj_att, x_tangent)
              else:
                  support_t = torch.spmm(adj, x_tangent)
              output = self.manifold.proj(self.manifold.expmap0(support_t, c=self.c), c=self.c)
              return output
    
          def extra_repr(self):
              return 'c={}'.format(self.c)    
    


  • In order to get attention scores $\large \text{Attn} (x_{i}^{H}, \, x_{j}^{H})$

    • Map $\large x$ to the tangent space of the origin and then pass it to MLP to get n x n attention score matrix (check the operations in detail from the implementation DenseAtt below with added comments)


      class DenseAtt(nn.Module):
          def __init__(self, in_features, dropout):
              super(DenseAtt, self).__init__()
              self.dropout = dropout
              self.linear = nn.Linear(2 * in_features, 1, bias=True)  # takes 2d and returns 1
              self.in_features = in_features
    
          def forward (self, x, adj):
              n = x.size(0)   # batch size
                
              x_left = torch.unsqueeze(x, 1)   # n x 1 x d
              x_left = x_left.expand(-1, n, -1)   # n x n x d : replicate n times along the second dimension (node i repeated n times : n -> for nodes 1~n)
              x_right = torch.unsqueeze(x, 0)   # 1 x n x d
              x_right = x_right.expand(n, -1, -1)  # n x n x d : replicates n times along the first dimension (for nodes 1~n : n -> repated n times)
    
              x_cat = torch.cat((x_left, x_right), dim=2)  # n x n x 2d 
              att_adj = self.linear(x_cat).squeeze()       # feature transformation to get n x n x 1 -> squeeze to n x n
              att_adj = F.sigmoid(att_adj)                 # get probabilistic attention scores
              att_adj = torch.mul(adj.to_dense(), att_adj)   # mask the attention with adjacency matrix to keep only real edges
              return att_adj
    


3.2. Non-Linear Activation


  • Finally, non-linear activation is applied to the final output features (weighted aggregation by attention) to give sufficient complexity and expressiveness to the network functions.

  • However, as those non-linear functions such as sigmoid and relu are desinged based on the Euclidean space, they cannot be directly applied to hyperbolic space.

  • Since the activation layer is typically the final operation within a given layer, the outcome vector from the activation step serves as the input features for the next layer.

  • Therefore, it needs to be mapped to the new hyperboloid with different curvature. (Note that HGCN takes differing curvature for each layer)

  • Formally, the non-llinear activation between layer $\large l-1$ and $\large l$ is performed as follows

    • $\large \sigma^{⊗^{K_{l-1}, K_{l}}} \, = \, \exp_{o}^{K_{l}}(\sigma(\log_{o}^{K_{l-1}}(X^{H})))$

    • $\sigma$ can be any Euclidean non-linear functions

    1. Take the attention-applied features and map it to the tangent space of the origin of the hyperboloid with curvature $\large K_{l-1}$

    2. Then, apply non-linear activation.

    3. Re-map back to the new hyperboloid of curvature $\large K_{l}$ using exponential map.

      class HypAct(Module):
          """
          Hyperbolic activation layer.
          """
    
          def __init__(self, manifold, c_in, c_out, act):
              super(HypAct, self).__init__()
              self.manifold = manifold
              self.c_in = c_in
              self.c_out = c_out
              self.act = act
    
          def forward(self, x):
              xt = self.act(self.manifold.logmap0(x, c=self.c_in))
              xt = self.manifold.proj_tan0(xt, c=self.c_out)
              return self.manifold.proj(self.manifold.expmap0(xt, c=self.c_out), c=self.c_out)
    
          def extra_repr(self):
              return 'c_in={}, c_out={}'.format(
                  self.c_in, self.c_out
              )
    


4. Trainable Curvature


  • The geometric properties of a hyperboloid significantly depends on its curvature.

  • This implies that the curvature that fits best with the embeddings transformed from each layer can also vary by layers.

  • HGCN succesfully adjusts the curvature at every layer by parameterizing the trainable curvatures.

  • Even with the varying curvatures, the Fermi-Dirac decoder, which is affine-invariant, can achieve the constant performance in link prediction task.


4.1. Features Embedded in Hyperboloid of Differing Curvature


  • Now let’s see how differing curvatures affects the embeddings on the hyperboloid.

  • According to the B.2 Curvature - Lemma 1.

    • For any hyperbolic spaces with constant curvatures ($\large -1/K, \, -1/K’$) and any pair of hyperbolic points $\large (u, v)$ embedded in $\large \mathbb{H}^{d, K}$, there always exists a mapping $\large \phi \, : \, \mathbb{H}^{d, K}, \, \rightarrow \, \mathbb{H}^{d, K’}$ to another pair of corresponding hyperbolic points in $\large \mathbb{H}^{d, K’}$, $\large (\phi(u), \phi(v))$ such that the Minkowski inner product is scaled by a constant factor.

    • For the case of hyperboloid, the Minkowski inner product of the hyperboloid with curvature $\large -1/K$ ($\large \mathbb{H}^{d, K}$) is $\large \langle x, \, x \rangle_{\mathbb{L}} \, = \, -||x||^{2} \, = \, -K$ by definition.

      • Then the embeddings $\large H \, \in \, \mathbb{H}^{d, K}$ and $\large H’ \, \in \, \mathbb{H}^{d, K’}$ satisfies the relation

        $\large H’ \, = \, \frac{\sqrt{K’}}{\sqrt{K}}H$

      • Hence, the mapping $\large \phi$ satisfies $\large \phi(x^{H}) \, = \, \frac{\sqrt{K’}}{\sqrt{K}}x^{H}$ and $\large \langle \phi(x), \, \phi(x) \rangle_{\mathbb{L}} \, = \, -K’$

      • This leads to the distance function for $\large H’ \, \in \, \mathbb{H}^{d, K’}$

  • Due to the linearity between curvature and embeddings, if the decoder is affine-invariant (invariant to linear scaling), the final graph reconstructed from the decoder can be constant regardless of the scale of the input embeddings.


4.2. Fermi-Dirac Decoder and Hyperparameters


  • Fermi-Dirac decoder used in the link prediction task computes the possibility of the existence of edge between two nodes as

    • $\large p((i, j) \in \mathcal{E} \, | \, x_{i}^{H}, \, x_{j}^{H})$ = $\large [e^{(d^{K}(x_{i}^{H}, \, x_{j}^{H})^{2} \, - r)/t} + 1]^{-1}$ where $r, t$ are the hyper-parameters.


      class FermiDiracDecoder(Module):
      """Fermi Dirac to compute edge probabilities based on distances."""
    
      def __init__(self, r, t):
          super(FermiDiracDecoder, self).__init__()
          self.r = r
          self.t = t
    
      def forward(self, dist):
          probs = 1. / (torch.exp((dist - self.r) / self.t) + 1.0)
          return probs
    


  • Then $\large x_{i}^{H}, \, x_{j}^{H}$ is connected by edge iff $\large p((i, j) \in \mathcal{E} \, | \, x_{i}^{H}, \, x_{j}^{H}) \, \geq \, b$

    • and this leads to the equal criterion

      $\large d_{\mathbb{L}}^{K}(x_{i}^{H}, \, x_{j}^{H}) \, \leq \, r \, + t\log(\frac{1 - b}{b})$ where $\large b \, \in \, (0, 1)$

  • Then given the final embeddings from HGCN, $\large H \, = [h_{1}, h_{2}, … , h_{n}]$, the decoder gives the edge set $\large E_{H}$

    • $\large E_{H}$ = {$\large (i, j) | d_{\mathbb{L}}^{K}(h_{i}^{H}, \, h_{j}^{H}) \, \leq \, r \, + t\log(\frac{1 - b}{b})$}
  • As the parameters $\large r, t$ are linear to the criterion, scailing $\large r’, t’$ for $\large \mathbb{H}^{d, K’}$ by a constant factor $\large \frac{\sqrt{K’}}{\sqrt{K}}$ can gives the equal edge reconstruction.

    • Two criterions below are equivalent when $\large r’, t’$ equals to $\large \frac{\sqrt{K’}}{\sqrt{K}}r, \frac{\sqrt{K’}}{\sqrt{K}}t$

    • $\large d_{\mathbb{L}}^{K}(h_{i}^{H}, \, h_{j}^{H}) \, \leq \, r \, + t\log(\frac{1 - b}{b})$

    • $\large d_{\mathbb{L}}^{K’}(\phi(h_{i})^{H}, \, \phi(h_{j})^{H}) \, \leq \, r’ \, + t’\log(\frac{1 - b}{b})$

  • Despite the same expressive capacity with the affine-invariant decoder (Fermi-Dirac), trainable curvatures can still provide the embeddings of a right scale and this plays an important role in the model performance by stabilizing and facilitating the optimization process.

  • Implementing trainable curvature is straightforward. just set the curvature variable for each layer as a tensor parameter.


      def get_dim_act_curv(args):
          """
          Helper function to get dimension and activation at every layer.
          :param args:
          :return:
          """
          if not args.act:
              act = lambda x: x
          else:
              act = getattr(F, args.act)
          acts = [act] * (args.num_layers - 1)
          dims = [args.feat_dim] + ([args.dim] * (args.num_layers - 1))
          if args.task in ['lp', 'rec']:
              dims += [args.dim]
              acts += [act]
              n_curvatures = args.num_layers
          else:
              n_curvatures = args.num_layers - 1
          if args.c is None:
              # create list of trainable curvature parameters
              curvatures = [nn.Parameter(torch.Tensor([1.])) for _ in range(n_curvatures)]
          else:
              # fixed curvature
              curvatures = [torch.tensor([args.c]) for _ in range(n_curvatures)]
              if not args.cuda == -1:
                  curvatures = [curv.to(args.device) for curv in curvatures]
          return dims, acts, curvatures
    



5. Architecture and Performance of HGCN


5.1. Encoder and Decoder of HGCN


Encoder


class HGCN(Encoder):
    """
    Hyperbolic-GCN.
    """

    def __init__(self, c, args):
        super(HGCN, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        assert args.num_layers > 1
        dims, acts, self.curvatures = hyp_layers.get_dim_act_curv(args)
        self.curvatures.append(self.c)
        hgc_layers = []
        for i in range(len(dims) - 1):
            c_in, c_out = self.curvatures[i], self.curvatures[i + 1]
            in_dim, out_dim = dims[i], dims[i + 1]
            act = acts[i]
            hgc_layers.append(
                    hyp_layers.HyperbolicGraphConvolution(
                            self.manifold, in_dim, out_dim, c_in, c_out, args.dropout, act, args.bias, args.use_att, args.local_agg
                    )
            )
        self.layers = nn.Sequential(*hgc_layers)
        self.encode_graph = True

    def encode(self, x, adj):
        x_tan = self.manifold.proj_tan0(x, self.curvatures[0])
        x_hyp = self.manifold.expmap0(x_tan, c=self.curvatures[0])
        x_hyp = self.manifold.proj(x_hyp, c=self.curvatures[0])
        return super(HGCN, self).encode(x_hyp, adj)



NCModel : Decoder for Node Classification


class NCModel(BaseModel):
    """
    Base model for node classification task.
    """

    def __init__(self, args):
        super(NCModel, self).__init__(args)
        self.decoder = model2decoder[args.model](self.c, args)    # Linear Decoder for HGCN
        if args.n_classes > 2:
            self.f1_average = 'micro'
        else:
            self.f1_average = 'binary'
        if args.pos_weight:
            self.weights = torch.Tensor([1., 1. / data['labels'][idx_train].mean()])
        else:
            self.weights = torch.Tensor([1.] * args.n_classes)
        if not args.cuda == -1:
            self.weights = self.weights.to(args.device)

    def decode(self, h, adj, idx):
        output = self.decoder.decode(h, adj)
        return F.log_softmax(output[idx], dim=1)


class LinearDecoder(Decoder):
    """
    MLP Decoder for Hyperbolic/Euclidean node classification models.
    """

    def __init__(self, c, args):
        super(LinearDecoder, self).__init__(c)
        self.manifold = getattr(manifolds, args.manifold)()
        self.input_dim = args.dim
        self.output_dim = args.n_classes
        self.bias = args.bias
        self.cls = Linear(self.input_dim, self.output_dim, args.dropout, lambda x: x, self.bias)
        self.decode_adj = False

    def decode(self, x, adj):
        h = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c), c=self.c)
        return super(LinearDecoder, self).decode(h, adj)


LPModel : Decoder for Link Prediction


class LPModel(BaseModel):
    """
    Base model for link prediction task.
    """

    def __init__(self, args):
        super(LPModel, self).__init__(args)
        self.dc = FermiDiracDecoder(r=args.r, t=args.t)
        self.nb_false_edges = args.nb_false_edges
        self.nb_edges = args.nb_edges

    def decode(self, h, idx):
        if self.manifold_name == 'Euclidean':
            h = self.manifold.normalize(h)
        emb_in = h[idx[:, 0], :]
        emb_out = h[idx[:, 1], :]
        sqdist = self.manifold.sqdist(emb_in, emb_out, self.c)
        probs = self.dc.forward(sqdist)
        return probs


5.2. Performance of HGCN


  • Largely two types of tasks (node classification and link prediction) are used to evaluate the performance of HGCN.


5.2.1. Ablation Studies


  • Ablation studies for link prediction task on DISEASE and AIRPORT datasets to test the contribution of 1. attention-based aggregation and trainable curvature in the performance.

  • ATTo : attention and aggregation at the tangent space of origin

  • ATT : local aggregation

  • C : trainable curvature

  • HGCN-ATT-C records the best and all others perform better compared to the baseline HGCN with no ATT and C.


5.2.2. Comparison to Other Models


Table 3. ROC AUC for LP and F1 score for NC tasks

  • HGCN records compelling performance on most of the datasets compared to other models.


Figure 3. Visualization of Embeddings of LP and NC by GCN and HGCN (embedded in Poincare Model)

  • LP on DISEASE dataset.

  • Depth indicated by the shade of colors.

  • GCN hardly capture the hierarchy (deeper nodes are closer to the root compared to the nodes that are hierarchically close to the root)

  • On the other hand, hierarchical distances are well preserved in the embeddings by HGCN.

  • NC on CORA dataset

  • Different colors represent different classes.

  • HGCN gives better class separation.