Latent class model estimation with long-format data

I am attempting to fit a latent class analysis model using rstan. I have been generally following the case study here: http://mc-stan.org/users/documentation/case-studies/dina_independent.html, both model have latent discrete parameters that have to be marginalized.

Following the case study and using wide-format data, the model estimates in the expected way and recovers the parameters at an acceptable rate. However, the data that I wish to use has missing data, and therefore will need to be in long format with the missing responses filtered out (mentioned in this IRT case study: http://mc-stan.org/users/documentation/case-studies/tutorial_twopl.html). I created a new .stan file to estimate using long-format data, but the parameter estimates are way off, leading me to conclude that something in my specification of the long-format model is incorrect. Can anyone offer guidance as to how the model should be specified differently?

Here is the code to generate and estimate the models:

Generate data:

library(tidyverse)
library(rstan)

rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores() - 1)

inv_logit <- function(x) {
  exp(x) / (1 + exp(x))
}

set.seed(9416)

num_stu <- 500
num_item <- 10

att_mastery <- 0.4

stu_mastery <- data_frame(
  student_id = seq_len(num_stu),
  mastery = sample(c(0, 1), size = num_stu, replace = TRUE,
    prob = c(1 - att_mastery, att_mastery))
)

items <- data_frame(
  item_id = seq_len(num_item),
  intercept = runif(num_item, -3, 0.3),
  maineffect = runif(num_item, 0, 3)
)

response_data <- crossing(
  student_id = seq_len(num_stu),
  item_id = seq_len(num_item)
) %>%
  left_join(stu_mastery, by = "student_id") %>%
  left_join(items, by = "item_id") %>%
  mutate(
    probability = case_when(mastery == 1 ~ inv_logit(intercept + maineffect),
      TRUE ~ inv_logit(intercept)),
    random = runif(nrow(.)),
    score = case_when(random <= probability ~ 1, TRUE ~ 0)
  )

response_matrix <- response_data %>%
  select(student_id, item_id, score) %>%
  spread(key = item_id, value = score) %>%
  select(-student_id) %>%
  as.matrix()

Fit model with wide format data:

wide_data <- list(
  I = num_item,
  J = num_stu,
  y = response_matrix
)

wide_model <- test_model <- stan(file = "lca_wide.stan", data = wide_data,
  chains = 3, iter = 3000)

wide_summary <- summary(wide_model, pars = c("intercept", "maineffect"),
  probs = c("0.025", "0.975"))$summary

wide_summary %>%
  as.data.frame() %>%
  rownames_to_column() %>%
  select(rowname, mean) %>%
  mutate(
    parameter = map_chr(rowname, function(x) {
      gsub("\\[.*\\]", "", x)
    }),
    item = map_dbl(rowname, function(x) {
      as.numeric(gsub("[^\\d]+", "", x, perl=TRUE))
    })
  ) %>%
  left_join(gather(items, key = "parameter", value = "true", 2:3),
    by = c("item" = "item_id", "parameter")) %>%
  ggplot(aes(x = true, y = mean)) +
  geom_abline(slope = 1, intercept = 0) +
  geom_point() +
  facet_wrap(~ parameter, scales = "free") +
  labs(x = "true value", y = "estimated value") +
  expand_limits(x = 0, y = 0)

Fit model with long format data:

response_matrix <- response_matrix %>%
  as.data.frame() %>%
  rowid_to_column() %>%
  gather(key = item_id, value = score, `1`:`10`) %>%
  as_data_frame() %>%
  select(respondent = rowid, item = item_id, score) %>%
  mutate(item = as.numeric(item))

long_data <- list(
  I = num_item,
  J = num_stu,
  N = nrow(response_matrix),
  ii = response_matrix$item,
  jj = response_matrix$respondent,
  y = response_matrix$score
)

long_model <- stan(file = "lca_long.stan", data = long_data, chains = 3,
  iter = 3000)

long_summary <- summary(long_model, pars = c("intercept", "maineffect"),
  probs = c("0.025", "0.975"))$summary

