Help needed specifying and optimizing tensor factorization model

Dear Stan community,

I’m looking for help specifying the following tensor factorization model (I used this resource as a starting point).

\mathcal{Y} \approx \sum_{k=1}^{K} \mathbf x_k \circ \mathbf w_k \circ \mathbf u_k

where \circ is the outer product.

Specifically, I’m looking for suggestions to (1) improve sampling efficiency, since lack of convergence is a problem, and (2) vectorize the nested loop, since the current program will not scale well.

I hope the reproducible example below illustrates the problem.

# setup -------------------------------------------------------------------

library("cmdstanr")
#> This is cmdstanr version 0.5.3
#> - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
#> - CmdStan path: /Users/cwrs/.cmdstan/cmdstan-2.32.2
#> - CmdStan version: 2.32.2
cores <- parallel::detectCores()

# simulate ----------------------------------------------------------------

# simulate a tensor to factorize:

N <- 10 # dimension 1
D <- 10 # dimension 2
L <- 10 # dimension 3
K <- 1 # number of components

# for K=1, tensor is the outer product of three vectors:
X <- matrix(rnorm(N * K, 0, 1), N, K)
W <- matrix(rnorm(D * K, 0, 1), D, K)
U <- matrix(rnorm(L * K, 0, 1), L, K)

sigma <- .05 # sd of error
e <- rnorm(N * D * L, 0, sigma) # error
mu <- lapply(1:K, \(k) (X[,k] %o% W[,k]) %o% U[,k]) # generate outer product

Y <- Reduce('+', mu) + # sum components
  array(e, c(N, D, L)) # add measurement error

# stan program ------------------------------------------------------------

scode <- "
// modified from: https://www.cs.helsinki.fi/u/sakaya/tutorial/code/tensor2.R

data {
    int<lower=0> N; 
    int<lower=0> D; 
    int<lower=0> K;
    int<lower=0> L;
    array[N, D, L] real Y;
}
parameters {
    matrix[N, K] X;
    matrix[L, K] U;
    matrix[D, K] W;
    real<lower=0> sigma;
}
model {
    real tmp;
    sigma ~ normal(0,1);
    to_vector(X) ~ normal(0,1);         
    to_vector(U) ~ normal(0,1);         
    to_vector(W) ~ normal(0,1);
    for (n in 1:N) {
        for (d in 1:D) {
            for (l in 1:L) {
                tmp = 0;
                for (k in 1:K)
                    tmp = tmp + X[n,k] * W[d,k] * U[l,k];           
                Y[n,d,l] ~ normal(tmp, sigma);
            }
        }
    }
}
"

# sample ------------------------------------------------------------------

# optional code to run the model:

# standata <- list(N = N, D = D, L = L, K = 1, Y = Y)
# 
# stanseed <- 215678
# 
# model <- cmdstan_model(stan_file = write_stan_file(scode))
# 
# fit <- model$sample(
#   data = standata,
#   seed = stanseed,
#   parallel_chains = cores
# )
# 
# fit$cmdstan_diagnose()

Created on 2023-07-12 with reprex v2.0.2

The trick is to reduce this all to matrix and element wise operations wherever possible. They have much more efficient autodiff than the equivalent scalar operations.

data {
    int<lower=0> N; 
    int<lower=0> D; 
    int<lower=0> K;
    int<lower=0> L;
    array[N, D, L] real Y;
}
parameters {
    array[N] row_vector[K] X;
    matrix[K, L] U;
    array[D] row_vector[K] W;
    real<lower=0> sigma;
}
model {
    sigma ~ normal(0, 1);
    for (k in 1:K) {
      X[k] ~ normal(0, 1);
      W[k] ~ normal(0, 1);
    }
    to_vector(U) ~ normal(0, 1);         

    for (n in 1:N) {
        for (d in 1:D) {
            Y[n,d] ~ normal(X[n] .* W[d] * U, sigma);
        }
    }
}

I rewrote the shapes of the parameters so that X[n] and W[d] don’t involve copying by making them arrays of vectors (accessing an array doesn’t require a copy, but accessing a row of a matrix does). Then I transposed U so I didn’t have to transpose again int he sampling statement. This meant I had to do a loop for the X and W priors, but that’s minor compared to all the arithmetic.

Thank you—that is tremendously helpful!

My parameterization, though, seems to be problematic. The result of the original example, modified with the suggested changes to the Stan code, is a fit where most Rhats are far too high (>1.5) and most bulk ESS estimates are far too low (<10).

I’m struggling to come up with an alternative parameterization. Do you have any suggestions?