Very slow MCMC sampling in stan_jm() due to large number of individuals

I have been struggling time-wise fitting a joint model using stan_jm() from rstanarm pacakge due to large number of individuals (~10,000) and multiple repeated measurements for each individual. The fact that I would like to model time variable with a restricted cubic spline also adds up to the burden. I recreated the structure of the model using the sample code from stan_jm vignette. Please see below.

I was wondering if there is anyway to speed up the MCMC sampling. I don’t think within-chain paralellization with ‘reduce_sum’ is available for rstanarm models but I was hoping there would be a way to incorporate it. Any help would be appreciated! Thank you!

stan_jm(
        formulaLong = list(
         logBili ~ age + sex + trt + rms::rcs(year,4) + (rms::rcs(year,4) | id)), 
        formulaEvent = survival::Surv(futimeYears, death) ~ age + sex + trt, 
        dataLong = pbcLong, dataEvent = pbcSurv, assoc = c("etavalue"),
        time_var = "year", cores=parallel::detectCores(),
        chains = 2, iter=2000, refresh = 200, seed = 12345, max_treedepth = 18)
****
1 Like

This indeed looks like a very big model, I fear that either the longitudinal model or the time-to-event component would be very demanding on their own. It however seems that maybe at least part of the problem is that the model is actually not well informed by the data - you seem to need very large max_treedepth setting. Do you know why that is so? And as the “folk theorem” goes slow computation often also means a problematic model.

In any case, I would definitely start with some experiments with simpler models and subsets of the individuals to get a grip on what makes the model misbehave and to get an idea of how big models can you fit in reasonable time. The dataset is indeed big, so it is not implausible that the full problem is just currently out of reach.

Best of luck with the model!

1 Like

Thank you so much for the reply! Max_treedepth argument was actually left there from another model so in fact it wasn’t intentional. The survival model alone finished sampling on my local machine (over 4 cores for 4 chains) in a few days without any issues at the end which makes me think that the real problem is more about the data size rather than model misspecification. I also put the longitudinal and joint models on our high-performance computing cluster, current it’s been a few days of run-time. I’ll update the post later on as far as total computing time to give forum members a point of reference for fitting large size complex joint/multilevel/survival models.

1 Like

Do you have a strong reason for the rcs having a degree of freedom of 4? This seems like alot, and lowering it would greatly reduce the complexity of the model. I don’t typically use Stan to fit JM’s, I usually use the cran package JMbayes2 these days, but in any case I’ve found that using splines with more than 2 or 3 df’s tends not to work. Of course the number of longitudinal measurements per individual is also important here, as is the heterogeneity of the number of longitudinal measurements per person. For example if all of your individuals have exactly 5 readings, that will fit more easily than if you have a mean of 5 readings per person, but say some significant fractional subset only has 1 or 2 readings - then if you are using rcs(, 4), … thats going to be very hard to fit!

2 Likes

Yes, I did lower it to 3 as I was okay with that level of flexibility in modeling the relationship. I have heard of JMbayes2 but assumed it would be even slower to fit since it doesn’t use Stan’s No-U-Turn sampler for fitting (I believe it uses JAGS for MCMC sampling). The number of measurements also varies among participants, as you pointed out, it is another factor contributing to the speed issues. If I’m able to fit the model within a reasonable time frame, I will share some metrics to give a point of reference to forum members who plan on fitting similar models with a cohort of comparable size to mine.

JMbayes has options to fit the longtiudinal components with JAGS or STAN, but the JMBayes2 is totally new and uses custom written C++ code. I don’t know however what MCMC algorthim is uses underneath (I don’t grok C), but it is very, very, fast in the dabbling I have done. It could be worth trying.

In the other thread you are using the area under the curve association? The default “etavalue” seems to run faster in the package examples for me. I’d restate @martinmodrak’s recommedations - use a subset of the data, and make the model simpler, try using the default association parameter and even drop the splines altogether and see if you can fit that - you can then make it more complex later.

2 Likes

I’ll give JMBayes2 a try as well, thank you for the suggestion. Model fit had no issues with a subset of 400 individuals but I will also try running a more simplified version with a larger chunk of the data. Thank you both! @jroon @martinmodrak

2 Likes