One chain considerably slower than all others

I’m running Stan vis rstan / brms / cmdstanr or directly via cmdstan on a cloud instance supplied by CodeOcean (details below). This is a 16 core machine.

Regardless of which interface I use, or which model I run, often I get one chain going much slower than others - e.g. in simple model I just ran through brms, it had its first sample only after the three others ran 40% of samples.

There is sufficient number of cores available, so I am preplexed about this behavior. Any ideas about where to look for a solution?

sessionInfo()
R version 4.0.3 (2020-10-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.5 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8     LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                  LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] brms_2.14.4       Rcpp_1.0.6        cmdstanr_0.3.0    cowplot_1.1.1     data.table_1.13.6 ggplot2_3.3.3    

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_2.0-0     ellipsis_0.3.1       ggridges_0.5.3       rsconnect_0.8.16     markdown_1.1         base64enc_0.1-3      rstudioapi_0.13     
  [9] farver_2.0.3         rstan_2.21.2         DT_0.17              fansi_0.4.2          mvtnorm_1.1-1        bridgesampling_1.0-0 codetools_0.2-16     splines_4.0.3       
 [17] knitr_1.31           shinythemes_1.2.0    bayesplot_1.8.0      projpred_2.0.2       jsonlite_1.7.2       nloptr_1.2.2.2       packrat_0.5.0        shiny_1.6.0         
 [25] compiler_4.0.3       backports_1.2.1      assertthat_0.2.1     Matrix_1.2-18        fastmap_1.1.0        cli_2.2.0            later_1.1.0.1        htmltools_0.5.1.1   
 [33] prettyunits_1.1.1    tools_4.0.3          igraph_1.2.6         coda_0.19-4          gtable_0.3.0         glue_1.4.2           posterior_0.1.3      RWiener_1.3-3       
 [41] reshape2_1.4.4       dplyr_1.0.2          V8_3.4.0             vctrs_0.3.6          nlme_3.1-149         crosstalk_1.1.1      xfun_0.20            stringr_1.4.0       
 [49] ps_1.5.0             lme4_1.1-26          mime_0.9             miniUI_0.1.1.1       lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.35       MASS_7.3-53         
 [57] zoo_1.8-8            scales_1.1.1         colourpicker_1.1.0   promises_1.1.1       Brobdingnag_1.2-6    parallel_4.0.3       inline_0.3.17        shinystan_2.5.0     
 [65] yaml_2.2.1           gamm4_0.2-6          curl_4.3             gridExtra_2.3        loo_2.4.1            StanHeaders_2.21.0-7 stringi_1.5.3        dygraphs_1.1.1.6    
 [73] checkmate_2.0.0      boot_1.3-25          pkgbuild_1.2.0       rlang_0.4.10         pkgconfig_2.0.3      matrixStats_0.57.0   evaluate_0.14        lattice_0.20-41     
 [81] purrr_0.3.4          rstantools_2.1.1     htmlwidgets_1.5.3    labeling_0.4.2       processx_3.4.5       tidyselect_1.1.0     plyr_1.8.6           magrittr_2.0.1      
 [89] R6_2.5.0             generics_0.0.2       pillar_1.4.7         withr_2.4.1          mgcv_1.8-33          xts_0.12.1           abind_1.4-5          tibble_3.0.5        
 [97] crayon_1.3.4         rmarkdown_2.4        grid_4.0.3           callr_3.5.1          threejs_0.3.3        digest_0.6.27        xtable_1.8-4         httpuv_1.5.5        
[105] RcppParallel_5.0.2   stats4_4.0.3         munsell_0.5.0        shinyjs_2.0.0
cmdstan_version()
[1] "2.26.0"
1 Like

I’m guessing here (but correct me if I’m wrong) that you mean that some chains are getting through 40% of their sampling iterations (so like 1400 iterations total, including warmup, if you’re using the defaults) before another chain finishes warmup (i.e. its first thousand iterations, if you’re using the defaults). Thus, one chain is running about 1.5 or 2 times slower.

This is not so uncommon, and is likely to be due to the vicissitudes of warmup in the slow chain. Stochastically, the chain might take longer to find the typical set than the others (possibly due to unlucky inits, but possibly also due to the stochasticity inherent in the algorithm). Once the chain arrives in the typical set, the adaptation of the mass matrix begins to substantially speed up sampling by allowing the sampler to take fewer, longer leapfrog steps. If one chain ends up lagging behind the others during this process, then it might plausibly end up reaching systematically deeper treedepths than the others during warmup, which will cause it to run dramatically slower.

If the initial values are the culprit, you can ameliorate the behavior by passing better inits.

In any case, these differences in runtime can even persist into the sampling phase of the model (even for well conditioned models with “nice” posterior geometries) if one chain happens to adapt a step-size that is just long enough that it frequently reaches deeper tree-depth than the other chains. Does that make sense?

Thanks!

I meant that 3 chains go through 40% of total samples (warmup+sampling) before the 4th chain complete 1 sample - that is the first.

I’ll try to see if initial values and see if it makes a difference - though as I said I see it in a variety of models, some very simple, so I suspect that it is something more software related. I’ll report back

This was a very enlightening explanation by @jsocolar for me as I have found this post after finding one of my chains is running 4-5 times slower on a program that runs for about 4-5 hours, so this one chain takes always almost one day, which I would like to avoid in the future. It seems to be about 25% of the chains (so, with 4 chains, 1 chain takes 4x longer; with 8 chains, 2 chains take 4x longer).

I provide initial values (I’m using a system of ODEs so it’s hard without it) so I’m pretty sure it’s the treedepth issue you are describing. In fact all the treedepth warnings come from that one chain.

Is there any possible way to help this chain to lower the chance of chains lagging behind? 25% of the time seems a lot and I’m wondering if I can reduce it. If I increase treedepth to 12 for all chains, it’ll make them all run a little slower, but will it lower the chance of chains lagging behind? If one chain takes 24 hours that’s still worse than all chains taking 12 hours. But I’m not sure if that can be done.

thanks!

2 Likes