Hi,
I am using K-fold cross-validation for model comparison, so I need to use the generated quantities block to get out-of-sample predictions. I have done this before using Stan and it went just fine. However, I am now running into some weird problem:
Once I add an array of matrices in the generated quantities, it takes a very long time until the first iteration.
Can you please let me know what I can do to improve this?
I find this weird because I though that the generated quantities block is only executed once per iteration. So if the gq block is somehow computationally intensive or inefficient, this would increase the time between iterations. But it shouldn’t influence the time before the first iteration.
Description of the example.
- estimating a very simple linear model: y = \alpha + \beta * x + \epsilon
- in the generated quantities block, I include an array of matrices (or vectors) and fill it with zeros. The zero matrix (vector) is created in transformed data and then just assigned to the array.
- stan.R generates data, compiles stan code, and samples
- stan_wo_gq.stan: estimates \alpha, \beta, and \sigma (the sd of \epsilon) but doesn’t have the generated quantities block
- stan_w_gq_matrix.stan: estimates the same same model, but creates and array of matrices in the gq block
- stan_w_gq_vector.stan: estimates the same same model, but creates and array of vectors in the gq block
stan.R (5.2 KB)
stan_w_gq_matrix.stan (1.0 KB)
stan_w_gq_vector.stan (1.1 KB)
stan_wo_gq.stan (877 Bytes)
Without any gq, the estimation is very fast:
> print(get_elapsed_time(res_wo_gq))
warmup sample
chain:1 0.422 0.314
chain:2 0.416 0.284
With the gq block:
transformed data {
vector[m] O_m;
matrix[m, n] O_m_n;
O_m = rep_vector(0.0, m);
O_m_n = rep_matrix(0.0, m, n);
}
...
generated quantities {
matrix[m, n] sim_matrix[p, q];
for(i in 1:p) {
for(j in 1:q) {
sim_matrix[i, j] = O_m_n;
}
}
print("End gq ");
}
For matrix[70, 5] sim_matrix[20, 50]
, it takes way to long until the first iteration (more than 30 minutes).
For matrix[70, 5] sim_matrix[20, 20]
, it takes about 5 minutes until the first iteration, then about 20 times longer for estimation (as compared to wo_gq). In total, it takes about 9 minutes (it seems to get stuck also after sampling is done).
> print(get_elapsed_time(res_w_gq_matrix))
warmup sample
chain:1 2.098 2.062
chain:2 2.197 2.173
This is what I have tried so far:
- the good old Google search. I’ve found some previous discussions that seem related to this, but I didn’t find an answer and it’s not clear if this issue has been solved yet (It is closed on GitHub).
https://github.com/stan-dev/stan/issues/2516 - added print statements in the stan code to see where it gets stuck.
- tried with an array of vectors instead or matrices, but kept the same number of values to be stored.
- played with the dimension of the matrices and of the arrays, keeping the same number of values
- initialized the array in the gq
- used
pars = c("sim_matrix"), include = TRUE
Nothing seems to work. I would need 3 of these arrays with dimsmatrix[70, 5] sim_matrix[20, 50]
. Is that too much? I would think the PC I am using has enough RAM (64GB) and during estimation it never went even close to the max (it goes to 13GB at most).
I would appreciate some advice on what I can do to speed this up, while still getting the simulated values I need.
Ana
Viewer pane stays like this for about 30 minutes before the first iteration:
> sessionInfo()
R version 3.5.1 (2018-07-02)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows >= 8 x64 (build 9200)
Matrix products: default
locale:
[1] LC_COLLATE=Dutch_Netherlands.1252 LC_CTYPE=Dutch_Netherlands.1252 LC_MONETARY=Dutch_Netherlands.1252 LC_NUMERIC=C
[5] LC_TIME=Dutch_Netherlands.1252
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] rstan_2.18.2 StanHeaders_2.18.0-1 ggplot2_3.1.0 loo_2.0.0
loaded via a namespace (and not attached):
[1] Rcpp_1.0.0 pillar_1.3.1 compiler_3.5.1 plyr_1.8.4 bindr_0.1.1 prettyunits_1.0.2 tools_3.5.1 pkgbuild_1.0.2
[9] tibble_1.4.2 gtable_0.2.0 pkgconfig_2.0.2 rlang_0.3.0.1 cli_1.0.1 rstudioapi_0.8 yaml_2.2.0 parallel_3.5.1
[17] bindrcpp_0.2.2 gridExtra_2.3 withr_2.1.2 dplyr_0.7.8 stats4_3.5.1 grid_3.5.1 tidyselect_0.2.5 glue_1.3.0
[25] inline_0.3.15 R6_2.3.0 processx_3.2.1 purrr_0.2.5 callr_3.1.1 magrittr_1.5 codetools_0.2-16 scales_1.0.0
[33] ps_1.3.0 matrixStats_0.54.0 assertthat_0.2.0 colorspace_1.3-2 lazyeval_0.2.1 munsell_0.5.0 crayon_1.3.4