long_summary %>%
  as.data.frame() %>%
  rownames_to_column() %>%
  select(rowname, mean) %>%
  mutate(
    parameter = map_chr(rowname, function(x) {
      gsub("\\[.*\\]", "", x)
    }),
    item = map_dbl(rowname, function(x) {
      as.numeric(gsub("[^\\d]+", "", x, perl=TRUE))
    })
  ) %>%
  left_join(gather(items, key = "parameter", value = "true", 2:3),
    by = c("item" = "item_id", "parameter")) %>%
  ggplot(aes(x = true, y = mean)) +
  geom_abline(slope = 1, intercept = 0) +
  geom_point() +
  facet_wrap(~ parameter, scales = "free") +
  expand_limits(x = 0, y = 0)

Stan code for wide format:

data {
  int<lower=1> I;                                     // number of items
  int<lower=1> J;                                     // number of respondents
  matrix[J, I] y;                                     // score for obs n
}
parameters {
  simplex[2] nu;
  
  // average intercept and main effect
  real mean_intercept;
  real<lower=0> mean_maineffect;
  
  // item level deviations from average effects
  real dev_intercept[I];
  real<lower=-mean_maineffect> dev_maineffect[I];
}
transformed parameters {
  vector[2] log_nu = log(nu);
  
  real intercept[I];
  real maineffect[I];
  vector[I] master_pi;
  vector[I] nonmaster_pi;
  
  for (i in 1:I) {
    intercept[i] = mean_intercept + dev_intercept[i];
    maineffect[i] = mean_maineffect + dev_maineffect[i];
    
    nonmaster_pi[i] = inv_logit(intercept[i]);
    master_pi[i] = inv_logit(intercept[i] + maineffect[i]);
  }
}
model{
  real ps[2];
  real log_items[I];
  matrix[I,2] pi;
  
  // Priors
  mean_intercept ~ normal(0, 5);
  mean_maineffect ~ normal(0, 5);
  
  dev_intercept ~ normal(0, 3);
  dev_maineffect ~ normal(0, 3);
  
  // Probability of correct response for each class on each item
  for (c in 1:2) {
    for (i in 1:I) {
      pi[i,c] = master_pi[i]^(c - 1) * nonmaster_pi[i]^(1 - (c - 1));
    }
  }
  
  // Define model
  for (j in 1:J) {
    for (c in 1:2) {
      for (i in 1:I) {
        log_items[i] = y[j,i] * log(pi[i,c]) + (1 - y[j,i]) * log(1 - pi[i,c]);
      }
      ps[c] = log_nu[c] + sum(log_items);
    }
    target += log_sum_exp(ps);
  }
}

Stan code for long-format data:

data {
  int<lower=1> I;                                     // number of items
  int<lower=1> J;                                     // number of respondents
  int<lower=1> N;                                     // number of observations
  int<lower=1,upper=I> ii[N];                         // item for obs n
  int<lower=1,upper=J> jj[N];                         // respondent for obs n
  int<lower=0,upper=1> y[N];                          // score for obs n
}
parameters {
  simplex[2] nu;
  
  // average intercept and main effect
  real mean_intercept;
  real<lower=0> mean_maineffect;
  
  // item level deviations from average effects
  real dev_intercept[I];
  real<lower=-mean_maineffect> dev_maineffect[I];
}
transformed parameters {
  vector[2] log_nu = log(nu);
  
  real intercept[I];
  real maineffect[I];
  vector[I] master_pi;
  vector[I] nonmaster_pi;
  
  for (i in 1:I) {
    intercept[i] = mean_intercept + dev_intercept[i];
    maineffect[i] = mean_maineffect + dev_maineffect[i];
    
    nonmaster_pi[i] = inv_logit(intercept[i]);
    master_pi[i] = inv_logit(intercept[i] + maineffect[i]);
  }
}
model{
  real ps[2];
  real log_items[N];
  matrix[I,2] pi;
  
  // Priors
  mean_intercept ~ normal(0, 5);
  mean_maineffect ~ normal(0, 5);
  
  dev_intercept ~ normal(0, 3);
  dev_maineffect ~ normal(0, 3);
  
  // Probability of correct response for each class on each item
  for (c in 1:2) {
    for (i in 1:I) {
      pi[i,c] = master_pi[i]^(c - 1) * nonmaster_pi[i]^(1 - (c - 1));
    }
  }
  
  // Define model
  for (n in 1:N) {
    for (c in 1:2) {
      ps[c] = log_nu[c] + (y[n] * log(pi[ii[n],c]) + (1 - y[n]) * log(1 - pi[ii[n],c]));
    }
    target += log_sum_exp(ps);
  }
}

