Meta-analysis latent variable model - prohibitively slow estmation

Hi,

I am running the following model. The estimates are seem fine so it seems to “work” in that sense. However, the estimation is extremely slow. In total there are 15 studies (see data below) and I am just subsetting the data on the first 4 to do test runs. However, even just with these 4 studies it takes ~ 45 minutes just to do 500 iterations.

I am wondering if anybody has any ideas for how to speed this up, or whether these models are inherently this slow.

I am not aware of these models being used much in the literature except for this paper , however they used SAS and do not explicitly mention computation times.


data {
                int<lower=2> nt; // number of tests
		int<lower=0> NS;	// total # of studies
                int<lower=1> n[NS]; // number of individuals in each study 
                int r[NS,4]; // 2x2 table for each study
	}

transformed data {
                int res1[NS, max(n)] ; 
                int res2[NS, max(n)] ; 

for (s in 1:NS) {
	res1[s] = append_array( append_array( rep_array(0, r[s,1]), rep_array(1, r[s,2]) ) , append_array( rep_array(0, r[s,3]), rep_array(1, r[s,4]) ) )   ;
        res2[s] = append_array( append_array( rep_array(0, r[s,1]), rep_array(0, r[s,2]) ) , append_array( rep_array(1, r[s,3]), rep_array(1, r[s,4]) ) )   ; 
    }
}
	parameters {
	  vector[nt] a11_raw;
          vector[nt] a10_raw;
          real<lower=0,upper=1> p[NS]; 
	}

	transformed parameters {
          vector[nt] a11;
          vector[nt] a10;
     
        a11 = a11_raw;
        a10 = a10_raw;
     }

model {
    a11_raw ~ std_normal(); 
    a10_raw ~ std_normal(); 

 for (s in 1:NS) {
   for (i in 1:n[s]) { 
     real lps[2];
    // logprob conditional on non-diseased
    lps[1] = bernoulli_lpmf(0| p[s])
    +  bernoulli_lpmf(res1[s,i] | Phi_approx(-a10[1] ) )
    +  bernoulli_lpmf(res2[s,i] | Phi_approx(-a10[2] ) );
    // logprob conditional on diseased
    lps[2] = bernoulli_lpmf(1| p[s])
    +  bernoulli_lpmf(res1[s,i] | Phi_approx(a11[1] ) )
    +  bernoulli_lpmf(res2[s,i] | Phi_approx(a11[2] ) );

    // marginalize
    target += log_sum_exp(lps);
   }
 }
}

generated quantities {
        vector<lower=0,upper=1>[nt] Se; 
        vector<lower=0,upper=1>[nt] Sp; 
        vector<lower=0,upper=1>[nt] fp; 
        vector<lower=0,upper=1>[nt] fn; 
              
       for (j in 1:nt) {
      // overall test accuracy estimates, for each test
          fn[j]   =  1-Phi_approx( (a11[j] ) );
          Se[j]   =    Phi_approx( (a11[j] ) );
          Sp[j]   =    Phi_approx( (a10[j] ) );
          fp[j]   = 1- Phi_approx( (a10[j] ) );
        } 
}

The data and code to run the model from R is:

num <- 15
T1 <- c(1,2,2,2,2,3,2,3,2,2,2,2,1,3,3)
T2 <- rep(4, times = 15)
T <- matrix(c(T1, T2), ncol=2, nrow=15)
r1 <- c(97, 1,20,14,93,             31, 132, 305, 17, 44, 54, 39, 64, 7, 25) ; length(r1) #tp
r2 <- c(20,  0 ,3,   1   ,30       ,6,23,9,19,5,22,4, 0,   4,44)  ; length(r2) #fn
r3 <- c(14,9,2,44,183,            3,54,98,31,61,63,536, 612 ,  31, 3) ; length(r3) #fp. gs=0. ind=1
r4 <- c(297, 45, 231, 285, 1845,   248, 226, 256, 580, 552, 980, 227, 2051, 98, 170)  ; length(r4) # tn, ind=0, 
ns <- c()
for (i in 1:num) {ns[i] <- r1[i] + r2[i] + r3[i] + r4[i]}
# order by test
data <- data.frame(r1,r2,r3,r4, ns, t1 = T[,1], t2= T[,2]) #%>% arrange(t1)
r1 <- data$r1 ; r2 <- data$r2 ;  r3 <- data$r3 ; r4 <- data$r4
r <- matrix(ncol = 4, nrow = num, c(r1,r2,r3,r4)) ; r
ns <- data$ns
data24 <-list()
pos <- r1+r2
neg <- r3+r4
#data24 <- list( r = r, n = ns, NS= num , pos=pos, neg=neg, T=data$t1, num_ref=3, nt=2)
data24 <- list( r = r[1:4,], n = ns[1:4], NS= 4 , num_ref=3, nt=2)
data24

