Different results using QR parameterization in Hierarchical Model

Dear all,
I am trying to fit a hierarchical (varying-intercept) linear(gaussian) model with 2 approaches:

  1. Without QR Reparameterization
  2. With QR Reparameterization

QR gives a lot of speed-ups but there are a lot of inconsistencies in the estimated parameters, specially in the varying intercepts (varying_a).

Data Setup

mpg <- ggplot2::mpg

dat <- list(
  N = nrow(mpg),
  J = length(unique(mpg$class)),
  K = 2,
  id = as.numeric(as.factor(mpg$class)),
  X = cbind(mpg$displ, mpg$year),
  y = mpg$hwy
)

Model 1 - Vanilla Hierarchical Linear Model

data {
  int<lower=1> N; //the number of observations
  int<lower=1> J; //the number of groups
  int<lower=1> K; //number of columns in the model matrix
  int<lower=1,upper=J> id[N]; //vector of group indeces
  matrix[N,K] X; //the model matrix
  vector[N] y; //the response variable
}
parameters {
  vector[K] beta; //population-level regression coefficients
  vector[J] varying_a; //group-level regression intercepts
  real<lower=0> sigma; //residual error
  real<lower=0> sigma_a; //hierarchical intercept residual error
}
model {
  //priors
  beta ~ student_t(3,0,1); //weakly informative priors on the regression coefficients
  varying_a ~ normal(0, sigma_a); //prior on the group-level intercepts
  sigma_a ~ normal(0, 2.5 * sd(y)); // hierarchical prior on intercepts
  sigma ~ gamma(2,0.1); //weakly informative priors, see section 6.9 in STAN user guide
  
  //likelihood
  y ~ normal(varying_a[id] + X * beta, sigma);
}

Results Model 1

# A tibble: 12 x 10
   variable          mean    median      sd      mad        q5       q95  rhat ess_bulk ess_tail
   <chr>            <dbl>     <dbl>   <dbl>    <dbl>     <dbl>     <dbl> <dbl>    <dbl>    <dbl>
 1 lp__         -365.     -365.     2.46    2.33     -370.     -362.      1.00    1075.    1254.
 2 beta[1]        -2.26     -2.27   0.216   0.215      -2.62     -1.91    1.00    2433.    2235.
 3 beta[2]         0.0161    0.0161 0.00112 0.000960    0.0144    0.0180  1.01     667.     764.
 4 varying_a[1]    6.03      6.02   2.38    2.18        2.18      9.98    1.01     658.     790.
 5 varying_a[2]    1.31      1.35   2.14    1.80       -2.33      4.54    1.01     601.     692.
 6 varying_a[3]    1.65      1.68   2.12    1.83       -1.89      4.91    1.01     604.     669.
 7 varying_a[4]   -2.11     -2.04   2.21    1.92       -5.86      1.31    1.01     629.     756.
 8 varying_a[5]   -5.30     -5.27   2.14    1.88       -8.82     -2.05    1.01     599.     660.
 9 varying_a[6]    1.90      1.95   2.15    1.86       -1.64      5.19    1.01     616.     692.
10 varying_a[7]   -4.01     -3.98   2.13    1.83       -7.54     -0.717   1.01     590.     647.
11 sigma           2.75      2.75   0.127   0.131       2.55      2.97    1.00    2535.    2542.
12 sigma_a         5.09      4.65   1.99    1.52        2.88      8.72    1.00    1220.    1349.
> 

Model 2 - QR Hierarchical Linear Model

