Dear all,
I am trying to fit a hierarchical (varying-intercept) linear(gaussian) model with 2 approaches:
- Without QR Reparameterization
- With QR Reparameterization
QR gives a lot of speed-ups but there are a lot of inconsistencies in the estimated parameters, specially in the varying intercepts (varying_a
).
Data Setup
mpg <- ggplot2::mpg
dat <- list(
N = nrow(mpg),
J = length(unique(mpg$class)),
K = 2,
id = as.numeric(as.factor(mpg$class)),
X = cbind(mpg$displ, mpg$year),
y = mpg$hwy
)
Model 1 - Vanilla Hierarchical Linear Model
data {
int<lower=1> N; //the number of observations
int<lower=1> J; //the number of groups
int<lower=1> K; //number of columns in the model matrix
int<lower=1,upper=J> id[N]; //vector of group indeces
matrix[N,K] X; //the model matrix
vector[N] y; //the response variable
}
parameters {
vector[K] beta; //population-level regression coefficients
vector[J] varying_a; //group-level regression intercepts
real<lower=0> sigma; //residual error
real<lower=0> sigma_a; //hierarchical intercept residual error
}
model {
//priors
beta ~ student_t(3,0,1); //weakly informative priors on the regression coefficients
varying_a ~ normal(0, sigma_a); //prior on the group-level intercepts
sigma_a ~ normal(0, 2.5 * sd(y)); // hierarchical prior on intercepts
sigma ~ gamma(2,0.1); //weakly informative priors, see section 6.9 in STAN user guide
//likelihood
y ~ normal(varying_a[id] + X * beta, sigma);
}
Results Model 1
# A tibble: 12 x 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -365. -365. 2.46 2.33 -370. -362. 1.00 1075. 1254.
2 beta[1] -2.26 -2.27 0.216 0.215 -2.62 -1.91 1.00 2433. 2235.
3 beta[2] 0.0161 0.0161 0.00112 0.000960 0.0144 0.0180 1.01 667. 764.
4 varying_a[1] 6.03 6.02 2.38 2.18 2.18 9.98 1.01 658. 790.
5 varying_a[2] 1.31 1.35 2.14 1.80 -2.33 4.54 1.01 601. 692.
6 varying_a[3] 1.65 1.68 2.12 1.83 -1.89 4.91 1.01 604. 669.
7 varying_a[4] -2.11 -2.04 2.21 1.92 -5.86 1.31 1.01 629. 756.
8 varying_a[5] -5.30 -5.27 2.14 1.88 -8.82 -2.05 1.01 599. 660.
9 varying_a[6] 1.90 1.95 2.15 1.86 -1.64 5.19 1.01 616. 692.
10 varying_a[7] -4.01 -3.98 2.13 1.83 -7.54 -0.717 1.01 590. 647.
11 sigma 2.75 2.75 0.127 0.131 2.55 2.97 1.00 2535. 2542.
12 sigma_a 5.09 4.65 1.99 1.52 2.88 8.72 1.00 1220. 1349.
>
Model 2 - QR Hierarchical Linear Model
data {
int<lower=1> N; //the number of observations
int<lower=1> J; //the number of groups
int<lower=1> K; //number of columns in the model matrix
int<lower=1,upper=J> id[N]; //vector of group indeces
matrix[N,K] X; //the model matrix
vector[N] y; //the response variable
}
transformed data {
matrix[N, K] Q_ast;
matrix[K, K] R_ast;
matrix[K, K] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(X) * sqrt(N - 1);
R_ast = qr_thin_R(X) / sqrt(N - 1);
R_ast_inverse = inverse(R_ast);
}
parameters {
vector[K] theta; // coefficients on Q_ast
vector[J] varying_a; //group-level regression intercepts
real<lower=0> sigma; //residual error
real<lower=0> sigma_a; //hierarchical intercept residual error
}
model {
//priors
theta ~ student_t(3,0,1); //weakly informative priors on the regression coefficients
varying_a ~ normal(0, sigma_a); //prior on the group-level intercepts
sigma_a ~ normal(0, 2.5 * sd(y)); // hierarchical prior on intercepts
sigma ~ gamma(2,0.1); //weakly informative priors, see section 6.9 in STAN user guide
//likelihood
y ~ normal(varying_a[id] + Q_ast * theta, sigma);
}
generated quantities {
vector[K] beta; //reconstructed population-level regression coefficients
beta = R_ast_inverse * theta; // coefficients on X
}
Results Model 2
# A tibble: 14 x 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -374. -374. 2.40 2.23 -379. -371. 1.00 1736. 2231.
2 beta_tilde[1] -2.95 -2.95 0.269 0.264 -3.38 -2.50 1.00 2510. 2905.
3 beta_tilde[2] 0.542 0.543 0.184 0.182 0.233 0.847 1.00 6600. 2627.
4 varying_a[1] 30.9 30.9 1.35 1.35 28.7 33.1 1.00 4596. 2987.
5 varying_a[2] 25.6 25.6 0.459 0.460 24.9 26.4 1.00 4037. 2917.
6 varying_a[3] 26.0 26.0 0.443 0.440 25.2 26.7 1.00 5829. 2579.
7 varying_a[4] 22.2 22.2 0.814 0.812 20.8 23.5 1.00 7733. 2993.
8 varying_a[5] 19.1 19.1 0.508 0.498 18.2 19.9 1.00 4570. 2939.
9 varying_a[6] 26.3 26.3 0.487 0.490 25.5 27.1 1.00 4149. 2981.
10 varying_a[7] 20.4 20.4 0.408 0.418 19.7 21.1 1.00 4044. 3126.
11 sigma 2.71 2.71 0.130 0.128 2.51 2.94 1.00 7141. 2749.
12 sigma_a 23.8 23.0 5.24 4.91 16.5 33.5 1.00 7110. 3158.
13 beta[1] -2.34 -2.35 0.210 0.209 -2.69 -1.99 1.00 2522. 3030.
14 beta[2] 0.121 0.122 0.0413 0.0408 0.0523 0.190 1.00 6600. 2627.
What am I doing wrong? Any help would be appreciated.