Stan optimizing seems to yield biased MAP estimates in weighted model, especially for scale parameters.
Attached is the cmdstanr model (modified from brms) together with the corresponding standata.
stan-opt.Rdata (599.4 KB)
This model has two weights: one on subject level (weights_ID) and one on observation level. (weights) The former is derived from the 1st observation of the latter.
for (n in 1:N) {
int nn = n + start - 1;
ptarget += weights[nn] * (normal_lpdf(Y[nn] | mu[n], sigma));
}
...
target += weights_ID[n] * std_normal_lpdf(z_1[:, n]);
The point estimates for random effect dispersion terms are clearly ill under optimize and laplace
- MCMC:
mcmc <- model$sample(data=sdata, parallel_chains=2, chains=2, iter_sampling = 6000, max_treedepth=11)
mcmc$summary('sd_1')
# A tibble: 2 Ă— 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 sd_1[1] 1.17 1.17 0.145 0.144 0.932 1.41 1.00 3133. 5452.
2 sd_1[2] 0.498 0.497 0.0236 0.0234 0.460 0.537 1.00 334. 1173.
- Optimize
opt <- model$optimize(data=sdata, tol_obj = 1e-9, tol_rel_obj = 1e-3, tol_grad = 1e-9, tol_rel_grad = 1e-3, tol_param = 1e-13, init_alpha = 1e-2, algorithm = 'lbfgs', iter = 3e5)
opt$summary('sd_1')
# A tibble: 2 Ă— 2
variable estimate
<chr> <dbl>
1 sd_1[1] 3.36e+ 1
2 sd_1[2] 1.62e-29
- Laplace
laplace <- model$laplace(data=sdata, draws=1000, opt_args=list(tol_obj = 1e-9,
tol_rel_obj = 1e-3,
tol_grad = 1e-9,
tol_rel_grad = 1e-3,
tol_param = 1e-13,
init_alpha = 1e-3,
algorithm = 'lbfgs',
iter = 3e5))
laplace$summary('sd_1')
# A tibble: 2 Ă— 7
variable mean median sd mad q5 q95
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 sd_1[1] 45.5 41.8 19.1 17.1 22.7 79.0
2 sd_1[2] 9.51 8.88 4.06 3.69 4.39 16.9
Is there any tweak I could do to prevent this? Or it this a problem of gradient ascent?
P/S: VB seems to be a bit better. However, sometimes it fails (which is bad).
vb <- model$variational(sdata, adapt_iter=1000, tol_rel_obj = 1e-3)
vb$summary('sd_1')
# A tibble: 2 Ă— 7
variable mean median sd mad q5 q95
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 sd_1[1] 0.0215 0.0139 0.0251 0.0110 0.00316 0.0643
2 sd_1[2] 0.468 0.468 0.00415 0.00422 0.461 0.475