data {
  int<lower=1> N; //the number of observations
  int<lower=1> J; //the number of groups
  int<lower=1> K; //number of columns in the model matrix
  int<lower=1,upper=J> id[N]; //vector of group indeces
  matrix[N,K] X; //the model matrix
  vector[N] y; //the response variable
}
transformed data {
  matrix[N, K] Q_ast;
  matrix[K, K] R_ast;
  matrix[K, K] R_ast_inverse;
  
  // thin and scale the QR decomposition
  Q_ast = qr_thin_Q(X) * sqrt(N - 1);
  R_ast = qr_thin_R(X) / sqrt(N - 1);
  R_ast_inverse = inverse(R_ast);
}
parameters {
  vector[K] theta;      // coefficients on Q_ast
  vector[J] varying_a; //group-level regression intercepts
  real<lower=0> sigma; //residual error
  real<lower=0> sigma_a; //hierarchical intercept residual error
}
model {
  //priors
  theta ~ student_t(3,0,1); //weakly informative priors on the regression coefficients
  varying_a ~ normal(0, sigma_a); //prior on the group-level intercepts
  sigma_a ~ normal(0, 2.5 * sd(y)); // hierarchical prior on intercepts
  sigma ~ gamma(2,0.1); //weakly informative priors, see section 6.9 in STAN user guide
  
  //likelihood
  y ~ normal(varying_a[id] + Q_ast * theta, sigma);
}
generated quantities {
  vector[K] beta; //reconstructed population-level regression coefficients
  
  beta = R_ast_inverse * theta; // coefficients on X
}

Results Model 2

# A tibble: 14 x 10
   variable          mean   median     sd    mad        q5      q95  rhat ess_bulk ess_tail
   <chr>            <dbl>    <dbl>  <dbl>  <dbl>     <dbl>    <dbl> <dbl>    <dbl>    <dbl>
 1 lp__          -374.    -374.    2.40   2.23   -379.     -371.     1.00    1736.    2231.
 2 beta_tilde[1]   -2.95    -2.95  0.269  0.264    -3.38     -2.50   1.00    2510.    2905.
 3 beta_tilde[2]    0.542    0.543 0.184  0.182     0.233     0.847  1.00    6600.    2627.
 4 varying_a[1]    30.9     30.9   1.35   1.35     28.7      33.1    1.00    4596.    2987.
 5 varying_a[2]    25.6     25.6   0.459  0.460    24.9      26.4    1.00    4037.    2917.
 6 varying_a[3]    26.0     26.0   0.443  0.440    25.2      26.7    1.00    5829.    2579.
 7 varying_a[4]    22.2     22.2   0.814  0.812    20.8      23.5    1.00    7733.    2993.
 8 varying_a[5]    19.1     19.1   0.508  0.498    18.2      19.9    1.00    4570.    2939.
 9 varying_a[6]    26.3     26.3   0.487  0.490    25.5      27.1    1.00    4149.    2981.
10 varying_a[7]    20.4     20.4   0.408  0.418    19.7      21.1    1.00    4044.    3126.
11 sigma            2.71     2.71  0.130  0.128     2.51      2.94   1.00    7141.    2749.
12 sigma_a         23.8     23.0   5.24   4.91     16.5      33.5    1.00    7110.    3158.
13 beta[1]         -2.34    -2.35  0.210  0.209    -2.69     -1.99   1.00    2522.    3030.
14 beta[2]          0.121    0.122 0.0413 0.0408    0.0523    0.190  1.00    6600.    2627.

What am I doing wrong? Any help would be appreciated.

After much tinkering, I’ve found some insights.
I was missing an intercept alpha.

So in both models the parameters and model blocks became

parameters {
  real alpha; //population-level intercept
  vector[K] beta; //population-level regression coefficients
  vector[J] varying_alpha; //group-level regression intercepts
  real<lower=0> sigma; //model residual error
  real<lower=0> sigma_alpha; //standard error for the group-level regression intercepts
}
model {
  //priors
  alpha ~ normal(mean(y), 2.5 * sd(y));
  beta ~ student_t(3,0,1);
  varying_alpha ~ normal(0, sigma_alpha);
  sigma ~ exponential(1/sd(y));
  sigma_alpha ~ exponential(0.1);
  
  //likelihood
  y ~ normal(alpha + varying_alpha[id] + X * beta, sigma);
}

and in the QR model:

