Dear Stan community,
I’m looking for help specifying the following tensor factorization model (I used this resource as a starting point).
\mathcal{Y} \approx \sum_{k=1}^{K} \mathbf x_k \circ \mathbf w_k \circ \mathbf u_k
where \circ is the outer product.
Specifically, I’m looking for suggestions to (1) improve sampling efficiency, since lack of convergence is a problem, and (2) vectorize the nested loop, since the current program will not scale well.
I hope the reproducible example below illustrates the problem.
# setup -------------------------------------------------------------------
library("cmdstanr")
#> This is cmdstanr version 0.5.3
#> - CmdStanR documentation and vignettes: mc-stan.org/cmdstanr
#> - CmdStan path: /Users/cwrs/.cmdstan/cmdstan-2.32.2
#> - CmdStan version: 2.32.2
cores <- parallel::detectCores()
# simulate ----------------------------------------------------------------
# simulate a tensor to factorize:
N <- 10 # dimension 1
D <- 10 # dimension 2
L <- 10 # dimension 3
K <- 1 # number of components
# for K=1, tensor is the outer product of three vectors:
X <- matrix(rnorm(N * K, 0, 1), N, K)
W <- matrix(rnorm(D * K, 0, 1), D, K)
U <- matrix(rnorm(L * K, 0, 1), L, K)
sigma <- .05 # sd of error
e <- rnorm(N * D * L, 0, sigma) # error
mu <- lapply(1:K, \(k) (X[,k] %o% W[,k]) %o% U[,k]) # generate outer product
Y <- Reduce('+', mu) + # sum components
array(e, c(N, D, L)) # add measurement error
# stan program ------------------------------------------------------------
scode <- "
// modified from: https://www.cs.helsinki.fi/u/sakaya/tutorial/code/tensor2.R
data {
int<lower=0> N;
int<lower=0> D;
int<lower=0> K;
int<lower=0> L;
array[N, D, L] real Y;
}
parameters {
matrix[N, K] X;
matrix[L, K] U;
matrix[D, K] W;
real<lower=0> sigma;
}
model {
real tmp;
sigma ~ normal(0,1);
to_vector(X) ~ normal(0,1);
to_vector(U) ~ normal(0,1);
to_vector(W) ~ normal(0,1);
for (n in 1:N) {
for (d in 1:D) {
for (l in 1:L) {
tmp = 0;
for (k in 1:K)
tmp = tmp + X[n,k] * W[d,k] * U[l,k];
Y[n,d,l] ~ normal(tmp, sigma);
}
}
}
}
"
# sample ------------------------------------------------------------------
# optional code to run the model:
# standata <- list(N = N, D = D, L = L, K = 1, Y = Y)
#
# stanseed <- 215678
#
# model <- cmdstan_model(stan_file = write_stan_file(scode))
#
# fit <- model$sample(
# data = standata,
# seed = stanseed,
# parallel_chains = cores
# )
#
# fit$cmdstan_diagnose()
Created on 2023-07-12 with reprex v2.0.2