meta_model <- stan(file = "latent_trait_2_tests_meta_analysis.stan", 
           data =data24,
           iter = 500, 
           chains = 1,# verbose=TRUE,
         #  init = list(init,init, init, init),
            control=list(adapt_delta=0.8, max_treedepth =10)
            )
1 Like

I get:

meta_model <- stan_model(file = "meta.stan")
SYNTAX ERROR, MESSAGE(S) FROM PARSER:
Variable "b1" does not exist.
 error in 'model168671c7b2140_meta' at line 81, column 49
  -------------------------------------------------
    79:  for (j in 1:nt) {
    80:   // overall test accuracy estimates, for each test
    81:     fn[j] =  1 -Phi_approx((a11[j]) / sqrt(1 + b1[j]^2));
                                                        ^
    82:     Se[j] =    Phi_approx((a11[j]) / sqrt(1 + b1[j]^2));
  -------------------------------------------------

from this model. Remove the generated quantities block and running the model with the supplied data yields:

> samples <- sampling(
+   meta_model,
+   data = data24,
+   iter = 2000,
+   chains = 4,
+   cores = 4
+ )
Chain 3:  Elapsed Time: 5.04136 seconds (Warm-up)
Chain 3:                4.32505 seconds (Sampling)
Chain 3:                9.36641 seconds (Total)
Chain 3: 
Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
Chain 2: 
Chain 2:  Elapsed Time: 5.10571 seconds (Warm-up)
Chain 2:                4.4371 seconds (Sampling)
Chain 2:                9.54281 seconds (Total)
Chain 2: 
Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
Chain 1: 
Chain 1:  Elapsed Time: 5.41459 seconds (Warm-up)
Chain 1:                4.40047 seconds (Sampling)
Chain 1:                9.81506 seconds (Total)
Chain 1: 
Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
Chain 4: 
Chain 4:  Elapsed Time: 5.5526 seconds (Warm-up)
Chain 4:                4.4272 seconds (Sampling)
Chain 4:                9.97981 seconds (Total)
Chain 4: 
Warning messages:
1: The largest R-hat is 1.73, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hat 
2: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-ess 
3: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess 
> samples
Inference for Stan model: meta.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

              mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff  Rhat
a11_raw[1]    0.20    0.89 1.26   -1.28   -1.05    0.30    1.45    1.55     2 12.64
a11_raw[2]    0.31    1.79 2.55   -2.99   -2.16    0.46    2.80    3.49     2  6.70
a10_raw[1]   -0.20    0.88 1.25   -1.56   -1.45   -0.29    1.05    1.28     2 12.56
a10_raw[2]   -0.31    1.79 2.56   -3.50   -2.79   -0.42    2.16    3.02     2  6.68
p[1]          0.50    0.16 0.23    0.24    0.28    0.50    0.73    0.76     2 11.16
p[2]          0.50    0.33 0.46    0.01    0.03    0.50    0.97    0.99     2 20.14
p[3]          0.50    0.29 0.41    0.06    0.09    0.50    0.91    0.94     2 23.84
p[4]          0.50    0.32 0.45    0.03    0.05    0.50    0.95    0.97     2 39.84
a11[1]        0.20    0.89 1.26   -1.28   -1.05    0.30    1.45    1.55     2 12.64
a11[2]        0.31    1.79 2.55   -2.99   -2.16    0.46    2.80    3.49     2  6.70
a10[1]       -0.20    0.88 1.25   -1.56   -1.45   -0.29    1.05    1.28     2 12.56
a10[2]       -0.31    1.79 2.56   -3.50   -2.79   -0.42    2.16    3.02     2  6.68
lp__       -731.68    0.05 2.03 -736.59 -732.76 -731.36 -730.20 -728.72  1541  1.00

Samples were drawn using NUTS(diag_e) at Thu Aug 27 13:35:32 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

So (for me) not a massively long run time, but the diagnostics suggests pretty serious issues with the model. The traces show that the posterior using data from the 4 studies is multi modal, so possibly some additional constraints are required.

2 Likes

Hi,

Sorry I forgot to remove the b variables from generated quantities. I simplified it to post here.

This is a mixture model hence you will get label-switching unless declaring appropriate starting values and/or constrains - so for now I am just running with one chain.

When I run with four chains as you did, I still get a huge run time, 500 iterations take > 30 mins… this is very odd

I think this may be some sort of windows 10 bug. I tried re-installing R and rstan packages but didn’t help

I think it is more likely to be a problem in the model (see the folk theorem). Even if I constrain a11, a10 > 0 I still get a posterior that is multi-modal:

There chains to take longer to run (300 warmup iterations + 300 sampling iterations takes ~150 seconds), but still not quite as long as you experience (and are clearly no good for inference).

It looks like you have used the full dataset there, which is why it probably takes longer to run. If i try to do that it will take well over 5 hours.

1 Like

Hi,

Thanks for taking the time to look at this!

Here is a model with parameter constraints which should prevent the label-switching (at least it does for me)

I think that the constraints need to be parameter ordering constraints, not hard constraints like a10 > 0 like you set (see this).

It took ~30 seconds to run 1000 iterations on the reduced dataset now. Stangely, even the previous model without the contraints to prevent label switching runs fast now without the multimodality, I am not sure why it was so slow before.


data {
                int<lower=2> nt; // number of tests
		int<lower=0> NS;	// total # of studies
                int<lower=1> n[NS]; // number of individuals in each study 
                int r[NS,4];
	}

transformed data {
                int res1[NS, max(n)] ; // index as res1[s,i]
                int res2[NS, max(n)] ; 

for (s in 1:NS) {
	res1[s] = append_array( append_array( rep_array(0, r[s,1]), rep_array(1, r[s,2]) ) , append_array( rep_array(0, r[s,3]), rep_array(1, r[s,4]) ) )   ;
        res2[s] = append_array( append_array( rep_array(0, r[s,1]), rep_array(0, r[s,2]) ) , append_array( rep_array(1, r[s,3]), rep_array(1, r[s,4]) ) )   ; 
    }
}
	parameters {
	  vector[1] a10;
          vector[1] a20;
          real<lower=0,upper=1> p[NS]; 
          vector<upper=0>[1] alpha_raw1;
          vector<upper=0>[1] alpha_raw2;
	}

	transformed parameters {
          vector[1] a11;
          vector[1] a21;

        a11 = a10 + alpha_raw1; 
        a21 = a20 + alpha_raw2; 
     }

model {

    a10 ~ std_normal(); 
    a20 ~ std_normal(); 
    alpha_raw1 ~ std_normal(); 
    alpha_raw2 ~ std_normal(); 

 for (s in 1:NS) {
   for (i in 1:n[s]) { 
     real lps[2];
    lps[1] = bernoulli_lpmf(0| p[s])
    +  bernoulli_lpmf(res1[s,i] | Phi_approx(a10 ) ) // Fp (= 1-Sp) of each i 
    +  bernoulli_lpmf(res2[s,i] | Phi_approx(a20 ) );
    lps[2] = bernoulli_lpmf(1| p[s])
    +  bernoulli_lpmf(res1[s,i] | Phi_approx(a11 ) ) // Se of each i
    +  bernoulli_lpmf(res2[s,i] | Phi_approx(a21 ) );

    // marginalize
    target += log_sum_exp(lps);
   }
 }

}


Plots:

1 Like