I am attempting to fit a latent class analysis model using rstan. I have been generally following the case study here: http://mc-stan.org/users/documentation/case-studies/dina_independent.html, both model have latent discrete parameters that have to be marginalized.
Following the case study and using wide-format data, the model estimates in the expected way and recovers the parameters at an acceptable rate. However, the data that I wish to use has missing data, and therefore will need to be in long format with the missing responses filtered out (mentioned in this IRT case study: http://mc-stan.org/users/documentation/case-studies/tutorial_twopl.html). I created a new .stan
file to estimate using long-format data, but the parameter estimates are way off, leading me to conclude that something in my specification of the long-format model is incorrect. Can anyone offer guidance as to how the model should be specified differently?
Here is the code to generate and estimate the models:
Generate data:
library(tidyverse)
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores() - 1)
inv_logit <- function(x) {
exp(x) / (1 + exp(x))
}
set.seed(9416)
num_stu <- 500
num_item <- 10
att_mastery <- 0.4
stu_mastery <- data_frame(
student_id = seq_len(num_stu),
mastery = sample(c(0, 1), size = num_stu, replace = TRUE,
prob = c(1 - att_mastery, att_mastery))
)
items <- data_frame(
item_id = seq_len(num_item),
intercept = runif(num_item, -3, 0.3),
maineffect = runif(num_item, 0, 3)
)
response_data <- crossing(
student_id = seq_len(num_stu),
item_id = seq_len(num_item)
) %>%
left_join(stu_mastery, by = "student_id") %>%
left_join(items, by = "item_id") %>%
mutate(
probability = case_when(mastery == 1 ~ inv_logit(intercept + maineffect),
TRUE ~ inv_logit(intercept)),
random = runif(nrow(.)),
score = case_when(random <= probability ~ 1, TRUE ~ 0)
)
response_matrix <- response_data %>%
select(student_id, item_id, score) %>%
spread(key = item_id, value = score) %>%
select(-student_id) %>%
as.matrix()
Fit model with wide format data:
wide_data <- list(
I = num_item,
J = num_stu,
y = response_matrix
)
wide_model <- test_model <- stan(file = "lca_wide.stan", data = wide_data,
chains = 3, iter = 3000)
wide_summary <- summary(wide_model, pars = c("intercept", "maineffect"),
probs = c("0.025", "0.975"))$summary
wide_summary %>%
as.data.frame() %>%
rownames_to_column() %>%
select(rowname, mean) %>%
mutate(
parameter = map_chr(rowname, function(x) {
gsub("\\[.*\\]", "", x)
}),
item = map_dbl(rowname, function(x) {
as.numeric(gsub("[^\\d]+", "", x, perl=TRUE))
})
) %>%
left_join(gather(items, key = "parameter", value = "true", 2:3),
by = c("item" = "item_id", "parameter")) %>%
ggplot(aes(x = true, y = mean)) +
geom_abline(slope = 1, intercept = 0) +
geom_point() +
facet_wrap(~ parameter, scales = "free") +
labs(x = "true value", y = "estimated value") +
expand_limits(x = 0, y = 0)
Fit model with long format data:
response_matrix <- response_matrix %>%
as.data.frame() %>%
rowid_to_column() %>%
gather(key = item_id, value = score, `1`:`10`) %>%
as_data_frame() %>%
select(respondent = rowid, item = item_id, score) %>%
mutate(item = as.numeric(item))
long_data <- list(
I = num_item,
J = num_stu,
N = nrow(response_matrix),
ii = response_matrix$item,
jj = response_matrix$respondent,
y = response_matrix$score
)
long_model <- stan(file = "lca_long.stan", data = long_data, chains = 3,
iter = 3000)
long_summary <- summary(long_model, pars = c("intercept", "maineffect"),
probs = c("0.025", "0.975"))$summary
long_summary %>%
as.data.frame() %>%
rownames_to_column() %>%
select(rowname, mean) %>%
mutate(
parameter = map_chr(rowname, function(x) {
gsub("\\[.*\\]", "", x)
}),
item = map_dbl(rowname, function(x) {
as.numeric(gsub("[^\\d]+", "", x, perl=TRUE))
})
) %>%
left_join(gather(items, key = "parameter", value = "true", 2:3),
by = c("item" = "item_id", "parameter")) %>%
ggplot(aes(x = true, y = mean)) +
geom_abline(slope = 1, intercept = 0) +
geom_point() +
facet_wrap(~ parameter, scales = "free") +
expand_limits(x = 0, y = 0)
Stan code for wide format:
data {
int<lower=1> I; // number of items
int<lower=1> J; // number of respondents
matrix[J, I] y; // score for obs n
}
parameters {
simplex[2] nu;
// average intercept and main effect
real mean_intercept;
real<lower=0> mean_maineffect;
// item level deviations from average effects
real dev_intercept[I];
real<lower=-mean_maineffect> dev_maineffect[I];
}
transformed parameters {
vector[2] log_nu = log(nu);
real intercept[I];
real maineffect[I];
vector[I] master_pi;
vector[I] nonmaster_pi;
for (i in 1:I) {
intercept[i] = mean_intercept + dev_intercept[i];
maineffect[i] = mean_maineffect + dev_maineffect[i];
nonmaster_pi[i] = inv_logit(intercept[i]);
master_pi[i] = inv_logit(intercept[i] + maineffect[i]);
}
}
model{
real ps[2];
real log_items[I];
matrix[I,2] pi;
// Priors
mean_intercept ~ normal(0, 5);
mean_maineffect ~ normal(0, 5);
dev_intercept ~ normal(0, 3);
dev_maineffect ~ normal(0, 3);
// Probability of correct response for each class on each item
for (c in 1:2) {
for (i in 1:I) {
pi[i,c] = master_pi[i]^(c - 1) * nonmaster_pi[i]^(1 - (c - 1));
}
}
// Define model
for (j in 1:J) {
for (c in 1:2) {
for (i in 1:I) {
log_items[i] = y[j,i] * log(pi[i,c]) + (1 - y[j,i]) * log(1 - pi[i,c]);
}
ps[c] = log_nu[c] + sum(log_items);
}
target += log_sum_exp(ps);
}
}
Stan code for long-format data:
data {
int<lower=1> I; // number of items
int<lower=1> J; // number of respondents
int<lower=1> N; // number of observations
int<lower=1,upper=I> ii[N]; // item for obs n
int<lower=1,upper=J> jj[N]; // respondent for obs n
int<lower=0,upper=1> y[N]; // score for obs n
}
parameters {
simplex[2] nu;
// average intercept and main effect
real mean_intercept;
real<lower=0> mean_maineffect;
// item level deviations from average effects
real dev_intercept[I];
real<lower=-mean_maineffect> dev_maineffect[I];
}
transformed parameters {
vector[2] log_nu = log(nu);
real intercept[I];
real maineffect[I];
vector[I] master_pi;
vector[I] nonmaster_pi;
for (i in 1:I) {
intercept[i] = mean_intercept + dev_intercept[i];
maineffect[i] = mean_maineffect + dev_maineffect[i];
nonmaster_pi[i] = inv_logit(intercept[i]);
master_pi[i] = inv_logit(intercept[i] + maineffect[i]);
}
}
model{
real ps[2];
real log_items[N];
matrix[I,2] pi;
// Priors
mean_intercept ~ normal(0, 5);
mean_maineffect ~ normal(0, 5);
dev_intercept ~ normal(0, 3);
dev_maineffect ~ normal(0, 3);
// Probability of correct response for each class on each item
for (c in 1:2) {
for (i in 1:I) {
pi[i,c] = master_pi[i]^(c - 1) * nonmaster_pi[i]^(1 - (c - 1));
}
}
// Define model
for (n in 1:N) {
for (c in 1:2) {
ps[c] = log_nu[c] + (y[n] * log(pi[ii[n],c]) + (1 - y[n]) * log(1 - pi[ii[n],c]));
}
target += log_sum_exp(ps);
}
}
Session information:
Session info -----------------------------------------------------------------------
setting value
version R version 3.4.1 (2017-06-30)
system x86_64, darwin15.6.0
ui RStudio (1.1.334)
language (EN)
collate en_US.UTF-8
tz America/Chicago
date 2017-09-05
Packages ---------------------------------------------------------------------------
package * version date source
assertthat 0.2.0 2017-04-11 CRAN (R 3.4.0)
base * 3.4.1 2017-07-07 local
bindr 0.1 2016-11-13 cran (@0.1)
bindrcpp * 0.2 2017-06-17 cran (@0.2)
broom 0.4.2 2017-02-13 CRAN (R 3.4.0)
cellranger 1.1.0 2016-07-27 CRAN (R 3.4.0)
codetools 0.2-15 2016-10-05 CRAN (R 3.4.1)
colorspace 1.3-2 2016-12-14 CRAN (R 3.4.0)
compiler 3.4.1 2017-07-07 local
datasets * 3.4.1 2017-07-07 local
devtools 1.13.3 2017-08-02 cran (@1.13.3)
digest 0.6.12 2017-01-27 CRAN (R 3.4.0)
dplyr * 0.7.2 2017-07-20 CRAN (R 3.4.1)
forcats 0.2.0 2017-01-23 CRAN (R 3.4.0)
foreign 0.8-69 2017-06-22 CRAN (R 3.4.1)
ggplot2 * 2.2.1.9000 2017-09-05 Github (tidyverse/ggplot2@c592e32)
glue 1.1.1 2017-06-21 cran (@1.1.1)
graphics * 3.4.1 2017-07-07 local
grDevices * 3.4.1 2017-07-07 local
grid 3.4.1 2017-07-07 local
gridExtra 2.2.1 2016-02-29 CRAN (R 3.4.0)
gtable 0.2.0 2016-02-26 CRAN (R 3.4.0)
haven 1.1.0 2017-07-09 CRAN (R 3.4.1)
hms 0.3 2016-11-22 CRAN (R 3.4.0)
httr 1.3.1 2017-08-20 CRAN (R 3.4.1)
inline 0.3.14 2015-04-13 CRAN (R 3.4.0)
jsonlite 1.5 2017-06-01 cran (@1.5)
labeling 0.3 2014-08-23 CRAN (R 3.4.0)
lattice 0.20-35 2017-03-25 CRAN (R 3.4.1)
lazyeval 0.2.0 2016-06-12 CRAN (R 3.4.0)
lubridate 1.6.0 2016-09-13 CRAN (R 3.4.0)
magrittr 1.5 2014-11-22 CRAN (R 3.4.0)
memoise 1.1.0 2017-04-21 CRAN (R 3.4.0)
methods * 3.4.1 2017-07-07 local
mnormt 1.5-5 2016-10-15 CRAN (R 3.4.0)
modelr 0.1.1 2017-07-24 cran (@0.1.1)
munsell 0.4.3 2016-02-13 CRAN (R 3.4.0)
nlme 3.1-131 2017-02-06 CRAN (R 3.4.0)
parallel 3.4.1 2017-07-07 local
pkgconfig 2.0.1 2017-03-21 cran (@2.0.1)
plyr 1.8.4 2016-06-08 CRAN (R 3.4.0)
psych 1.7.5 2017-05-03 CRAN (R 3.4.1)
purrr * 0.2.3 2017-08-02 cran (@0.2.3)
R6 2.2.2 2017-06-17 cran (@2.2.2)
Rcpp 0.12.12 2017-07-15 CRAN (R 3.4.1)
readr * 1.1.1 2017-05-16 CRAN (R 3.4.0)
readxl 1.0.0 2017-04-18 CRAN (R 3.4.0)
reshape2 1.4.2 2016-10-22 CRAN (R 3.4.0)
rlang 0.1.2.9000 2017-08-30 Github (hadley/rlang@f20124b)
rstan * 2.16.2 2017-07-03 CRAN (R 3.4.1)
rvest 0.3.2 2016-06-17 CRAN (R 3.4.0)
scales 0.5.0.9000 2017-08-30 Github (hadley/scales@d767915)
StanHeaders * 2.16.0-1 2017-07-03 CRAN (R 3.4.1)
stats * 3.4.1 2017-07-07 local
stats4 3.4.1 2017-07-07 local
stringi 1.1.5 2017-04-07 CRAN (R 3.4.0)
stringr 1.2.0 2017-02-18 CRAN (R 3.4.0)
tibble * 1.3.4 2017-08-22 cran (@1.3.4)
tidyr * 0.7.0 2017-08-16 CRAN (R 3.4.1)
tidyselect 0.1.1 2017-07-24 CRAN (R 3.4.1)
tidyverse * 1.1.1 2017-01-27 CRAN (R 3.4.0)
tools 3.4.1 2017-07-07 local
utils * 3.4.1 2017-07-07 local
withr 2.0.0 2017-09-05 Github (jimhester/withr@eff4818)
xml2 1.1.1 2017-01-24 CRAN (R 3.4.0)
yaml 2.1.14 2016-11-12 CRAN (R 3.4.0)