Separate partial_sum statements work, but not reduce_sum

Hi,

I have specified a partial_sum function. The model is hierarchical (patients nested within studies), so I have broken up by study.

It seems to work fine if I specify the likelihood as two separate target += partial_sum(…), as in the commented out code below. However, when I pass the partial_sum function through to reduce_sum, compile, and try to run it I get “chain x finished unexpectedly” errors on all chains.

Any advice would be appreciated.

functions {
    real partial_sum(int[,,] y, int start, int end,  int nt, 
                   real[,,] u_nd, 
                   real[,,] u_d, 
                   vector p,   
                   matrix[] L_Omega_nd, 
                   matrix[] L_Omega_d, 
                   vector[] d, 
                   vector[] nd,
                   int[] ns) {
               int len = end - start + 1;
               int r0 = start - 1;        
               real log_lik;
               vector[len] log_lik_s;
                  for (s in 1:len) {     // loop over studies
                          vector[ns[s+r0]] log_lik_i;  
                          real lp_bern0; 
                          real lp_bern1; 
                          lp_bern0  = bernoulli_lpmf(0 | p[s+r0]); 
                          lp_bern1  = bernoulli_lpmf(1 | p[s+r0]); 
                        for (n in 1:ns[s+r0]) {          // loop over individuals
                            vector[nt] z_d;
                            vector[nt] z_nd;
                            vector[nt] y1dz; 
                            vector[nt] y1ndz;
                            real lp0;
                            real lp1;
                            real prev_d;
                            real prev_nd;
                            prev_d = 0;
                            prev_nd = 0;
                            for (i in 1:nt) { // loop over each test 
                              real bound_d = inv_logit( 1.7* ( -(d[s+r0,i] + prev_d) / L_Omega_d[s+r0,i,i]  ));
                              real bound_nd = inv_logit( 1.7* ( -(nd[s+r0,i] + prev_nd) / L_Omega_nd[s+r0,i,i]  ));
                                if (y[n,i,s] == 1) {
                                  real u_d1  = u_d[n,i,s+r0];
                                  real u_nd1 = u_nd[n,i,s+r0];
                                  z_d[i]   = (1/1.7)*logit(bound_d + (1 - bound_d)*u_d1);      
                                  z_nd[i]  = (1/1.7)*logit(bound_nd + (1 - bound_nd)*u_nd1);    
                                  y1dz[i]   = log1m(bound_d);  // Jacobian adjustment
                                  y1ndz[i]  = log1m(bound_nd); // Jacobian adjustment
                                }
                                if (y[n,i,s] == 0) {
                                  real u_d1  = u_d[n,i,s+r0];
                                  real u_nd1 = u_nd[n,i,s+r0];
                                  z_d[i]   = (1/1.7)*logit(bound_d*u_d1);     
                                  z_nd[i]  = (1/1.7)*logit(bound_nd*u_nd1);
                                  y1dz[i]   = log(bound_d);  // Jacobian adjustment
                                  y1ndz[i]  = log(bound_nd); // Jacobian adjustment
                                }
                              if (i < nt) prev_d  = L_Omega_d[s+r0,i +1,1:i ] * head(z_d ,i); 
                              if (i < nt) prev_nd = L_Omega_nd[s+r0,i +1,1:i ] * head(z_nd ,i);
                              // Jacobian adjustments imply z is truncated standard normal
                              // thus utility --- mu + L_Omega * z --- is truncated multivariate normal
                              }
                            lp1 = sum(y1dz)  +  lp_bern1;
                            lp0 = sum(y1ndz) +  lp_bern0;
                            log_lik_i[n] =  log_sum_exp(lp1,lp0);
                          }
                            log_lik_s[s] = sum(log_lik_i[1:ns[s+r0]]);
                     }
                      log_lik = sum(log_lik_s);
                      return(log_lik); 
              }
}

