Help needed specifying and optimizing tensor factorization model

Dear Stan community,

I’m looking for help specifying the following tensor factorization model (I used this resource as a starting point).

\mathcal{Y} \approx \sum_{k=1}^{K} \mathbf x_k \circ \mathbf w_k \circ \mathbf u_k

where \circ is the outer product.

Specifically, I’m looking for suggestions to (1) improve sampling efficiency, since lack of convergence is a problem, and (2) vectorize the nested loop, since the current program will not scale well.

I hope the reproducible example below illustrates the problem.

# setup -------------------------------------------------------------------

library("cmdstanr")
#> This is cmdstanr version 0.5.3
#> - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
#> - CmdStan path: /Users/cwrs/.cmdstan/cmdstan-2.32.2
#> - CmdStan version: 2.32.2
cores <- parallel::detectCores()

# simulate ----------------------------------------------------------------

# simulate a tensor to factorize:

N <- 10 # dimension 1
D <- 10 # dimension 2
L <- 10 # dimension 3
K <- 1 # number of components

# for K=1, tensor is the outer product of three vectors:
X <- matrix(rnorm(N * K, 0, 1), N, K)
W <- matrix(rnorm(D * K, 0, 1), D, K)
U <- matrix(rnorm(L * K, 0, 1), L, K)

sigma <- .05 # sd of error
e <- rnorm(N * D * L, 0, sigma) # error
mu <- lapply(1:K, \(k) (X[,k] %o% W[,k]) %o% U[,k]) # generate outer product

Y <- Reduce('+', mu) + # sum components
array(e, c(N, D, L)) # add measurement error

# stan program ------------------------------------------------------------

scode <- "
// modified from: https://www.cs.helsinki.fi/u/sakaya/tutorial/code/tensor2.R

data {
int<lower=0> N;
int<lower=0> D;
int<lower=0> K;
int<lower=0> L;
array[N, D, L] real Y;
}
parameters {
matrix[N, K] X;
matrix[L, K] U;
matrix[D, K] W;
real<lower=0> sigma;
}
model {
real tmp;
sigma ~ normal(0,1);
to_vector(X) ~ normal(0,1);
to_vector(U) ~ normal(0,1);
to_vector(W) ~ normal(0,1);
for (n in 1:N) {
for (d in 1:D) {
for (l in 1:L) {
tmp = 0;
for (k in 1:K)
tmp = tmp + X[n,k] * W[d,k] * U[l,k];
Y[n,d,l] ~ normal(tmp, sigma);
}
}
}
}
"

# sample ------------------------------------------------------------------

# optional code to run the model:

# standata <- list(N = N, D = D, L = L, K = 1, Y = Y)
#
# stanseed <- 215678
#
# model <- cmdstan_model(stan_file = write_stan_file(scode))
#
# fit <- model$sample( # data = standata, # seed = stanseed, # parallel_chains = cores # ) # # fit$cmdstan_diagnose()


Created on 2023-07-12 with reprex v2.0.2

The trick is to reduce this all to matrix and element wise operations wherever possible. They have much more efficient autodiff than the equivalent scalar operations.

data {
int<lower=0> N;
int<lower=0> D;
int<lower=0> K;
int<lower=0> L;
array[N, D, L] real Y;
}
parameters {
array[N] row_vector[K] X;
matrix[K, L] U;
array[D] row_vector[K] W;
real<lower=0> sigma;
}
model {
sigma ~ normal(0, 1);
for (k in 1:K) {
X[k] ~ normal(0, 1);
W[k] ~ normal(0, 1);
}
to_vector(U) ~ normal(0, 1);

for (n in 1:N) {
for (d in 1:D) {
Y[n,d] ~ normal(X[n] .* W[d] * U, sigma);
}
}
}


I rewrote the shapes of the parameters so that X[n] and W[d] don’t involve copying by making them arrays of vectors (accessing an array doesn’t require a copy, but accessing a row of a matrix does). Then I transposed U so I didn’t have to transpose again int he sampling statement. This meant I had to do a loop for the X and W priors, but that’s minor compared to all the arithmetic.