Speeding up construction of linear predictor

Hello,

I am hoping to get some advice on speeding up a dynamic latent factor model, which I am fitting using cmdstanr. The major speed bottleneck seems to the construction of the linear predictor nu_metric. Due to the complexity of assigning parameters by time period t as well as unit j and item i, my basic approach has been to loop over every observation (the data in “long” form, with each row a period-unit-item triad). My attempts to vectorize the construction of nu as well as to parallelize within chains have not speed up sampling (if anything, they have slowed it down).

Here is a highly simplified version of the model that illustrates the essence of the problem:

data {
  int<lower=1> D;                           // number of latent dimensions
  int<lower=1> J;                           // number of units
  int<lower=1> T;                           // number of time periods
  int<lower=0> N_metric;                    // number of observations
  int<lower=0> I_metric;                    // number of items
  vector[N_metric] yy_metric;               // outcomes
  array[N_metric] int<lower=1> ii_metric;   // item indicator
  array[N_metric] int<lower=1> jj_metric;   // unit indicator
  array[N_metric] int<lower=1> tt_metric;   // period indicator
}
parameters {
  array[T, I_metric] real alpha_metric;   // metric intercepts
  array[I_metric, D] real lambda_metric;  // metric loadings
  array[T, J, D] real eta;                // latent factors
  vector<lower=0>[I_metric] sigma_metric; // residual sd
}
model {
  /* Priors */
  profile("priors") {
    to_array_1d(alpha_metric) ~ std_normal();
    to_array_1d(lambda_metric) ~ std_normal();
    to_array_1d(eta) ~ std_normal();
    sigma_metric ~ student_t(4, 0.5, 0.5);
  }
  /* Linear predictor */
  vector[N_metric] nu_metric;
  profile("linear_predictor") {
    for (n in 1:N_metric) {
      nu_metric[n] = alpha_metric[tt_metric[n], ii_metric[n]] +
        to_row_vector(lambda_metric[ii_metric[n], 1:D]) *
        to_vector(eta[tt_metric[n], jj_metric[n], 1:D]);
    }
  } 
  /* Likelihood */
  profile("likelihood") {
    target += normal_lupdf(yy_metric | nu_metric, sigma_metric[ii_metric]);
  }
}

Here is a glimpse at the toy data I’ve been using:

 $ D              : num 2
 $ J              : int 50
 $ T              : int 1
 $ N_metric       : int 498
 $ I_metric       : int 10
 $ yy_metric      : num [1:498] -1.5271 -0.2829 -0.4005 -0.0445 2.816 ...
 $ ii_metric      : int [1:498] 1 1 1 1 1 1 1 1 1 1 ...
 $ jj_metric      : int [1:498] 1 2 3 4 5 6 7 8 9 10 ...
 $ tt_metric      : int [1:498] 1 1 1 1 1 1 1 1 1 1 ...

Fitting the model with 1000 warmup and 1000 sampled iterations on the toy data takes about two minutes per chain. The output of profiles() for the typical chain is

              name   thread_id total_time forward_time reverse_time chain_stack
1       likelihood 0x115491600    5.18406      4.50953     0.674531      298901
2 linear_predictor 0x115491600   77.50600     73.98060     3.525410   297706392
3           priors 0x115491600    2.28105      2.15572     0.125335     1195608
  no_chain_stack autodiff_calls no_autodiff_calls
1              0         298901                 1
2      148853196         298901                 1
3              0         298901                 1

which makes it clear that constructing nu dominates the computation time. If anyone has bright ideas for speeding that process up, I would love to hear them.

Many thanks!
Devin