Sampling works in rstan but not cmdstan with a wiener mixture model

I’m trying to fit a mixture of a Wiener likelihood model, and a small probability uniform likelihood to account for outliers. The model is described in this thread.

The model runs just fine in rstan, but using cmdstanr to interface with cmdstan, I get many many divergent transitions, which was the original problem in the thread above, since solved for rstan by the model code I’m using.

This is the model:

functions {
  real wiener_diffusion2_lpdf(real y, real mu, real bs, 
                            real ndt, real bias, real lambda, 
                            int dec, real min_rt, real max_rt) {
  if (y < ndt) {
    return(log(lambda) + uniform_lpdf(y | min_rt, max_rt));
  } else {
    if (dec == 1) {
      return log_mix(lambda, 
        uniform_lpdf(y | min_rt, max_rt),
        wiener_lpdf(y | bs, ndt, bias, mu)
      );
    } else {
      return log_mix(lambda, 
        uniform_lpdf(y | min_rt, max_rt),
        wiener_lpdf(y | bs, ndt, 1 - bias, - mu)
      );
    }
  }
}
}

data {
  int<lower=0> N; // Total trial count
  vector[N] rt; // RT responses in s
  int <lower=0,upper=1> dec[N];  // decisions
  
  int<lower=1> K;  // number of effects for drift rate
  matrix[N, K] X;  // design matrix for drift rate
  
  real lambda; // Fixed lapse rate
  
  int prior_only;  // should the likelihood be ignored?
}

transformed data{
  real maxRT = max(rt);
}

parameters {
  vector[K] b;  // population-level effects
  
  real<lower=0> bs;  // boundary separation population parameter
  
  real<lower=0> ndt;  // non-decision time population parameter

  real<lower=0,upper=1> bias;  // initial bias parameter
}

model {
  // Prior on drift-rate coefficients
  b ~ normal(0, 10);
  
  // Boundary seperation constrained to be positive, for identifiability.
  bs ~ normal(0, 2.5);

  // Assuming that non decision time is distributed lognormally in the population 
  ndt ~ lognormal(-1.2, 0.5);

  // Assuming that bias is distributed from a beta distribution centered on 0.5. 
  bias ~ beta(15, 15);
  
  // Likelihood
  if (!prior_only) {
    // Compute linear predictor term
    vector[N] mu = X * b;
    
    // Compute likelihood per trial
    for (n in 1:N) {
      target += wiener_diffusion2_lpdf(rt[n] | mu[n], bs, ndt, bias, lambda, 
                                          dec[n], 0.0, maxRT);
    }
  }
}

And then with rstan:

  lapse_fit0 <- stan("ddm_lapse0.stan",
                  data = dat_list,
                  chains = 4,
                  cores = 4,
                  seed = 34)

I get:

Inference for Stan model: ddm_lapse0.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

        mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
b[1]    9.78    0.00 0.17    9.44    9.67    9.77    9.89   10.11  2738    1
bs      1.98    0.00 0.03    1.93    1.96    1.98    2.00    2.03  2659    1
ndt     0.20    0.00 0.00    0.19    0.20    0.20    0.20    0.20  2684    1
bias    0.51    0.00 0.00    0.50    0.50    0.51    0.51    0.51  3263    1
lp__ -143.50    0.03 1.41 -147.03 -144.18 -143.15 -142.47 -141.77  1894    1

Which is spot on.

But with cmdstanr:

ddm_lapse0 <- cmdstan_model("ddm_lapse0.stan")
lapse_fit0cmd <- ddm_lapse0$sample(
  data = dat_list,
  chains = 4,
  parallel_chains = 4,
  seed = 34,
)

I get:

Warning: 2973 of 4000 (74.0%) transitions ended with a divergence.
This may indicate insufficient exploration of the posterior distribution.
Possible remedies include: 
  * Increasing adapt_delta closer to 1 (default is 0.8) 
  * Reparameterizing the model (e.g. using a non-centered parameterization)
  * Using informative or weakly informative prior distributions 

variable     mean    median      sd     mad        q5     q95 rhat ess_bulk ess_tail
     lp__ -8400.15 -10478.83 4921.36 2514.66 -12555.60 -143.31 3.06        4       26
     b[1]     2.40      0.07    4.26    0.31     -0.30    9.92 2.93        4       12
     bs       1.41      1.37    0.45    0.60      0.89    2.00 3.54        4       11
     ndt      0.67      0.64    0.38    0.47      0.20    1.20 2.85        4       31
     bias     0.38      0.41    0.15    0.18      0.16    0.54 3.92        4       11

This is the R code for generating the fake data used here:

library(RWiener)
library(brms)
library(data.table)

# Set parameters for test
bs = 2
k = 10
ndt = 0.2
bias = 0.5
lambda = 0.02
tightspacing = ((seq(0, sqrt(0.5), length.out = 5))^2)[2:5]
xs = c(-1, -tightspacing, 0, tightspacing , 1)
n = 200

