I’m trying to fit a mixture of a Wiener likelihood model, and a small probability uniform likelihood to account for outliers. The model is described in this thread.
The model runs just fine in rstan, but using cmdstanr to interface with cmdstan, I get many many divergent transitions, which was the original problem in the thread above, since solved for rstan by the model code I’m using.
This is the model:
functions {
real wiener_diffusion2_lpdf(real y, real mu, real bs,
real ndt, real bias, real lambda,
int dec, real min_rt, real max_rt) {
if (y < ndt) {
return(log(lambda) + uniform_lpdf(y | min_rt, max_rt));
} else {
if (dec == 1) {
return log_mix(lambda,
uniform_lpdf(y | min_rt, max_rt),
wiener_lpdf(y | bs, ndt, bias, mu)
);
} else {
return log_mix(lambda,
uniform_lpdf(y | min_rt, max_rt),
wiener_lpdf(y | bs, ndt, 1 - bias, - mu)
);
}
}
}
}
data {
int<lower=0> N; // Total trial count
vector[N] rt; // RT responses in s
int <lower=0,upper=1> dec[N]; // decisions
int<lower=1> K; // number of effects for drift rate
matrix[N, K] X; // design matrix for drift rate
real lambda; // Fixed lapse rate
int prior_only; // should the likelihood be ignored?
}
transformed data{
real maxRT = max(rt);
}
parameters {
vector[K] b; // population-level effects
real<lower=0> bs; // boundary separation population parameter
real<lower=0> ndt; // non-decision time population parameter
real<lower=0,upper=1> bias; // initial bias parameter
}
model {
// Prior on drift-rate coefficients
b ~ normal(0, 10);
// Boundary seperation constrained to be positive, for identifiability.
bs ~ normal(0, 2.5);
// Assuming that non decision time is distributed lognormally in the population
ndt ~ lognormal(-1.2, 0.5);
// Assuming that bias is distributed from a beta distribution centered on 0.5.
bias ~ beta(15, 15);
// Likelihood
if (!prior_only) {
// Compute linear predictor term
vector[N] mu = X * b;
// Compute likelihood per trial
for (n in 1:N) {
target += wiener_diffusion2_lpdf(rt[n] | mu[n], bs, ndt, bias, lambda,
dec[n], 0.0, maxRT);
}
}
}
And then with rstan:
lapse_fit0 <- stan("ddm_lapse0.stan",
data = dat_list,
chains = 4,
cores = 4,
seed = 34)
I get:
Inference for Stan model: ddm_lapse0.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
b[1] 9.78 0.00 0.17 9.44 9.67 9.77 9.89 10.11 2738 1
bs 1.98 0.00 0.03 1.93 1.96 1.98 2.00 2.03 2659 1
ndt 0.20 0.00 0.00 0.19 0.20 0.20 0.20 0.20 2684 1
bias 0.51 0.00 0.00 0.50 0.50 0.51 0.51 0.51 3263 1
lp__ -143.50 0.03 1.41 -147.03 -144.18 -143.15 -142.47 -141.77 1894 1
Which is spot on.
But with cmdstanr:
ddm_lapse0 <- cmdstan_model("ddm_lapse0.stan")
lapse_fit0cmd <- ddm_lapse0$sample(
data = dat_list,
chains = 4,
parallel_chains = 4,
seed = 34,
)
I get:
Warning: 2973 of 4000 (74.0%) transitions ended with a divergence.
This may indicate insufficient exploration of the posterior distribution.
Possible remedies include:
* Increasing adapt_delta closer to 1 (default is 0.8)
* Reparameterizing the model (e.g. using a non-centered parameterization)
* Using informative or weakly informative prior distributions
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
lp__ -8400.15 -10478.83 4921.36 2514.66 -12555.60 -143.31 3.06 4 26
b[1] 2.40 0.07 4.26 0.31 -0.30 9.92 2.93 4 12
bs 1.41 1.37 0.45 0.60 0.89 2.00 3.54 4 11
ndt 0.67 0.64 0.38 0.47 0.20 1.20 2.85 4 31
bias 0.38 0.41 0.15 0.18 0.16 0.54 3.92 4 11
This is the R code for generating the fake data used here:
library(RWiener)
library(brms)
library(data.table)
# Set parameters for test
bs = 2
k = 10
ndt = 0.2
bias = 0.5
lambda = 0.02
tightspacing = ((seq(0, sqrt(0.5), length.out = 5))^2)[2:5]
xs = c(-1, -tightspacing, 0, tightspacing , 1)
n = 200
# Simulate RTs and responses
rdat <- function(x) {
dat <- rwiener(n, alpha = bs, tau = ndt, beta = bias, delta = k * x)
dat$lapsed <- runif(n) <= lambda
dat$rt <- dat$lapsed * runif(n, 0, 5) + (1 - dat$lapsed) * dat$q
dat$dec <- dat$lapsed * (runif(n) > 0.5) + (1 - dat$lapsed) * dat$resp
dat$x <- x
return(dat)
}
set.seed(34)
dat1 <- as.data.table(do.call(rbind, lapply(xs, rdat)))
# Add one pernicious trial
dat1 <- rbind(dat1, data.table(q = 0.1, resp = 1, lapsed = T, rt = 0.1, dec = 1, x = xs[3]))
dat_list <- list(N = nrow(dat1),
rt = dat1$rt,
dec = dat1$dec,
K = 1,
X = model.matrix(rt ~ 0 + x, dat1),
lambda = 0.01,
prior_only = 0)
Session info:
sessionInfo()
R version 4.0.3 (2020-10-10)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 18.04.5 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/blas/libblas.so.3.7.1
LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.7.1
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8 LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=en_US.UTF-8 LC_NAME=C LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] brms_2.14.4 Rcpp_1.0.6 rstan_2.21.2 StanHeaders_2.21.0-7 cmdstanr_0.3.0 cowplot_1.1.1 data.table_1.13.6 ggplot2_3.3.3
loaded via a namespace (and not attached):
[1] minqa_1.2.4 colorspace_2.0-0 ellipsis_0.3.1 ggridges_0.5.3 rsconnect_0.8.16 markdown_1.1 base64enc_0.1-3 rstudioapi_0.13
[9] farver_2.0.3 DT_0.17 fansi_0.4.2 mvtnorm_1.1-1 bridgesampling_1.0-0 codetools_0.2-16 splines_4.0.3 knitr_1.31
[17] shinythemes_1.2.0 bayesplot_1.8.0 projpred_2.0.2 jsonlite_1.7.2 nloptr_1.2.2.2 packrat_0.5.0 shiny_1.6.0 compiler_4.0.3
[25] backports_1.2.1 assertthat_0.2.1 Matrix_1.2-18 fastmap_1.1.0 cli_2.2.0 later_1.1.0.1 htmltools_0.5.1.1 prettyunits_1.1.1
[33] tools_4.0.3 igraph_1.2.6 coda_0.19-4 gtable_0.3.0 glue_1.4.2 posterior_0.1.3 RWiener_1.3-3 reshape2_1.4.4
[41] dplyr_1.0.2 V8_3.4.0 vctrs_0.3.6 nlme_3.1-149 crosstalk_1.1.1 xfun_0.20 stringr_1.4.0 ps_1.5.0
[49] lme4_1.1-26 mime_0.9 miniUI_0.1.1.1 lifecycle_0.2.0 gtools_3.8.2 statmod_1.4.35 MASS_7.3-53 zoo_1.8-8
[57] scales_1.1.1 colourpicker_1.1.0 promises_1.1.1 Brobdingnag_1.2-6 parallel_4.0.3 inline_0.3.17 shinystan_2.5.0 gamm4_0.2-6
[65] yaml_2.2.1 curl_4.3 gridExtra_2.3 loo_2.4.1 stringi_1.5.3 dygraphs_1.1.1.6 checkmate_2.0.0 boot_1.3-25
[73] pkgbuild_1.2.0 rlang_0.4.10 pkgconfig_2.0.3 matrixStats_0.57.0 evaluate_0.14 lattice_0.20-41 purrr_0.3.4 rstantools_2.1.1
[81] htmlwidgets_1.5.3 labeling_0.4.2 processx_3.4.5 tidyselect_1.1.0 plyr_1.8.6 magrittr_2.0.1 R6_2.5.0 generics_0.0.2
[89] pillar_1.4.7 withr_2.4.1 mgcv_1.8-33 xts_0.12.1 abind_1.4-5 tibble_3.0.5 crayon_1.3.4 utf8_1.1.4
[97] rmarkdown_2.4 grid_4.0.3 callr_3.5.1 threejs_0.3.3 digest_0.6.27 xtable_1.8-4 httpuv_1.5.5 RcppParallel_5.0.2
[105] stats4_4.0.3 munsell_0.5.0 shinyjs_2.0.0
and cmdstan version:
cmdstan_version()
[1] "2.26.0"