Faster weighted decomposition kernel

I’m implementing the weighted decomposition kernel. The implementation works, but it’s ~10x slower than a numba’d python/numpy implementation. What can I do to speed it up?

The weighted decomposition kernel is:

k(s, s') = \sum_{i=1}^L \left(S(s_i, s'_i) \sum_{l\in \text{nbs}(i)} S(s_l, s'_l) \right)

Where S is a substitution matrix.

  functions {
    real wdk(int[] x1, int[] x2, int[,] adj, int L, matrix S){
      vector[L] K;
      real subs[L + 1];
      subs[L + 1] = 0;
      for (k in 1:L){
        subs[k] = S[x1[k], x2[k]];
      }
      for (k in 1:L){
        K[k] = subs[k] * sum(subs[adj[k]]);
      }
      return sum(K);
    }
    
    matrix wd_cov(int[,] X1, int[,] X2, int[,] adj, int L, matrix S){
      int n1 = size(X1);
      int n2 = size(X2);
      int x1[L];
      int x2[L];
      matrix[n1, n2] K;
      vector[n1] K1;
      row_vector[n2] K2;
      for (i in 1:n1){
        for (j in 1:n2){
          K[i, j] = wdk(X1[i], X2[j], adj, L, S);
        }
      }
      for (i in 1:n1){
        K1[i] = wdk(X1[i], X1[i], adj, L, S);
      }
      for (i in 1:n2){
        K2[i] = wdk(X2[i], X2[i], adj, L, S);
      }
      K1 = sqrt(K1);
      K2 = sqrt(K2);
      for (i in 1:n2){
        K[:, i] = K[:, i] ./ K1;
      }
      for (i in 1:n1){
        K[i] = K[i] ./ K2;
      }
      return K;
    }
  }
  
  data {
    int<lower=1> n1;
    int<lower=1> n2;
    int<lower=1> L;
    int<lower=1> n_subs;
    int<lower=2> D;
    int X1[n1, L];
    int X2[n2, L];
    int adj[L, n_subs];
    matrix[D, D] S;
  }

  generated quantities {
      matrix[n1, n2] cov = wd_cov(X1, X2, adj, L, S);
  }

What exactly are you comparing? Did you configure things to run only a single iteration and not do sampling?

I don’t see an easy way to make this a lot faster. Do you have access to the code computing this in numpy? If so, I’d look there for suggestions. Stan compiels straight down to C++ pretty directly, so any fast algorithm should be pretty easy to implement.

The place to look would be in repeated work among the different wdk calls that could be cached and reused. I don’t see much faster way to do that internally.

Yes, it is a single iteration with no sampling.

Here’s the numpy code:

class WeightedDecompositionKernel(object):
    
    def __init__(self, contacts, S, L):
        super(WeightedDecompositionKernel, self).__init__()
        self.S = S
        self.graph = self.make_graph(contacts, L)   

    def make_graph(self, contacts, L):
        graph = [[] for i in range(L)]
        for c1, c2 in contacts:
            graph[c1].append(int(c2))
            graph[c2].append(int(c1))
        max_L = max([len(g) for g in graph])
        # Fill with -1s so that every row has the same length
        graph = [g + [-1] * (max_L - len(g)) for g in graph]
        return np.array(graph).astype(int)
  
    def wdk(self, subs):
        n = len(subs)
        z = np.zeros((n, 1))
        subs_ = np.concatenate([subs, z], axis=1)
        return np.sum(subs_[:, self.graph].sum(axis=2) * subs, axis=1)    
    
    def cov(self, X1, X2):
        n1, L = X1.shape
        n2, _ = X2.shape
        subs = self.S[X1, X1]
        k1 = self.wdk(subs).reshape((n1, 1))
        subs = self.S[X2, X2]
        k2 = self.wdk(subs).reshape((1, n2))
        L_inds = np.arange(L).astype(int)
        subs = self.S[X1][:, L_inds, X2].reshape((n1 * n2, L))
        K = self.wdk(subs).reshape((n1, n2))

        return K / np.sqrt(k1) / np.sqrt(k2)

The speedups in numpy come from extensive fancy-indexing to vectorize operations. A lot of these aren’t possible in stan, as far as I know?

I’m not sure what all those operations like S[X1, X1]are doing as I’m not quite sure what X1 and X2 are here. Stan has a lot of built-in reshaping functions, but some of these I’m not sure what they’re doing. What we don’t have any way of doing is soehting like subs = self.S[X1, X1] if that’s supposed to produce a singly-indexed object.

S is a 21 x 21 substitution matrix. X1 and X2 are n x L and m x L inputs with values between 0 and 20 (in python).

It’s an implementation of the weighted decomposition kernel from https://arxiv.org/pdf/1802.02852.pdf.