Fast weighted decomposition kernel

I’d like to use a weighted decomposition kernel in a Gaussian process in stan.

The weighted decomposition kernel is:

k(s, s’) = \sum_{i=1}^L \left(S(s_i, s’i) \sum{l\in \text{nbs}(i)} S(s_l, s’_l) \right)

Where S is a substitution matrix.

I can do it like this:

  functions {
    real wdk(int[] x1, int[] x2, int[,] adj, int L, matrix S){
      vector[L] K;
      real subs[L + 1];
      subs[L + 1] = 0;
      for (k in 1:L){
        subs[k] = S[x1[k], x2[k]];
      }
      for (k in 1:L){
        K[k] = subs[k] * sum(subs[adj[k]]);
      }
      return sum(K);
    }
    
    matrix wd_cov(int[,] X1, int[,] X2, int[,] adj, int L, matrix S){
      int n1 = size(X1);
      int n2 = size(X2);
      int x1[L];
      int x2[L];
      matrix[n1, n2] K;
      vector[n1] K1;
      row_vector[n2] K2;
      for (i in 1:n1){
        for (j in 1:n2){
          K[i, j] = wdk(X1[i], X2[j], adj, L, S);
        }
      }
      for (i in 1:n1){
        K1[i] = wdk(X1[i], X1[i], adj, L, S);
      }
      for (i in 1:n2){
        K2[i] = wdk(X2[i], X2[i], adj, L, S);
      }
      K1 = sqrt(K1);
      K2 = sqrt(K2);
      for (i in 1:n2){
        K[:, i] = K[:, i] ./ K1;
      }
      for (i in 1:n1){
        K[i] = K[i] ./ K2;
      }
      return K;
    }
  }
  
  data {
    int<lower=1> n1;
    int<lower=1> n2;
    int<lower=1> L;
    int<lower=1> n_subs;
    int<lower=2> D;
    int X1[n1, L];
    int X2[n2, L];
    int adj[L, n_subs];
    matrix[D, D] S;
  }

  generated quantities {
      matrix[n1, n2] cov = wd_cov(X1, X2, adj, L, S);
  }

But the loop is very very slow. In numpy or pytorch I can implement a much faster (but memory-intensive) version of the kernel with fancy indexing:

class WeightedDecompositionKernel(object):
    
    def __init__(self, contacts, S, L):
        super(WeightedDecompositionKernel, self).__init__()
        self.S = S
        self.graph = self.make_graph(contacts, L)   

    def make_graph(self, contacts, L):
        graph = [[] for i in range(L)]
        for c1, c2 in contacts:
            graph[c1].append(int(c2))
            graph[c2].append(int(c1))
        max_L = max([len(g) for g in graph])
        # Fill with -1s so that every row has the same length
        graph = [g + [-1] * (max_L - len(g)) for g in graph]
        return np.array(graph).astype(int)
  
    def wdk(self, subs):
        n = len(subs)
        z = np.zeros((n, 1))
        subs_ = np.concatenate([subs, z], axis=1)
        return np.sum(subs_[:, self.graph].sum(axis=2) * subs, axis=1)    
    
    def cov(self, X1, X2):
        n1, L = X1.shape
        n2, _ = X2.shape
        subs = self.S[X1, X1]
        k1 = self.wdk(subs).reshape((n1, 1))
        subs = self.S[X2, X2]
        k2 = self.wdk(subs).reshape((1, n2))
        L_inds = np.arange(L).astype(int)
        subs = self.S[X1][:, L_inds, X2].reshape((n1 * n2, L))
        K = self.wdk(subs).reshape((n1, n2))
        return K / np.sqrt(k1) / np.sqrt(k2)

Is there any way to get an equivalent speedup in stan?

No