model {
  theta ~ student_t(3,0,1);

  //likelihood
  y ~ normal(alpha + varying_alpha[id] + Q_ast * theta, sigma);
generated quantities {
  vector[K] beta; //reconstructed population-level regression coefficients
  beta = R_ast_inverse * theta; // coefficients on X
}

This gave a much better fit from both models, but still QR model is vast superior (specially in the alpha parameter, don’t know why) also the time to fit the models was vastly improved (11.7 vs 2.5s)

Results Model 1

# A tibble: 13 x 10
   variable               mean     median       sd      mad         q5       q95  rhat ess_bulk ess_tail
   <chr>                 <dbl>      <dbl>    <dbl>    <dbl>      <dbl>     <dbl> <dbl>    <dbl>    <dbl>
 1 lp__             -368.      -368.       2.52     2.48    -373.      -365.      1.00    1487.    2205.
 2 alpha              15.9       15.8     14.9     14.8       -8.85      40.6     1.00    2095.    2080.
 3 beta[1]            -2.25      -2.26     0.207    0.208     -2.59      -1.91    1.00    3045.    2613.
 4 beta[2]             0.00809    0.00819  0.00753  0.00758   -0.00436    0.0207  1.00    2040.    1973.
 5 varying_alpha[1]    6.01       5.92     2.26     2.08       2.54       9.74    1.00     939.     840.
 6 varying_alpha[2]    1.38       1.34     1.97     1.73      -1.68       4.63    1.00     723.     803.
 7 varying_alpha[3]    1.72       1.68     1.95     1.72      -1.28       4.95    1.00     746.     701.
 8 varying_alpha[4]   -2.06      -2.10     2.03     1.83      -5.20       1.26    1.00     825.     897.
 9 varying_alpha[5]   -5.25      -5.25     1.97     1.77      -8.23      -1.98    1.00     765.     773.
10 varying_alpha[6]    1.96       1.91     1.96     1.77      -1.08       5.22    1.00     724.     731.
11 varying_alpha[7]   -3.94      -3.94     1.94     1.75      -6.89      -0.762   1.00     711.     731.
12 sigma               2.75       2.75     0.130    0.127      2.55       2.98    1.00    2819.    2079.
13 sigma_alpha         4.81       4.45     1.74     1.40       2.76       8.10    1.00    1732.    1609.
> 

Results Model 2

# A tibble: 15 x 10
   variable               mean     median      sd      mad         q5        q95  rhat ess_bulk ess_tail
   <chr>                 <dbl>      <dbl>   <dbl>    <dbl>      <dbl>      <dbl> <dbl>    <dbl>    <dbl>
 1 lp__             -370.      -369.      2.66    2.56     -374.      -366.       1.00    1113.    1985.
 2 alpha              24.3       24.0     2.84    2.39       20.2       29.7      1.00     760.     720.
 3 theta[1]           -0.972     -0.571   2.06    1.31       -5.20       1.46     1.00    1168.     709.
 4 theta[2]            2.72       2.84    0.819   0.615       1.05       3.77     1.00    1197.     714.
 5 varying_alpha[1]    5.98       5.92    2.14    2.01        2.60       9.49     1.00    1169.    1378.
 6 varying_alpha[2]    1.40       1.42    1.85    1.65       -1.61       4.37     1.00     858.    1145.
 7 varying_alpha[3]    1.73       1.75    1.84    1.64       -1.27       4.62     1.00     867.    1163.
 8 varying_alpha[4]   -2.03      -2.04    1.94    1.79       -5.11       1.02     1.00     927.    1245.
 9 varying_alpha[5]   -5.24      -5.19    1.85    1.69       -8.23      -2.28     1.00     830.    1138.
10 varying_alpha[6]    1.99       2.00    1.88    1.73       -0.957      4.98     1.00     853.    1060.
11 varying_alpha[7]   -3.96      -3.94    1.84    1.67       -6.94      -1.08     1.00     835.    1008.
12 sigma               2.75       2.75    0.133   0.132       2.54       2.99     1.00    2412.    2212.
13 sigma_alpha         4.69       4.36    1.62    1.37        2.76       7.70     1.00    1972.    1851.
14 beta[1]            -2.23      -2.23    0.215   0.209      -2.58      -1.88     1.00    4066.    3173.
15 beta[2]             0.00389    0.00407 0.00117 0.000881    0.00150    0.00540  1.00    1197.     714.
>

The QR should be faster.

I assume the results are different because the priors are different.

beta ~ student_t(3,0,1);

vs.

theta ~ student_t(3,0,1);

I think because the variable transform here is constant you can still say:

beta ~ student_t(3,0,1);

in the second model.

See: 10.1 Changes of Variables | Stan Reference Manual (a constant transform will result in a constant determinant being added to the log density and we can drop additive constants from the log density and still sample the right distribution)

1 Like