I’m adapting code from a course taught in BUGS to Stan. We are tasked with comparing two simple models: a linear regression y ~ x + g
, where x
is continuous and g
is three-group categorical, compared to one with an interaction y ~ x + g + x * g
. I have included code for both models below. In the class, the professor has been using DIC, since it is the one that always pops out with BUGS.
I am trying to pull fit statistics out to compare. I am using rstan
and loo
to do this. It appears that leave-one-out CV and WAIC are what is preferred by this community. I am having troubles calculating the log_lik
. After looking at the documentation for loo::extract_log_lik
and the code here (https://jrnold.github.io/bayesian_notes/introduction-to-stan-and-linear-regression.html), I feel as if my log_lik
is specified correctly.
I know I am using vector notation instead of matrix notation, but I’m stumped as to where I’m going wrong (see generated quantities
):
Model 1:
data {
int<lower=0> n_obs;
vector[n_obs] y;
vector[n_obs] x;
int g[n_obs];
}
transformed data {
vector[n_obs] group2;
vector[n_obs] group3;
for (i in 1:n_obs) {
group2[i] = g[i] == 2;
group3[i] = g[i] == 3;
}
}
parameters {
vector[4] beta;
real<lower=0> sigma;
}
transformed parameters {
real beta43_diff;
beta43_diff = beta[4] - beta[3];
}
model {
vector[n_obs] yhat;
for(i in 1:n_obs) {
yhat[i] = beta[1] + beta[2] * x[i] + beta[3] * group2[i] + beta[4] * group3[i];
}
y ~ normal(yhat, sigma);
sigma ~ cauchy(0, 2.5);
beta ~ normal(0, 100);
}
generated quantities {
vector[n_obs] log_lik;
for (i in 1:n_obs)
log_lik[i] = normal_lpdf(y[i] | yhat[i], sigma);
}
Model 2
data {
int<lower=0> n_obs;
vector[n_obs] y;
vector[n_obs] x;
int g[n_obs];
}
transformed data {
vector[n_obs] group2;
vector[n_obs] group3;
for (i in 1:n_obs) {
group2[i] = g[i] == 2;
group3[i] = g[i] == 3;
}
}
parameters {
vector[6] beta;
real<lower=0> sigma;
}
model {
vector[n_obs] yhat;
for(i in 1:n_obs) {
yhat[i] = beta[1] + beta[2] * x[i] +
beta[3] * group2[i] + beta[4] * group3[i] +
beta[5] * x[i] * group2[i] + beta[6] * x[i] * group3[i];
}
y ~ normal(yhat, sigma);
sigma ~ cauchy(0, 2.5);
beta ~ normal(0, 100);
}
generated quantities {
vector[n_obs] log_lik;
for (i in 1:n_obs)
log_lik[i] = normal_lpdf(y[i] | yhat[i], sigma);
}
If I comment out the generated quantities
blocks, then the code runs fine. What am I misspecifying? I have dput
the data below.
Additionally, Is there a general case for specifying the log_lik that is easily comparable across various types models (e.g., multilevel models, different link functions)? If not, I think it would be a great feature to calculate the necessary information for fit statistics in rstan
; however, I understand if this is more complicated than I am giving it credit for.
structure(list(n_obs = 200L, y = c(82L, 79L, 127L, 85L, 84L,
114L, 75L, 112L, 101L, 87L, 104L, 103L, 106L, 123L, 82L, 102L,
76L, 119L, 129L, 84L, 101L, 116L, 82L, 99L, 95L, 82L, 124L, 92L,
120L, 83L, 97L, 115L, 103L, 112L, 99L, 107L, 96L, 119L, 96L,
73L, 113L, 114L, 100L, 114L, 82L, 96L, 102L, 118L, 79L, 91L,
127L, 95L, 87L, 133L, 91L, 121L, 118L, 117L, 99L, 120L, 80L,
75L, 76L, 95L, 104L, 106L, 118L, 96L, 102L, 135L, 101L, 119L,
103L, 125L, 101L, 106L, 131L, 100L, 83L, 96L, 106L, 122L, 76L,
105L, 109L, 108L, 72L, 108L, 104L, 97L, 95L, 90L, 135L, 105L,
93L, 103L, 88L, 99L, 119L, 106L, 93L, 90L, 93L, 89L, 110L, 97L,
69L, 106L, 94L, 73L, 108L, 115L, 101L, 100L, 81L, 92L, 96L, 112L,
111L, 74L, 103L, 113L, 98L, 110L, 97L, 108L, 88L, 84L, 110L,
97L, 132L, 95L, 92L, 92L, 90L, 107L, 117L, 79L, 130L, 124L, 116L,
123L, 82L, 100L, 102L, 86L, 102L, 113L, 79L, 113L, 108L, 111L,
100L, 94L, 98L, 110L, 90L, 118L, 70L, 88L, 93L, 97L, 104L, 99L,
76L, 111L, 87L, 116L, 99L, 100L, 91L, 94L, 96L, 93L, 104L, 121L,
73L, 100L, 103L, 117L, 108L, 99L, 96L, 111L, 96L, 88L, 123L,
95L, 87L, 90L, 90L, 111L, 86L, 86L, 100L, 97L, 95L, 103L, 124L,
108L), x = c(-1.61778110598111, -1.41930244948108, 1.66055088814571,
-0.877834399732974, -0.514547487661916, -0.193145998933965, -0.983695115804797,
0.513319581703684, -0.860280368023465, -1.88155404296431, 0.339182014456294,
0.442938863359524, 0.759015596456956, 0.599229113883343, -1.92479433732238,
-0.184039203388914, -1.61397117479425, 0.962066579913893, 2.46414076460808,
-1.17607277971075, -0.272868824943055, 1.35967097940892, -1.05438362825567,
-0.228787420019154, -0.292992101740388, -1.82954098345819, 2.08519534735967,
0.591107398640649, 1.95579014918031, -0.65667842751543, 0.0997133328373552,
1.5252950891141, -0.0926206198821385, 0.902974476600102, 0.249136489148938,
-0.0830554643004599, 0.549419221786291, 1.71525303107621, -0.850629708048084,
-1.98109122606004, 1.1222263062255, 1.11184843718797, -0.275291338001666,
1.07060421189662, -1.43309225350286, -1.13249022646723, 0.532765933295707,
0.752377526984537, -1.51346118188762, -1.17005988275818, 1.9165727402041,
-0.784129344335195, 0.387125876635209, 0.729149295988579, 0.341384258953968,
2.16299269693693, 1.3422990840893, 1.85248225327614, 0.0491641778834516,
0.99144115233223, -0.890310948536989, -1.11569362896512, -1.41715180607983,
-0.738615505051399, -0.00459357194037879, 1.13696375696993, -0.506374974084071,
0.252297985832738, -0.368866868114261, 1.08774792247216, -0.341718291299489,
0.15639183470376, 0.409860225688048, 1.89837230163239, -0.520120130525424,
1.97000596158883, 1.67776757200812, -0.168419663083854, -0.781874943580764,
-0.514731489700144, 0.0589933531189247, 0.227893986184042, -2.4374109865297,
0.171513222362348, 0.903167999470544, -0.0344474632158082, -1.03997041718748,
-0.950822657655817, -0.705289953893956, -0.653414563552836, -0.190110785154354,
-0.354290358646747, 2.39976174652172, -1.35617767206351, -1.22221321624565,
-0.454598225635334, -0.990033092897734, -0.244847502676501, 0.505704018544499,
0.669561573651021, 0.324323516964131, -0.076309426351195, -0.458350076494571,
-1.83698987236975, 0.0803829331577828, 0.87788869488688, 2.40273116380173,
-0.0670704513110104, 1.09911896218748, 0.997356439322442, -0.435618412408813,
-1.67972593760864, 1.04078761388651, -0.470079835122939, 0.574807777314493,
0.102452519681578, 0.957329543824806, -1.31012959393378, -0.369434151260721,
0.848718206044066, 0.65304393518083, -0.157812135445461, -0.643681461075948,
1.37852013777139, -0.981482791650922, -0.194101148603517, 1.20084662370845,
-0.327739108052317, -0.815923921147731, -0.364336156241106, 0.412741737882205,
-1.06602917308272, 0.585301017920321, 1.07595977751635, -0.913953240303675,
-0.960950830448768, -2.3264600669633, 1.06706497395435, -0.628006453541378,
-0.0829198446716328, -1.65412177364532, -2.08977332907322, 0.553902505801274,
-0.292075055808902, -0.233261312573706, 0.751055086579009, -0.114927863866158,
0.374181429758709, 2.52238371229229, 0.107655809876072, -1.02102044562293,
-1.07644673407697, 0.00571436701147894, 0.495689755744636, -0.221297772945544,
-0.632483960431338, 0.424388055322506, 0.234124916335216, -0.163031665697092,
-0.3342207825213, -0.333464178669596, 1.13597895106278, 0.0802457519807691,
0.442004414767879, 0.0631890969970137, -0.928784884004674, 0.600321357042254,
0.113300847973133, -1.62841351193006, -1.33808720350458, 0.523483509043114,
-1.49571955064019, -0.601478311316166, -1.2554232917663, 1.36126110478246,
0.200018774239335, -0.620303665440682, -0.440943383259836, -0.317157817670251,
3.34334429646825e-05, 0.289794566256829, 0.968321868754466, 1.76731991341036,
0.536365489366773, -1.29110416808371, -0.0950369217019353, -1.36594489408091,
0.771589976670854, 1.46438162528982, 1.0807684713817, -0.225632514269493,
-0.584571340561916, 0.508562770886174, -1.11150779418741, -0.539016239263928,
-0.904335948222398, -0.215635290995049, -1.0022074154576, 0.599938732154603,
0.830811853601161), g = c(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L,
1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L,
2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L,
2L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L,
3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L, 3L)), .Names = c("n_obs",
"y", "x", "g"))