data {
  int gr; 
  real x;
  real y1;
  real y2;
  int<lower=1> nt; // number of tests
  int<lower=1> NS;	// total # of studies
  int<lower=1> ns[NS]; // number of individuals in each study  
  int<lower=0> y[max(ns),nt, NS]; // N individuals and nt tests, NS studies 
  int r[NS,4]; // data in summary (2x2 tables) format for generated quantities block
}
parameters {
     ordered[2] a1_m; 
     ordered[2] a2_m;   
     vector<lower=0>[nt] d_sd;
     vector<lower=0>[nt] nd_sd;
     vector[4] z1[NS];
     real<lower=0,upper=1> p[NS];   
     real<lower=0,upper=1> u_d[max(ns),nt, NS]; // nuisance that absorbs inequality constraints
     real<lower=0,upper=1> u_nd[max(ns),nt, NS]; // nuisance that absorbs inequality constraints
     cholesky_factor_corr[nt] L_Omega_nd[NS];
     cholesky_factor_corr[nt] L_Omega_d[NS];
}
transformed parameters {
     vector[nt] d_m; 
     vector[nt] nd_m;  
     vector[nt] d[NS];
     vector[nt] nd[NS];
     vector[NS] y1dz[4,nt];
     vector[NS] y1ndz[4,nt];

  d_m[1] = a1_m[2]; 
  d_m[2] = a2_m[2]; 
  nd_m[1] = a1_m[1]; 
  nd_m[2] = a2_m[1];

 for (s in 1:NS) {
  d[s,1] = d_m[1]  + d_sd[1]*z1[s,1]*y1;
  d[s,2] = d_m[2]  + d_sd[2]*z1[s,2]*y2;
  nd[s,1] = nd_m[1]   + nd_sd[1]*z1[s,3]*y1;
  nd[s,2] = nd_m[2]   + nd_sd[2]*z1[s,4]*y2;
                    }                
}
model {
// p ~ beta(1,5); 
 d_sd ~ std_normal(); 
 nd_sd ~ std_normal(); 
 for (s in 1:NS) 
    z1[s,] ~ std_normal(); 
 for (s in 1:NS) {
    L_Omega_nd[s,] ~ lkj_corr_cholesky(x);
    L_Omega_d[s,]  ~ lkj_corr_cholesky(x);    }                 
  // sens
   a1_m[2]  ~  normal(1, 0.4);
   a2_m[2]  ~  normal(1, 0.4);
  // spec
  a1_m[1]  ~  normal(-1, 0.4);
  a2_m[1]  ~  normal(-1, 0.4);

    //  target += partial_sum( y[, ,1:2], 1, 2, nt, u_nd[,,], u_d[,,], to_vector(p), 
    //                                            L_Omega_nd[,,], L_Omega_d[,,], d[,], nd[,], ns);
    //  target += partial_sum( y[, ,3:4], 3, 4, nt, u_nd[,,], u_d[,,], to_vector(p), 
     //                                           L_Omega_nd[,,], L_Omega_d[,,], d[,], nd[,], ns);

  target += reduce_sum(partial_sum, y,  gr , nt, u_nd[,,], u_d[,,], to_vector(p), 
                                                L_Omega_nd[,,], L_Omega_d[,,], d[,], nd[,], ns);


}


R code :

num <- 15
T1 <- c(1,2,2,2,2,3,2,3,2,2,2,2,1,3,3)
T2 <- rep(4, times = 15)
T <- matrix(c(T1, T2), ncol=2, nrow=15)
r1 <- c(97, 1,20,14,93, 31, 132, 305, 17, 44, 54, 39, 64, 7, 25) ; length(r1) #tp
r2 <- c(20, 0 ,3, 1 ,30 ,6,23,9,19,5,22,4, 0, 4,44) ; length(r2) #fn
r3 <- c(14,9,2,44,183, 3,54,98,31,61,63,536, 612 , 31, 3) ; length(r3) #fp. gs=0. ind=1
r4 <- c(297, 45, 231, 285, 1845, 248, 226, 256, 580, 552, 980, 227, 2051, 98, 170) ; length(r4) # tn, ind=0, gs = 0
ns <- c()
for (i in 1:num) {ns[i] <- r1[i] + r2[i] + r3[i] + r4[i]}
data <- data.frame(r1,r2,r3,r4, ns, t1 = T[,1], t2= T[,2]) #%>% arrange(t1)
r1 <- data$r1 ; r2 <- data$r2 ; r3 <- data$r3 ; r4 <- data$r4
r <- matrix(ncol = 4, nrow = num, c(r1,r2,r3,r4)) ; r
ns <- data$ns
data24 <-list()
pos <- r1+r2
neg <- r3+r4
data24 <- list( r = r, n = ns, NS= num , pos=pos, neg=neg, T=data$t1, num_ref=3, nt=2)
NS=15
sum(ns) # N= 10,968

y_list <- list()
y1a <- list(length = max(ns))
y1b <- list(length = max(ns))
pa <- list(length = max(ns))

max <- max(ns)

for (i in 1:NS) {
y1a[[i]] = c(rep(1, r[i,1]), rep(1, r[i,2]), rep(0, r[i,3]), rep(0, r[i,4]), rep(2, max - ns[i] ))
y1b[[i]] = c(rep(1, r[i,1]), rep(0, r[i,2]), rep(1, r[i,3]), rep(0, r[i,4]), rep(2, max - ns[i] ))
y_list[[i]] = matrix(ncol = 2, c(y1a[[i]] , y1b[[i]] ))
}