# Simulate RTs and responses
rdat <- function(x) {
  dat <- rwiener(n, alpha = bs, tau = ndt, beta = bias, delta = k * x)
  dat$lapsed <- runif(n) <= lambda
  dat$rt <- dat$lapsed * runif(n, 0, 5) + (1 - dat$lapsed) * dat$q
  dat$dec <- dat$lapsed * (runif(n) > 0.5) + (1 - dat$lapsed) * dat$resp
  dat$x <- x
  return(dat)
}
set.seed(34)
dat1 <- as.data.table(do.call(rbind, lapply(xs, rdat)))

# Add one pernicious trial
dat1 <- rbind(dat1, data.table(q = 0.1, resp = 1, lapsed = T, rt = 0.1, dec = 1, x = xs[3]))

dat_list <- list(N = nrow(dat1),
              rt = dat1$rt,
              dec = dat1$dec,
              K = 1,
              X = model.matrix(rt ~ 0 + x, dat1),
              lambda = 0.01,
              prior_only = 0)

Session info:

sessionInfo()

R version 4.0.3 (2020-10-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.5 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1

locale:
 [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C               LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8     LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                  LC_ADDRESS=C               LC_TELEPHONE=C             LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] brms_2.14.4          Rcpp_1.0.6           rstan_2.21.2         StanHeaders_2.21.0-7 cmdstanr_0.3.0       cowplot_1.1.1        data.table_1.13.6    ggplot2_3.3.3       

loaded via a namespace (and not attached):
  [1] minqa_1.2.4          colorspace_2.0-0     ellipsis_0.3.1       ggridges_0.5.3       rsconnect_0.8.16     markdown_1.1         base64enc_0.1-3      rstudioapi_0.13     
  [9] farver_2.0.3         DT_0.17              fansi_0.4.2          mvtnorm_1.1-1        bridgesampling_1.0-0 codetools_0.2-16     splines_4.0.3        knitr_1.31          
 [17] shinythemes_1.2.0    bayesplot_1.8.0      projpred_2.0.2       jsonlite_1.7.2       nloptr_1.2.2.2       packrat_0.5.0        shiny_1.6.0          compiler_4.0.3      
 [25] backports_1.2.1      assertthat_0.2.1     Matrix_1.2-18        fastmap_1.1.0        cli_2.2.0            later_1.1.0.1        htmltools_0.5.1.1    prettyunits_1.1.1   
 [33] tools_4.0.3          igraph_1.2.6         coda_0.19-4          gtable_0.3.0         glue_1.4.2           posterior_0.1.3      RWiener_1.3-3        reshape2_1.4.4      
 [41] dplyr_1.0.2          V8_3.4.0             vctrs_0.3.6          nlme_3.1-149         crosstalk_1.1.1      xfun_0.20            stringr_1.4.0        ps_1.5.0            
 [49] lme4_1.1-26          mime_0.9             miniUI_0.1.1.1       lifecycle_0.2.0      gtools_3.8.2         statmod_1.4.35       MASS_7.3-53          zoo_1.8-8           
 [57] scales_1.1.1         colourpicker_1.1.0   promises_1.1.1       Brobdingnag_1.2-6    parallel_4.0.3       inline_0.3.17        shinystan_2.5.0      gamm4_0.2-6         
 [65] yaml_2.2.1           curl_4.3             gridExtra_2.3        loo_2.4.1            stringi_1.5.3        dygraphs_1.1.1.6     checkmate_2.0.0      boot_1.3-25         
 [73] pkgbuild_1.2.0       rlang_0.4.10         pkgconfig_2.0.3      matrixStats_0.57.0   evaluate_0.14        lattice_0.20-41      purrr_0.3.4          rstantools_2.1.1    
 [81] htmlwidgets_1.5.3    labeling_0.4.2       processx_3.4.5       tidyselect_1.1.0     plyr_1.8.6           magrittr_2.0.1       R6_2.5.0             generics_0.0.2      
 [89] pillar_1.4.7         withr_2.4.1          mgcv_1.8-33          xts_0.12.1           abind_1.4-5          tibble_3.0.5         crayon_1.3.4         utf8_1.1.4          
 [97] rmarkdown_2.4        grid_4.0.3           callr_3.5.1          threejs_0.3.3        digest_0.6.27        xtable_1.8-4         httpuv_1.5.5         RcppParallel_5.0.2  
[105] stats4_4.0.3         munsell_0.5.0        shinyjs_2.0.0       

and cmdstan version:

cmdstan_version()
[1] "2.26.0"
4 Likes

Your reprex doesn’t produce valid data for the model, can you fix that so I can replicate locally?

1 Like

Do you mean that the data is not formatted correctly for the model? rstan agrees to run for me, on two different set ups…

Maybe my post wasn’t written clearly enough. You can find a reprex here.

However - trying this on another set up I realized that this isn’t an cmdstan issue: On my second set up I don’t get parameter recovery with rstan either. These are the details of the environment where rstan didn’t work. You can see the sessionInfo above for the environment it did work (coincidentally that’s on a CodeOcean capsule which I can share if it helps).

sessionInfo()

R version 3.6.0 (2019-04-26)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14.6

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] rstan_2.19.3         StanHeaders_2.21.0-1 brms_2.14.4          Rcpp_1.0.4.6         cowplot_1.0.0       
[6] data.table_1.12.8    ggplot2_3.3.0        RWiener_1.3-3       