Session information:

Session info -----------------------------------------------------------------------
 setting  value                       
 version  R version 3.4.1 (2017-06-30)
 system   x86_64, darwin15.6.0        
 ui       RStudio (1.1.334)           
 language (EN)                        
 collate  en_US.UTF-8                 
 tz       America/Chicago             
 date     2017-09-05                  

Packages ---------------------------------------------------------------------------
 package     * version    date       source                            
 assertthat    0.2.0      2017-04-11 CRAN (R 3.4.0)                    
 base        * 3.4.1      2017-07-07 local                             
 bindr         0.1        2016-11-13 cran (@0.1)                       
 bindrcpp    * 0.2        2017-06-17 cran (@0.2)                       
 broom         0.4.2      2017-02-13 CRAN (R 3.4.0)                    
 cellranger    1.1.0      2016-07-27 CRAN (R 3.4.0)                    
 codetools     0.2-15     2016-10-05 CRAN (R 3.4.1)                    
 colorspace    1.3-2      2016-12-14 CRAN (R 3.4.0)                    
 compiler      3.4.1      2017-07-07 local                             
 datasets    * 3.4.1      2017-07-07 local                             
 devtools      1.13.3     2017-08-02 cran (@1.13.3)                    
 digest        0.6.12     2017-01-27 CRAN (R 3.4.0)                    
 dplyr       * 0.7.2      2017-07-20 CRAN (R 3.4.1)                    
 forcats       0.2.0      2017-01-23 CRAN (R 3.4.0)                    
 foreign       0.8-69     2017-06-22 CRAN (R 3.4.1)                    
 ggplot2     * 2.2.1.9000 2017-09-05 Github (tidyverse/ggplot2@c592e32)
 glue          1.1.1      2017-06-21 cran (@1.1.1)                     
 graphics    * 3.4.1      2017-07-07 local                             
 grDevices   * 3.4.1      2017-07-07 local                             
 grid          3.4.1      2017-07-07 local                             
 gridExtra     2.2.1      2016-02-29 CRAN (R 3.4.0)                    
 gtable        0.2.0      2016-02-26 CRAN (R 3.4.0)                    
 haven         1.1.0      2017-07-09 CRAN (R 3.4.1)                    
 hms           0.3        2016-11-22 CRAN (R 3.4.0)                    
 httr          1.3.1      2017-08-20 CRAN (R 3.4.1)                    
 inline        0.3.14     2015-04-13 CRAN (R 3.4.0)                    
 jsonlite      1.5        2017-06-01 cran (@1.5)                       
 labeling      0.3        2014-08-23 CRAN (R 3.4.0)                    
 lattice       0.20-35    2017-03-25 CRAN (R 3.4.1)                    
 lazyeval      0.2.0      2016-06-12 CRAN (R 3.4.0)                    
 lubridate     1.6.0      2016-09-13 CRAN (R 3.4.0)                    
 magrittr      1.5        2014-11-22 CRAN (R 3.4.0)                    
 memoise       1.1.0      2017-04-21 CRAN (R 3.4.0)                    
 methods     * 3.4.1      2017-07-07 local                             
 mnormt        1.5-5      2016-10-15 CRAN (R 3.4.0)                    
 modelr        0.1.1      2017-07-24 cran (@0.1.1)                     
 munsell       0.4.3      2016-02-13 CRAN (R 3.4.0)                    
 nlme          3.1-131    2017-02-06 CRAN (R 3.4.0)                    
 parallel      3.4.1      2017-07-07 local                             
 pkgconfig     2.0.1      2017-03-21 cran (@2.0.1)                     
 plyr          1.8.4      2016-06-08 CRAN (R 3.4.0)                    
 psych         1.7.5      2017-05-03 CRAN (R 3.4.1)                    
 purrr       * 0.2.3      2017-08-02 cran (@0.2.3)                     
 R6            2.2.2      2017-06-17 cran (@2.2.2)                     
 Rcpp          0.12.12    2017-07-15 CRAN (R 3.4.1)                    
 readr       * 1.1.1      2017-05-16 CRAN (R 3.4.0)                    
 readxl        1.0.0      2017-04-18 CRAN (R 3.4.0)                    
 reshape2      1.4.2      2016-10-22 CRAN (R 3.4.0)                    
 rlang         0.1.2.9000 2017-08-30 Github (hadley/rlang@f20124b)     
 rstan       * 2.16.2     2017-07-03 CRAN (R 3.4.1)                    
 rvest         0.3.2      2016-06-17 CRAN (R 3.4.0)                    
 scales        0.5.0.9000 2017-08-30 Github (hadley/scales@d767915)    
 StanHeaders * 2.16.0-1   2017-07-03 CRAN (R 3.4.1)                    
 stats       * 3.4.1      2017-07-07 local                             
 stats4        3.4.1      2017-07-07 local                             
 stringi       1.1.5      2017-04-07 CRAN (R 3.4.0)                    
 stringr       1.2.0      2017-02-18 CRAN (R 3.4.0)                    
 tibble      * 1.3.4      2017-08-22 cran (@1.3.4)                     
 tidyr       * 0.7.0      2017-08-16 CRAN (R 3.4.1)                    
 tidyselect    0.1.1      2017-07-24 CRAN (R 3.4.1)                    
 tidyverse   * 1.1.1      2017-01-27 CRAN (R 3.4.0)                    
 tools         3.4.1      2017-07-07 local                             
 utils       * 3.4.1      2017-07-07 local                             
 withr         2.0.0      2017-09-05 Github (jimhester/withr@eff4818)  
 xml2          1.1.1      2017-01-24 CRAN (R 3.4.0)                    
 yaml          2.1.14     2016-11-12 CRAN (R 3.4.0)   

