I’d like to use a weighted decomposition kernel in a Gaussian process in stan.
The weighted decomposition kernel is:
k(s, s’) = \sum_{i=1}^L \left(S(s_i, s’i) \sum{l\in \text{nbs}(i)} S(s_l, s’_l) \right)
Where S is a substitution matrix.
I can do it like this:
functions {
real wdk(int[] x1, int[] x2, int[,] adj, int L, matrix S){
vector[L] K;
real subs[L + 1];
subs[L + 1] = 0;
for (k in 1:L){
subs[k] = S[x1[k], x2[k]];
}
for (k in 1:L){
K[k] = subs[k] * sum(subs[adj[k]]);
}
return sum(K);
}
matrix wd_cov(int[,] X1, int[,] X2, int[,] adj, int L, matrix S){
int n1 = size(X1);
int n2 = size(X2);
int x1[L];
int x2[L];
matrix[n1, n2] K;
vector[n1] K1;
row_vector[n2] K2;
for (i in 1:n1){
for (j in 1:n2){
K[i, j] = wdk(X1[i], X2[j], adj, L, S);
}
}
for (i in 1:n1){
K1[i] = wdk(X1[i], X1[i], adj, L, S);
}
for (i in 1:n2){
K2[i] = wdk(X2[i], X2[i], adj, L, S);
}
K1 = sqrt(K1);
K2 = sqrt(K2);
for (i in 1:n2){
K[:, i] = K[:, i] ./ K1;
}
for (i in 1:n1){
K[i] = K[i] ./ K2;
}
return K;
}
}
data {
int<lower=1> n1;
int<lower=1> n2;
int<lower=1> L;
int<lower=1> n_subs;
int<lower=2> D;
int X1[n1, L];
int X2[n2, L];
int adj[L, n_subs];
matrix[D, D] S;
}
generated quantities {
matrix[n1, n2] cov = wd_cov(X1, X2, adj, L, S);
}
But the loop is very very slow. In numpy or pytorch I can implement a much faster (but memory-intensive) version of the kernel with fancy indexing:
class WeightedDecompositionKernel(object):
def __init__(self, contacts, S, L):
super(WeightedDecompositionKernel, self).__init__()
self.S = S
self.graph = self.make_graph(contacts, L)
def make_graph(self, contacts, L):
graph = [[] for i in range(L)]
for c1, c2 in contacts:
graph[c1].append(int(c2))
graph[c2].append(int(c1))
max_L = max([len(g) for g in graph])
# Fill with -1s so that every row has the same length
graph = [g + [-1] * (max_L - len(g)) for g in graph]
return np.array(graph).astype(int)
def wdk(self, subs):
n = len(subs)
z = np.zeros((n, 1))
subs_ = np.concatenate([subs, z], axis=1)
return np.sum(subs_[:, self.graph].sum(axis=2) * subs, axis=1)
def cov(self, X1, X2):
n1, L = X1.shape
n2, _ = X2.shape
subs = self.S[X1, X1]
k1 = self.wdk(subs).reshape((n1, 1))
subs = self.S[X2, X2]
k2 = self.wdk(subs).reshape((1, n2))
L_inds = np.arange(L).astype(int)
subs = self.S[X1][:, L_inds, X2].reshape((n1 * n2, L))
K = self.wdk(subs).reshape((n1, n2))
return K / np.sqrt(k1) / np.sqrt(k2)
Is there any way to get an equivalent speedup in stan?