So @bbbales2 solved this by discovering that (1) I wasn’t composing the inits properly and (2) cmdstan was silently ignoring the inits because they weren’t in the right format and instead using random inits.
I think that the following (super kludgey, possibly-over-tidyversed) code should do the init composition better, handling single value parameters, vector parameters, and matrices. It should handle single and 2 dimensional arrays too, but I haven’t tried those. N-dimensional arrays are probably a fairly straightforward generalization, but I’ll leave that for another time.
get_inits = function(chain_id){
warmup_draws = warmup$draws(inc_warmup=T)
final_warmup_value = warmup_draws[iter_warmup,chain_id,2:(dim(warmup_draws)[3])]
(
final_warmup_value
%>% tibble::as_tibble(
.name_repair = function(names){
dimnames(final_warmup_value)$variable
}
)
%>% tidyr::pivot_longer(cols=dplyr::everything())
%>% tidyr::separate(
name
, into = c('variable','index')
, sep = "\\["
, fill = 'right'
)
%>% dplyr::group_split(variable)
%>% purrr::map(
.f = function(x){
(
x
%>% dplyr::mutate(
index = stringr::str_replace(index,']','')
)
%>% tidyr::separate(
index
, into = c('first','second')
, sep = ","
, fill = 'right'
, convert = T
)
%>% dplyr::arrange(first,second)
%>% (function(x){
out = list()
if(all(is.na(x$second))){
out[[1]] = x$value
}else{
out[[1]] = matrix(
x$value
, nrow = max(x$first)
, ncol = max(x$second)
)
}
names(out) = x$variable[1]
return(out)
})
)
}
)
%>% unlist(recursive=F)
%>% return()
)
}
#example usage:
warmup = mod$sample(
data = data_for_stan
, chains = parallel_chains
, parallel_chains = parallel_chains
, seed = seed
, iter_warmup = iter_warmup
, save_warmup = T #for inits
, sig_figs = 18
, iter_sampling = 0
)
samples = mod$sample(
data = data_for_stan
, chains = parallel_chains
, parallel_chains = parallel_chains
, seed = seed+1
, iter_warmup = 0
, adapt_engaged = FALSE
, inv_metric = warmup$inv_metric(matrix=F)
, step_size = warmup$metadata()$step_size_adaptation
, iter_sampling = iter_sample
, init = get_inits
)