y = array(data= unlist(y_list), dim = c( max(ns), 2, NS))

meta_model2 <- mod$sample(
data = list(x=4, y1 = 0, y2=0.6, NS=4, ns=c(ns[c(1:4)]), y=y[1:428,c(1:4)], # first 4 studies
nt = 2, r = r[c(1:4),]),
seed = 123,
chains = 4,
parallel_chains = 4,
iter_warmup = 250,
iter_sampling = 350,
refresh = 12,
adapt_delta = 0.80,
max_treedepth = 10)

meta_model2r <- rstan::read_stan_csv(meta_model2$output_files())

Reduce sum slices the first dimension not the last one.

1 Like

Thanks!

If I have few studies (what I am splitting over), but many observations within studies. For speed, would it make more sense to partition over the Individuals instead and pass through the reduce_sum as a loop over the studies?

Start simple!

Do you have more CPUs than studies? Then study may not be the best factor to split. If I recall right your hierarchical model is grouped by patient…then it often makes sense to slice over patients and keep patient specific parameters in the slicing argument.

If you got things working with splitting by study, then you may as a first step nest things. So outer reduce_sum is study and within that you do reduce_sum over patients… but that should be less efficient than slicing over patients everything.

In any case… just see where you are now with performance with what you have. Maybe its fast enough already.

1 Like

I have 12 cores/24 threads, and at the moment testing on a smaller dataset with 13 studies (not the data shown in the post, which has 15 studies with a few of large sizes and takes hours to run). Since the number of cores is similar to the number of studies I left the grainsize at 1. I’ll update this with results with different threads_per_chain values. CPU usage seems to be lower than I expected - at the moment I am running 4 chains with 6 threads per chain and CPU is only at 50% utilization (if I run 4 chains normally it’s ~ 20%), is that normal?

Don’t use hyper threading. This was never seen to be useful for Stan programs. Not sure about the rest.

As in, disable hyperthreading within the BIOS?

edit: there seems to be a very small benefit to multithreading here: ~ 640 seconds normally (20% CPU util) vs ~ 550 seconds with 3 threads per chain (~40% CPU util), both running with 4 parallel chains. Will try slicing over individuals. Still have no idea why the CPU usage is so low.

Edit 2: I also tried using reduce_sum on a non-heirachecal version of the model (so just patients , 1 study) by modifying the partial_sum so that it is essentially the same but without the loop over s. I seem to get no speedup whatsoever relative to just using 1 thread per chain with using various grainsiizes

No need to disable things in the bios. Just use as many threads as cores.

Your problem is heavily memory bound. That means the computations are very cheap and you move around a lot of variables. Right now the model runs in 10 min… sounds not so bad to me, so why go parallel?

What I would try is

  • move the random effect and the L_Omega into the slicing argument by packing them together… this one is important, because you are moving a lot of parameters into the shared arguments which is not good for performance
  • if lots of parameters must stay in the shared portion of arguments, then choose grainsize largish
  • maybe code up a first non-hierarchical version to see how that goes
  • try to find a formulation where you can take advantage of vectorization over the observations, but looking at the model this might be impossible.

… but I am not very optimistic - sorry to say - about this model going well with reduce_sum. All you do is move memory around, but you don’t really compute a lot of things as it looks to me.

1 Like

It only takes 10 mins because I’m using a smaller dataset to test, otherwise it takes hours. I have tried a non-heirachecal version just slicing over patients and multithreading doesnt seem to make a difference at all. I’ll try again with different grainsiizes and the other suggestions to see if it helps, thanks.

Haven’t had luck getting reduce_sum to actually make anything notably more efficient. However, I have noticed that * lowering * adapt_delta to 0.6 and max_treedepth to 6, and using step_size=1.1 increases speed quite a bit with no errors on datasets tested (w/ same estimated posterior densities - so it doesn’t produce biased answers for these data)

@wds15 How would I do this without a loop? I can do it with a loop over reduce_sum but pretty sure several calls to reduce_sum (one call for each study) aren’t a good idea

1 Like

There are no treedepth warnings?

I think you can something like this to check it in cmdstanr:

as_draws_df(fit$sampler_diagnostics()) %>%
  mutate(max_treedepths = treedepth__ == 6) %>%
  pull(max_treedepths) %>% sum

You’re looking for the treedepth__ column == 6 in the sampler diagnostics column.

1 Like

Looks like it is very dataset and prior dependent. On the dataset with relatively larger studies (two have >2000 patients) which takes a really long time to run, i cant lower the treedepth below the default 10 without warnings.

The tree depth warning is about efficiency, but if you still get a good posterior ess then you are fine. It’s known that the warmup can be very generous with the tree depth which is sometimes excessive.

1 Like