In an effort to improve my understanding and development of stan models, I am “working backwards” from brms to “pure” stan to compare outputs of what should be the same model.
## BRMS model
## Model formula
mod1_fmla <- bf(CR_Present ~ Condition +
Session_Type +
Facilitator +
Condition:Session_Type +
Facilitator:Session_Type +
(1 | Condition / Group))
## Specify global priors
priors <- c(
set_prior("normal(0,1)", class = "b"),
set_prior("normal(0,1)", class = "Intercept"),
set_prior("cauchy(0,1)", class = "sd"),
set_prior("cauchy(0,1)", class = "sd", coef = "Intercept", group = "Condition"),
set_prior("cauchy(0,1)", class = "sd", coef = "Intercept", group = "Condition:Group")
)
## Compile and run model
BRMS_Mod <- brm(
mod1_fmla,
How_Data,
family = "bernoulli",
prior = priors,
inits = 0,
iter = 2000,
warmup = 1000,
chains = 2,
cores = ncores,
sample_prior = TRUE,
seed = 1234,
control = list(max_treedepth = 12,
adapt_delta = 0.99
)
)
There are 14 parameters in the fitted model:
BRMS_Mod[["fit"]]@model_pars
[1] "b" "Intercept" "sd_1"
[4] "z_1" "sd_2" "z_2"
[7] "r_1_1" "r_2_1" "b_Intercept"
[10] "prior_b" "prior_Intercept" "prior_sd_1_1"
[13] "prior_sd_2_1" "lp__"
The summary has 15 paramter estimates:
Population-Level Effects:
Estimate
Intercept -0.20
ConditionIntervention 0.22
Session_TypeInitial_Homecare -1.54
Session_TypeInitial_Debrief -1.17
Session_TypeFollowMup_Homecare -2.92
Session_TypeFollowMup_Debrief -1.45
FacilitatorFacilitator2 -0.26
ConditionIntervention:Session_TypeInitial_Homecare -1.05
ConditionIntervention:Session_TypeInitial_Debrief 0.32
ConditionIntervention:Session_TypeFollowMup_Homecare -0.99
ConditionIntervention:Session_TypeFollowMup_Debrief -0.13
Session_TypeInitial_Homecare:FacilitatorFacilitator2 -0.44
Session_TypeInitial_Debrief:FacilitatorFacilitator2 0.75
Session_TypeFollowMup_Homecare:FacilitatorFacilitator2 -0.65
Session_TypeFollowMup_Debrief:FacilitatorFacilitator2 0.22
For comparison, I fit the same model using the stan_data()
and stancode()
functions from BRMS
Stan_Fit <- sampling(
stan_model(model_code = Stan_Code,
auto_write = rstan_options("auto_write")),
data = standata(BRMS_Mod),
iter = 2000,
warmup = 1000,
init = 0,
chains = 2,
cores = ncores,
seed = 1234,
control = list(max_treedepth = 12,
adapt_delta = 0.99)
)
The Stan model has the same estimated parameters:
Stan_Fit@model_pars
[1] "b" "Intercept" "sd_1"
[4] "z_1" "sd_2" "z_2"
[7] "r_1_1" "r_2_1" "b_Intercept"
[10] "prior_b" "prior_Intercept" "prior_sd_1_1"
[13] "prior_sd_2_1" "lp__"
But, the summary shows >50 parameter estimates
summary(Stan_Fit)
$summary
mean se_mean sd
b[1] 2.217257e-01 0.020471036 0.8290820
b[2] -1.537814e+00 0.008043568 0.3371027
b[3] -1.174772e+00 0.006399751 0.2449436
b[4] -2.923468e+00 0.011252569 0.5230500
b[5] -1.450151e+00 0.008074908 0.3097133
b[6] -2.599561e-01 0.008185724 0.2617239
b[7] -1.052058e+00 0.008810477 0.4305872
b[8] 3.241392e-01 0.005720877 0.2425289
b[9] -9.873551e-01 0.012379691 0.6608076
b[10] -1.287215e-01 0.007753539 0.3090895
b[11] -4.429590e-01 0.009369959 0.4255439
b[12] 7.499469e-01 0.005516026 0.2483604
b[13] -6.541740e-01 0.012289146 0.6643431
b[14] 2.192564e-01 0.007258343 0.3046114
Intercept -1.525383e+00 0.033777919 0.9652861
sd_1[1] 1.436836e+00 0.046714680 1.3970096
z_1[1,1] -7.517089e-01 0.021172056 0.7810144
z_1[1,2] -4.972927e-01 0.018919796 0.7594828
sd_2[1] 4.477611e-01 0.004041604 0.1209679
z_2[1,1] 8.596555e-02 0.014483056 0.5831262
z_2[1,2] -1.253229e+00 0.014916998 0.6980695
z_2[1,3] 3.111350e-01 0.014900960 0.6099702
z_2[1,4] 1.149322e+00 0.016503509 0.6165432
z_2[1,5] -4.060654e-01 0.015903545 0.6457641
z_2[1,6] 2.134038e-01 0.015484912 0.5889044
z_2[1,7] 5.544974e-02 0.015525774 0.6473537
z_2[1,8] -6.181450e-01 0.014747755 0.6660775
z_2[1,9] 5.290829e-01 0.015906044 0.5698653
z_2[1,10] 8.729432e-01 0.016571499 0.5810122
z_2[1,11] 7.845846e-01 0.017139551 0.6229427
z_2[1,12] -6.204374e-01 0.015044339 0.5690281
z_2[1,13] -9.586822e-01 0.015809491 0.6092595
z_2[1,14] -4.878422e-01 0.016512104 0.5412494
z_2[1,15] 8.643315e-01 0.017285002 0.6057127
z_2[1,16] -1.332695e+00 0.016735423 0.6309517
r_1_1[1] -1.144309e+00 0.040164351 1.1217041
r_1_1[2] -8.140889e-01 0.031927232 1.0065935
r_2_1[1] 3.782811e-02 0.006578949 0.2588802
r_2_1[2] -5.487705e-01 0.007058007 0.3202292
r_2_1[3] 1.352888e-01 0.006769787 0.2653286
r_2_1[4] 4.959403e-01 0.007115804 0.2696668
r_2_1[5] -1.719174e-01 0.007350662 0.2827683
r_2_1[6] 9.256252e-02 0.006880261 0.2568993
r_2_1[7] 2.529879e-02 0.007081619 0.2878717
r_2_1[8] -2.672047e-01 0.006906700 0.2965187
r_2_1[9] 2.256720e-01 0.006860215 0.2517440
r_2_1[10] 3.736942e-01 0.006813081 0.2510305
r_2_1[11] 3.369743e-01 0.007350852 0.2714495
r_2_1[12] -2.701398e-01 0.006758221 0.2507979
r_2_1[13] -4.192210e-01 0.007321869 0.2733448
r_2_1[14] -2.096125e-01 0.007437828 0.2380768
r_2_1[15] 3.655870e-01 0.007103823 0.2513691
r_2_1[16] -5.749828e-01 0.007158903 0.2695888
b_Intercept -1.997591e-01 0.039201459 1.1117260
prior_b -1.647482e-02 0.022177523 0.9896906
prior_Intercept 2.471019e-02 0.022081330 0.9972855
prior_sd_1_1 1.006901e+01 3.391166288 151.9934237
prior_sd_2_1 5.821483e+00 0.944451437 42.5894717
lp__ -1.285477e+03 0.216157036 5.1362030
I’m trying to figure out which Stan estimates map onto the BRMS estimates and/or how. Does anyone know of where that might be documented?
I’m also unsure of how the model is specified:
data {
int<lower=1> N; // number of observations
int Y[N]; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
// data for group-level effects of ID 1
int<lower=1> N_1; // number of grouping levels
int<lower=1> M_1; // number of coefficients per level
int<lower=1> J_1[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_1_1;
// data for group-level effects of ID 2
int<lower=1> N_2; // number of grouping levels
int<lower=1> M_2; // number of coefficients per level
int<lower=1> J_2[N]; // grouping indicator per observation
// group-level predictor values
vector[N] Z_2_1;
int prior_only; // should the likelihood be ignored?
}
transformed data {
int Kc = K - 1;
matrix[N, Kc] Xc; // centered version of X without an intercept
vector[Kc] means_X; // column means of X before centering
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
}
}
parameters {
vector[Kc] b; // population-level effects
real Intercept; // temporary intercept for centered predictors
vector<lower=0>[M_1] sd_1; // group-level standard deviations
vector[N_1] z_1[M_1]; // standardized group-level effects
vector<lower=0>[M_2] sd_2; // group-level standard deviations
vector[N_2] z_2[M_2]; // standardized group-level effects
}
transformed parameters {
vector[N_1] r_1_1; // actual group-level effects
vector[N_2] r_2_1; // actual group-level effects
r_1_1 = (sd_1[1] * (z_1[1]));
r_2_1 = (sd_2[1] * (z_2[1]));
}
model {
// initialize linear predictor term
vector[N] mu = Intercept + Xc * b;
for (n in 1:N) {
// add more terms to the linear predictor
mu[n] += r_1_1[J_1[n]] * Z_1_1[n] + r_2_1[J_2[n]] * Z_2_1[n];
}
// priors including all constants
target += normal_lpdf(b | 0,1);
target += normal_lpdf(Intercept | 0,1);
target += cauchy_lpdf(sd_1[1] | 0,1)
- 1 * cauchy_lccdf(0 | 0,1);
target += std_normal_lpdf(z_1[1]);
target += cauchy_lpdf(sd_2[1] | 0,1)
- 1 * cauchy_lccdf(0 | 0,1);
target += std_normal_lpdf(z_2[1]);
// likelihood including all constants
if (!prior_only) {
target += bernoulli_logit_lpmf(Y | mu);
}
}
generated quantities {
// actual population-level intercept
real b_Intercept = Intercept - dot_product(means_X, b);
// additionally draw samples from priors
real prior_b = normal_rng(0,1);
real prior_Intercept = normal_rng(0,1);
real prior_sd_1_1 = cauchy_rng(0,1);
real prior_sd_2_1 = cauchy_rng(0,1);
// use rejection sampling for truncated priors
while (prior_sd_1_1 < 0) {
prior_sd_1_1 = cauchy_rng(0,1);
}
while (prior_sd_2_1 < 0) {
prior_sd_2_1 = cauchy_rng(0,1);
}
}
I think r_1_1
represents “Condition” and r_2_1
is “Group”, but I’m unsure of what z_*
represents. Are the standardized estimates what BRMS shows in summary? J_*
and Z_*_*
are all vectors of 1
for each observation