Assignment version of extract in Rstan

I’m looking for a way to invert extract() to overwrite the parameters in the stanfit object. So an assignment method extract()<-. I have a silly example below with some helper functions to demonstrate my intent. Unfortunately my version inv_extract() is pretty limited.

  • I use nested for loops because I’m trying to subset by parameter and warmup I couldn’t figure out how to do that with nested apply.
  • I actually just make a copy of the stanfit object and overwrite some of it. I don’t actually write to the original object.
  • Other slots like summary are not updated
  • Only works for one set of pars at a time. I’d have to apply over a list of pars if I had several of interest.

The goal is to take the samples for a subset of parameters, adjustment them - for example rescale them according to some asymptotic or external correction, and then overwrite the values in the stanfit object. I’d like to keep the stanfit object to facilitate the associated methods such as the plots and diagnostics.

Toy Example

#helper functions
#recaling a vector 
vscale <- function(v, center, Mscale){(v - center)%*%Mscale + center}

#trying to write to stanfit object
inv_extract <- function(stanfit, samp_ex, pars){

	simname <- names(stanfit@sim$samples[[1]])#use first chain to get individual parameter names
	gp1 <- grep(pars, simname) #just match those of interest
	warm <- stanfit@sim$warmup #need offset
	nchain <- stanfit@sim$chains #number of chains
	pdims <- stanfit@par_dims[[grep(pars,names(stanfit@par_dims))]] #parameter dimensions

	for(j in 1:nchain){for(i in 1:pdims ){#only overwrite samples after warmup
 		stanfit@sim$samples[[j]][gp1][[i]][-c(1:warm)] <-  samp_ex[,j,i]
	}}
	return(stanfit)
	}

Here’s a toy with a two-parameter logistic regression

###
#simple logistic model
N <- 5000
beta1 <- 1

x1 <- rnorm(N,0,1)

mu <- x1*beta1
theta <- plogis(mu)
y <- rbinom(n = N, size = 1, prob = theta)

#fit in stan
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())

mod  <- stan_model('logistic.stan')

X <- model.matrix( ~ x1)
k   <- dim(X)[2]
n   <- length(y)
weights <- rep(1,n)

stan_data <- list(y = as.vector(y), X = X, k = k, n = n)

out_stan  <- sampling(object = mod, data = stan_data,
                      pars = c("beta"), 
                      chains = 2, iter = 100, warmup = 50, 
                      )

Extracting the parameters, applying a silly adjustment, and writing to a new stanfit object that looks like the original but with the sample values overwritten

samp_ex <- extract(out_stan, "beta", permuted = FALSE)

pm <- colMeans(samp_ex, dim = 2)

#silly rescaling centered at posterior mean and triple the standard deviation
samp_adj <- aaply(samp_ex, c(1,2), vscale, center = pm, Mscale = diag(rep(3,2)), .drop = TRUE)
 
out_stan2 <- inv_extract(out_stan, samp_adj, "beta")

Plots are updated, reflecting the new scale

#trace plots
stan_trace(out_stan)

stan_trace(out_stan2) #note the scale increase


Summary is unchanged

#summaries #unchanged
summary(out_stan) 
summary(out_stan2)

> summary(out_stan)
$`summary`
                 mean     se_mean         sd          2.5%           25%           50%           75%         97.5%
beta[1]  2.024584e-02 0.002653286 0.03263762 -3.786105e-02 -1.352632e-03  1.430888e-02  4.446862e-02     0.0800016
beta[2]  1.045112e+00 0.004918603 0.03828606  9.600493e-01  1.028806e+00  1.052076e+00  1.073302e+00     1.1131749
lp__    -2.976962e+03 0.118281223 0.99796267 -2.979155e+03 -2.977417e+03 -2.976525e+03 -2.976227e+03 -2975.9961847
            n_eff      Rhat
beta[1] 151.31036 1.0093243
beta[2]  60.58956 0.9851996
lp__     71.18639 1.0224404

$c_summary
, , chains = chain:1

         stats
parameter          mean         sd          2.5%           25%           50%           75%         97.5%
  beta[1]  1.699874e-02 0.03345020 -3.906042e-02 -5.510431e-03  9.995533e-03  4.501224e-02  7.059195e-02
  beta[2]  1.046645e+00 0.04114109  9.602575e-01  1.033106e+00  1.057698e+00  1.075464e+00  1.115309e+00
  lp__    -2.977059e+03 0.87580622 -2.978931e+03 -2.977639e+03 -2.976746e+03 -2.976386e+03 -2.976074e+03

, , chains = chain:2

         stats
parameter          mean         sd          2.5%           25%           50%           75%         97.5%
  beta[1]  2.349295e-02 0.03180775 -3.064287e-02  6.011404e-03  1.960032e-02  4.225356e-02     0.0829192
  beta[2]  1.043579e+00 0.03555520  9.614698e-01  1.026773e+00  1.042931e+00  1.065930e+00     1.1069984
  lp__    -2.976865e+03 1.10726663 -2.979756e+03 -2.977376e+03 -2.976331e+03 -2.976145e+03 -2975.9881334


> summary(out_stan2)
$`summary`
                 mean     se_mean         sd          2.5%           25%           50%           75%         97.5%
beta[1]  2.024584e-02 0.002653286 0.03263762 -3.786105e-02 -1.352632e-03  1.430888e-02  4.446862e-02     0.0800016
beta[2]  1.045112e+00 0.004918603 0.03828606  9.600493e-01  1.028806e+00  1.052076e+00  1.073302e+00     1.1131749
lp__    -2.976962e+03 0.118281223 0.99796267 -2.979155e+03 -2.977417e+03 -2.976525e+03 -2.976227e+03 -2975.9961847
            n_eff      Rhat
beta[1] 151.31036 1.0093243
beta[2]  60.58956 0.9851996
lp__     71.18639 1.0224404

$c_summary
, , chains = chain:1

         stats
parameter          mean         sd          2.5%           25%           50%           75%         97.5%
  beta[1]  1.699874e-02 0.03345020 -3.906042e-02 -5.510431e-03  9.995533e-03  4.501224e-02  7.059195e-02
  beta[2]  1.046645e+00 0.04114109  9.602575e-01  1.033106e+00  1.057698e+00  1.075464e+00  1.115309e+00
  lp__    -2.977059e+03 0.87580622 -2.978931e+03 -2.977639e+03 -2.976746e+03 -2.976386e+03 -2.976074e+03

, , chains = chain:2

         stats
parameter          mean         sd          2.5%           25%           50%           75%         97.5%
  beta[1]  2.349295e-02 0.03180775 -3.064287e-02  6.011404e-03  1.960032e-02  4.225356e-02     0.0829192
  beta[2]  1.043579e+00 0.03555520  9.614698e-01  1.026773e+00  1.042931e+00  1.065930e+00     1.1069984
  lp__    -2.976865e+03 1.10726663 -2.979756e+03 -2.977376e+03 -2.976331e+03 -2.976145e+03 -2975.9881334

I wouldn’t do this, but if I were to do this (and I have done it in rstanarm), I would change the internal values in the list of lists inside a stanfit object so that everything that is expected to work on a stanfit object would still work.