Within-chain parallelization using a hierarchical model generated by the package rethinking

Hello,
I am trying to run a hierarchical model using centered-parametrization. Because the sample is big, I would prefer to run the model using within-chain parallelization.

I create the model using the package rethinking, however, I haven’t been able to run it using threads.

library(data.table)
library(rethinking)
library(cmdstanr)
dat = fread("output/data/example.csv")

f = alist(
    mx ~ normal(mu, sigma),
    mu <- a_county[county] + b_time_county[county] * time + b_period_county[county] * period + 
        b_time_period_county[county] * time_period,
    c(a_county, b_time_county, b_period_county, b_time_period_county)[county] ~ 
        multi_normal(c(a,b_time, b_period, b_time_period), Rho, sigma_county),
    c(a, b_time, b_period, b_time_period) ~ normal(0, 1.0),
    c(sigma_county, sigma) ~ exponential(1),
    Rho ~ lkj_corr(4)
)

m0 = ulam(f, data = dat, sample = FALSE, threads = 10, cmdstan=TRUE)
cat(stancode(m0), file = "src/model.stan")

example.csv (773.9 KB)
model.stan (2.3 KB)

The stan code generated by rethinking has a problem when creating the map reducer function:

corr_matrix[] Rho;

As @jonah and @WardBrian pointed out, function arguments can’t be constrained types like correlation matrices, so that should just be matrix Rho. When I use matrix Rho, I still got an error:

Compiling Stan program...
Semantic error in '/tmp/RtmpeINNBr/model-174e1161d04c07.stan', line 66, column 14 to line 83, column 17:
   -------------------------------------------------
    64:      YY ~ multi_normal( MU , quad_form_diag(Rho , sigma_county) );
    65:      }
    66:      target += reduce_sum( reducer , mx , 1 , 
                       ^
    67:              N,
    68:              mx_sd,
   -------------------------------------------------

Ill-typed arguments supplied to function 'reduce_sum':
(<F1>, vector, int, int, vector, array[] int, array[] int, array[] int,
 array[] int, vector, vector, vector, vector, real, real, real, real, real,
 vector, matrix)
where F1 = (vector, int, int, int, vector, array[] int, array[] int,
            array[] int, array[] int, vector, vector, vector, vector, real,
            real, real, real, real, vector, matrix) => real
Available signatures:
(<F2>, array[] real, int) => real
where F2 = (array[] real, data int, data int) => real
  The first argument must be
   (array[] real, data int, data int) => real
  but got
   (vector, int, int, int, vector, array[] int, array[] int, array[] int,
    array[] int, vector, vector, vector, vector, real, real, real, real,
    real, vector, matrix) => real
  These are not compatible because:
    The types for the first argument are incompatible: one is
     vector
    but the other is
     array[] real
make: *** [make/program:50: /tmp/RtmpeINNBr/model-174e1161d04c07.hpp] Error 1
Error: An error occured during compilation! See the message above for more information.

Any suggestions or ideas on how to solve this problem?
Thank you so much!

I have no idea what that code’s doing (is it rethinking?). The key to the error message is this:

The first argument must be

   (array[] real, data int, data int) => real

  but got

   (vector, int, int, int, vector, array[] int, array[] int, array[] int,
    array[] int, vector, vector, vector, vector, real, real, real, real,
    real, vector, matrix) => real

This says the first argument must be a function that takes (array[] real, data int, data int) as an argument and returns a value of type real. The function you passed in has many many more arguments.