I think the two for loops at the end are different. This (in the long form model) is not what you wrote in the wide format:

// Define model
  for (n in 1:N) {
    for (c in 1:2) {
      ps[c] = log_nu[c] + (y[n] * log(pi[ii[n],c]) + (1 - y[n]) * log(1 - pi[ii[n],c]));
    }
    target += log_sum_exp(ps);
  }

This is what the long-format code would look like if it were in the wide format, which seems wrong based on the number of times target gets incremented (the long-format probability is the product of I * J things, the wide-format is just J things). Maybe this is an easier way to see the difference?

for (j in 1:J) {
  for (i in 1:I) {
    for (c in 1:2) {
      ps[c] = log_nu[c] + y[j,i] * log(pi[i,c]) + (1 - y[j,i]) * log(1 - pi[i,c]);
    }
    target += log_sum_exp(ps);
  }
}

What you want is the other direction though, haha. There’s probably a ragged array way of working with this.

Probably start by sorting your y data first by blocks of j, and then building a couple new arrays (calling them s and l here) that are the start index and length of the blocks of j. For example:

jj: 1 1 1 1 2 2 2 3 3 3 3
ii: 1 2 5 7 1 2 3 2 5 7 8
s: 1 5 8
l:  4 3 4
y: ... 

And then write something like:

for(k in 1:size(s)) {
  int j = jj[s];
  for(c in 1:2) {
    real log_items[l[k]];
    for(m in s[k] : s[k] + l[k]) {
      int i = ii[m];
      log_items[m] = y[k] * log(pi[i,c]) + (1 - y[k]) * log(1 - pi[i,c]);
    }
    ps[c] = log_nu[c] + sum(log_items);
  }
  target += log_sum_exp(ps);
}

Look for the ragged array examples in the manual. This stuff can be confusing. There’s always little annoying issues with things like the log_items array would be different sizes for each j (not 100% sure what I wrote would work – but you could make it an array of length max_size where max_size is the maximum length block of j indexes and just be careful about rezeroing it).

Hope that helps! Thanks for getting the format really nice on your question. Makes it a lot easier to parse what’s going on.

This is exactly what I needed! I had to tweak the model definition you provided just a little bit. Here is what ended up working (with l and s defined as suggested by @bbbales2):

  // Define model
  for (j in 1:J) {
    for (c in 1:2) {
      real log_items[l[j]];
      for (m in 1:l[j]) {
        int i = ii[s[j] + m - 1];
        log_items[m] = y[s[j] + m - 1] * log(pi[i,c]) + (1 - y[s[j] + m - 1]) * log(1 - pi[i,c]);
      }
      ps[c] = log_nu[c] + sum(log_items);
    }
    target += log_sum_exp(ps);
  }
1 Like