loaded via a namespace (and not attached):
 [1] nlme_3.1-147         matrixStats_0.56.0   xts_0.12-0           threejs_0.3.3        backports_1.1.6     
 [6] tools_3.6.0          R6_2.4.1             DT_0.13              mgcv_1.8-31          projpred_2.0.2      
[11] colorspace_1.4-1     withr_2.4.1          prettyunits_1.1.1    tidyselect_1.0.0     gridExtra_2.3       
[16] processx_3.4.5       Brobdingnag_1.2-6    compiler_3.6.0       cli_2.0.2            shinyjs_1.1         
[21] labeling_0.3         colourpicker_1.0     scales_1.1.0         dygraphs_1.1.1.6     mvtnorm_1.1-0       
[26] ggridges_0.5.2       callr_3.4.3          stringr_1.4.0        digest_0.6.25        minqa_1.2.4         
[31] base64enc_0.1-3      pkgconfig_2.0.3      htmltools_0.4.0      lme4_1.1-23          fastmap_1.0.1       
[36] htmlwidgets_1.5.1    rlang_0.4.10         rstudioapi_0.11      shiny_1.4.0.2        farver_2.0.3        
[41] zoo_1.8-7            crosstalk_1.1.0.1    gtools_3.8.2         dplyr_0.8.5          inline_0.3.15       
[46] magrittr_1.5         loo_2.4.1            bayesplot_1.7.1      Matrix_1.2-18        munsell_0.5.0       
[51] fansi_0.4.1          abind_1.4-5          lifecycle_0.2.0      stringi_1.4.6        MASS_7.3-51.5       
[56] pkgbuild_1.0.6       plyr_1.8.6           grid_3.6.0           parallel_3.6.0       promises_1.1.0      
[61] crayon_1.3.4         miniUI_0.1.1.1       lattice_0.20-41      splines_3.6.0        knitr_1.28          
[66] ps_1.3.2             pillar_1.4.3         igraph_1.2.5         boot_1.3-24          markdown_1.1        
[71] shinystan_2.5.0      codetools_0.2-16     reshape2_1.4.4       stats4_3.6.0         rstantools_2.1.1    
[76] glue_1.4.0           packrat_0.5.0        vctrs_0.2.4          nloptr_1.2.2.1       httpuv_1.5.2        
[81] gtable_0.3.0         purrr_0.3.3          assertthat_0.2.1     xfun_0.13            mime_0.9            
[86] xtable_1.8-4         coda_0.19-3          later_1.0.0          rsconnect_0.8.16     tibble_3.0.0        
[91] shinythemes_1.1.2    gamm4_0.2-6          statmod_1.4.34       ellipsis_0.3.0       bridgesampling_1.0-0
1 Like

Are you still seeing a difference in rstan and cmdstanr?

Or is the problem now that the fits only intermittently converge?

for the same model, same data, same seed, I see differences between the different set ups above:

rstan 2.21.2 converges
cmdstan 2.26.0 does not
rstan 2.19.3 does not

Can you let the seed vary and run a few times in each case? Will dig into this later if cmdstanr is still not working, but the seed is not consistent across versions or interfaces so that’s worth trying.

1 Like

Trying three seeds, indeed I get different results: perfect recovery with the first seed I chose, but no convergence with two other seeds.

How does one go about figuring this out and making the model more robust?

By eyeballing the initial values, I can’t spot what’s special about the seed that worked…

Looking at the traces for the failed seeds - they are completely flat, there is no movement within chain almost at all

Leaving the seed unspecified helps you catch these issues better (and then you can set the seed to debug them farther).

This would fit in as a workflow problem: https://statmodeling.stat.columbia.edu/2020/11/10/bayesian-workflow/ – fit isn’t really working – what to do?

Various things to try:

  1. Make predictions to see if the chains that ‘fit’ are actually fitting the data
  2. Run the model with simulated data instead of real data
  3. Make the inits smaller (for instance, init = 0.5 or something). It might be that Stan’s default inits are just bad for your model. By default they’re uniformly initialized [-2, 2] on the unconstrained scale.

The workflow paper has ideas like that.

1 Like

Thanks for the leads!

It turns out, as you suggested, that both my troubles here, and in the previous thread were such that any change to the model or the seed moved me away from the very particular trajectories I (un)luckily happened to stumble upon at first try.

The solution was to give initial values for the ndt parameter, which is non decision time for the Wiener likelihood. The way the mixture model is set up, if rt < ndt, only the uniform likelihood component contributes to the likelihood. If the initial value of ndt is too high, too much of the data is assigned only the uniform component, and so the shape of the data does nothing to inform the fitting. Hence, the sampler can’t leave that area in parameter space, I guess.

Setting very small initial values for ndt (e.g. 0.01) solves this.

3 Likes