I’m looking for a way to invert extract() to overwrite the parameters in the stanfit object. So an assignment method extract()<-. I have a silly example below with some helper functions to demonstrate my intent. Unfortunately my version inv_extract() is pretty limited.
- I use nested for loops because I’m trying to subset by parameter and warmup I couldn’t figure out how to do that with nested apply.
- I actually just make a copy of the stanfit object and overwrite some of it. I don’t actually write to the original object.
- Other slots like summary are not updated
- Only works for one set of pars at a time. I’d have to apply over a list of pars if I had several of interest.
The goal is to take the samples for a subset of parameters, adjustment them - for example rescale them according to some asymptotic or external correction, and then overwrite the values in the stanfit object. I’d like to keep the stanfit object to facilitate the associated methods such as the plots and diagnostics.
Toy Example
#helper functions
#recaling a vector
vscale <- function(v, center, Mscale){(v - center)%*%Mscale + center}
#trying to write to stanfit object
inv_extract <- function(stanfit, samp_ex, pars){
simname <- names(stanfit@sim$samples[[1]])#use first chain to get individual parameter names
gp1 <- grep(pars, simname) #just match those of interest
warm <- stanfit@sim$warmup #need offset
nchain <- stanfit@sim$chains #number of chains
pdims <- stanfit@par_dims[[grep(pars,names(stanfit@par_dims))]] #parameter dimensions
for(j in 1:nchain){for(i in 1:pdims ){#only overwrite samples after warmup
stanfit@sim$samples[[j]][gp1][[i]][-c(1:warm)] <- samp_ex[,j,i]
}}
return(stanfit)
}
Here’s a toy with a two-parameter logistic regression
###
#simple logistic model
N <- 5000
beta1 <- 1
x1 <- rnorm(N,0,1)
mu <- x1*beta1
theta <- plogis(mu)
y <- rbinom(n = N, size = 1, prob = theta)
#fit in stan
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
mod <- stan_model('logistic.stan')
X <- model.matrix( ~ x1)
k <- dim(X)[2]
n <- length(y)
weights <- rep(1,n)
stan_data <- list(y = as.vector(y), X = X, k = k, n = n)
out_stan <- sampling(object = mod, data = stan_data,
pars = c("beta"),
chains = 2, iter = 100, warmup = 50,
)
Extracting the parameters, applying a silly adjustment, and writing to a new stanfit object that looks like the original but with the sample values overwritten
samp_ex <- extract(out_stan, "beta", permuted = FALSE)
pm <- colMeans(samp_ex, dim = 2)
#silly rescaling centered at posterior mean and triple the standard deviation
samp_adj <- aaply(samp_ex, c(1,2), vscale, center = pm, Mscale = diag(rep(3,2)), .drop = TRUE)
out_stan2 <- inv_extract(out_stan, samp_adj, "beta")
Plots are updated, reflecting the new scale
#trace plots
stan_trace(out_stan)
stan_trace(out_stan2) #note the scale increase
Summary is unchanged
#summaries #unchanged
summary(out_stan)
summary(out_stan2)
> summary(out_stan)
$`summary`
mean se_mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 2.024584e-02 0.002653286 0.03263762 -3.786105e-02 -1.352632e-03 1.430888e-02 4.446862e-02 0.0800016
beta[2] 1.045112e+00 0.004918603 0.03828606 9.600493e-01 1.028806e+00 1.052076e+00 1.073302e+00 1.1131749
lp__ -2.976962e+03 0.118281223 0.99796267 -2.979155e+03 -2.977417e+03 -2.976525e+03 -2.976227e+03 -2975.9961847
n_eff Rhat
beta[1] 151.31036 1.0093243
beta[2] 60.58956 0.9851996
lp__ 71.18639 1.0224404
$c_summary
, , chains = chain:1
stats
parameter mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 1.699874e-02 0.03345020 -3.906042e-02 -5.510431e-03 9.995533e-03 4.501224e-02 7.059195e-02
beta[2] 1.046645e+00 0.04114109 9.602575e-01 1.033106e+00 1.057698e+00 1.075464e+00 1.115309e+00
lp__ -2.977059e+03 0.87580622 -2.978931e+03 -2.977639e+03 -2.976746e+03 -2.976386e+03 -2.976074e+03
, , chains = chain:2
stats
parameter mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 2.349295e-02 0.03180775 -3.064287e-02 6.011404e-03 1.960032e-02 4.225356e-02 0.0829192
beta[2] 1.043579e+00 0.03555520 9.614698e-01 1.026773e+00 1.042931e+00 1.065930e+00 1.1069984
lp__ -2.976865e+03 1.10726663 -2.979756e+03 -2.977376e+03 -2.976331e+03 -2.976145e+03 -2975.9881334
> summary(out_stan2)
$`summary`
mean se_mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 2.024584e-02 0.002653286 0.03263762 -3.786105e-02 -1.352632e-03 1.430888e-02 4.446862e-02 0.0800016
beta[2] 1.045112e+00 0.004918603 0.03828606 9.600493e-01 1.028806e+00 1.052076e+00 1.073302e+00 1.1131749
lp__ -2.976962e+03 0.118281223 0.99796267 -2.979155e+03 -2.977417e+03 -2.976525e+03 -2.976227e+03 -2975.9961847
n_eff Rhat
beta[1] 151.31036 1.0093243
beta[2] 60.58956 0.9851996
lp__ 71.18639 1.0224404
$c_summary
, , chains = chain:1
stats
parameter mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 1.699874e-02 0.03345020 -3.906042e-02 -5.510431e-03 9.995533e-03 4.501224e-02 7.059195e-02
beta[2] 1.046645e+00 0.04114109 9.602575e-01 1.033106e+00 1.057698e+00 1.075464e+00 1.115309e+00
lp__ -2.977059e+03 0.87580622 -2.978931e+03 -2.977639e+03 -2.976746e+03 -2.976386e+03 -2.976074e+03
, , chains = chain:2
stats
parameter mean sd 2.5% 25% 50% 75% 97.5%
beta[1] 2.349295e-02 0.03180775 -3.064287e-02 6.011404e-03 1.960032e-02 4.225356e-02 0.0829192
beta[2] 1.043579e+00 0.03555520 9.614698e-01 1.026773e+00 1.042931e+00 1.065930e+00 1.1069984
lp__ -2.976865e+03 1.10726663 -2.979756e+03 -2.977376e+03 -2.976331e+03 -2.976145e+03 -2975.9881334