Map_rect/MPI & Performance with Many Shards

I am currently parallelizing some of my time-series latent variable models using map_rect on a cluster with MPI. I have the code working etc., but am having some trouble accurately predicting/understanding the parallel workload overhead. Given conditional independence restrictions, I end up with a big number of shards, on the order of several thousand. Each shard contains about 4000 observations.

What I have recently noticed and find very interesting is that given a fixed (and very large) number of shards, there is an overhead issue with the possibilities of increasing performance with MPI. Part of that is the model I’m fitting has some hierarchical parameters that can not be broken up across shards due to conditional independence, so there is increased overhead as the numbers of cores increase. As a result, I find that the gradient calculations are actually faster with 300 cores than 700 cores (the maximum allotted to me on the cluster).

I’m bringing this up because it seems surprising to me that increasing the number of cores with a fixed number of shards, even considering some overhead from communicating priors, would not lead to an improvement but actually a decrease in performance. The order of magnitude is small – a gradient calculation at 300 cores takes 0.35 seconds while a gradient calculation at 700 cores takes 0.4 seconds.

It’s difficult of course to experiment well with data this large, but I wanted to post here as I found these upper limits surprising.

Operating System: Linux (CENTOS)
Interface Version: Latest Cmdstan Github
Compiler/Toolkit: MPI GCC

2 Likes

Whow! I would be curious to hear more about this problem…it’s certainly big! If you could post the Stan code then that would help to understand and maybe one can say more.

Does your cluster have infiniband or what else do you use for the network?

Other than that it is clear that at some point the overhead due to communication kills you. 80 cores was the largest I ever did (and heard of) …this is now a different order!

Yes I’m always pushing the limit :). It’s a bad habit…

I have a time-varying latent variable model. Code is below. Sorry it’s not commented so well but hopefully the map_rect is clear enough.

functions {
 vector overT(vector allparams, vector cit,
               real[] time_points, int[] allintdata) {
                 
    // load indices
    int J = allintdata[1];
    int T = allintdata[2];
    int coup = allintdata[3];
    real t = time_points[1];
    int N = (num_elements(allintdata)-3)/5; // has to be hard-coded unfortunately, no way to 
                                        // pass other info in
    int gen_out[N] = allintdata[4:(N+3)]; // retrieve indices in order
    int country_code[N] = allintdata[(N+4):(2*N+3)];
    int tt[N] = allintdata[((2*N)+4):(3*N+3)];
    int jj1[N] = allintdata[((3*N)+4):(4*N+3)];
    int jj2[N] = allintdata[((4*N)+4):(5*N+3)];
    // citizen params 
    
    real delta_11 = cit[1];
    real delta_10 = cit[2];
    real delta_21 = cit[3];
    real delta_20 = cit[4];
    real beta_1 = cit[5];
    real beta_0 = cit[6];
    
    vector[N] lp_out;
                
  // loop over outcome
  // use hurdle model for missing data (99999)
  // conditional on passing hurdle, use Poisson distribution for retweet counts
  
    for(n in 1:N) {
      
      real this_alpha1 = allparams[(J*(tt[n]-1))+jj1[n]];
      real this_alpha2 = allparams[(T*J + J*(tt[n]-1))+jj2[n]];
      
      if(gen_out[n]==99999) {
        real lp =  bernoulli_logit_lpmf(1 | delta_10*this_alpha1  +
                                        delta_20*this_alpha2 - beta_0);
        lp_out[n] = lp;
      } else {
        real lp = bernoulli_logit_lpmf(0 | delta_10*this_alpha1  +
                                        delta_20*this_alpha2 - beta_0);
        real ll = poisson_log_lpmf(gen_out[n]|delta_11*this_alpha1 +
                                  delta_21*this_alpha2 -
                                  beta_1);
        lp_out[n] = lp + ll;
      }
    }
   return [sum(lp_out) + normal_lpdf(cit|0,3)]'; // return joint probability of the shard
 }
}
data {
  int<lower=1> J;              // number of elites
  int<lower=1> K;              // number of citizens
  int<lower=1> N;              // number of observations per shard
  int<lower=1> T;  //number of time points
  int<lower=1> C; //number of total data columns
  int<lower=1> S; //number of shards in data for map_rect = number of citizens
  int alldata[S,N]; // data in shard format
  real time_points[S,1]; // counter for citizen IDs
  int coup; //when the coup happens
  vector[T-1] time_gamma;
}
transformed data {
  // calculate how the indices will work for the transformed parameters
  // varying is easy for citizens
  // static must include *all* time-varying parameters as they *cannot* vary across shards
  int vP = 6; // need all citizen parameters discrimination + difficulty for one citizen = 4 parameters
  int dP = 2*T*J; // all alpha parameters plus adjustment/betax/sigmas/country
  //need a vector to pad the time-varying parameters for the case when T=1
  vector[J] padT = rep_vector(0,J);
  real x_r[S,0]; // nothing vector to fil out map_rect 
}
parameters {    
  vector[vP] varparams[S];
  vector<lower=-1,upper=1>[4] adj_in1;  //adjustment parameters
  vector<lower=-1,upper=1>[4] adj_out1; //adjustment parameters
  vector<lower=-1,upper=1>[4] adj_in2;  //adjustment parameters
  vector<lower=-1,upper=1>[4] adj_out2; //adjustment parameters
  vector[4] alpha_int1; //drift
  vector[4] alpha_int2; //drift
  vector[4] betax1; //effects of coup
  vector[4] betax2; //effects of coup
  vector[dP-(2*J)] dparams_nonc; // non-centering time series
  vector<lower=0>[3] sigma_time1; //heteroskedastic variance by ideological group
  vector<lower=0>[3] sigma_time2; //heteroskedastic variance by ideological group
}

transformed parameters {
  
    vector[dP] dparams;
    vector[J] sigma_time1_con;
    vector[J] sigma_time2_con;
    
    sigma_time1_con = append_row([.1]',sigma_time1);
    sigma_time2_con = append_row([.1]',sigma_time2);
  
  // pack all the citizen parameters into an array vector for usage in map_rect
  
  //all elite params are in non-varying vectors
  
  // append all other parameters to one big vector that is passed to all shards
  // order.
  // pack this vector in a way that allows for time series to 
  // influence each other via heirarchical priors
  
  for(t in 1:T) {
    if(t==1) {
      dparams[1:J] = alpha_int1;
      dparams[(T*J+1):(T*J+J)] = alpha_int2;
    } else {
      for(j in 1:J) {
        int other;
        if(j==1) {
          other = 2;
        } else if(j==2) {
          other=1;
        } else if(j==3) {
          other=4;
        } else if(j==4) {
          other=3;
        }
        dparams[((t-1)*J + j)] = alpha_int1[j] +
                              adj_in1[j]*dparams[((t-2)*J + j)] +
                              adj_out1[j]*dparams[((t-2)*J + other)] +
                              betax1[j]*time_gamma[t-1] +
                              sigma_time1_con[j]*dparams_nonc[((t-1)*J + j)];
        dparams[(T*J + (t-1)*J +j)] = alpha_int2[j] +
                              adj_in2[j]*dparams[((t-2)*J + j + T*J)] +
                              adj_out2[j]*dparams[((t-2)*J + other + T*J)] +
                              betax2[j]*time_gamma[t-1] +
                              sigma_time2_con[j]*dparams_nonc[((t-2)*J + j + (T-1)*J)];
      }
    }
  }
  
}

model {
  
  //pin the intercepts for D2
  
  alpha_int2[1] ~ normal(1,.01);
  alpha_int2[2] ~ normal(-1,.01);
  alpha_int2[3:4] ~ normal(0,1);
  
  alpha_int1[1] ~ normal(-1,.01);
  alpha_int1[2] ~ normal(1,.01);
  alpha_int1[3:4] ~ normal(0,1);
  adj_out1 ~ normal(0,2);
  adj_in2 ~ normal(0,2);
  adj_out2 ~ normal(0,2);
  adj_in1 ~ normal(0,2);
  dparams_nonc ~ normal(0,1); // non-centering time series prior

  sigma_time1 ~ inv_gamma(30,3); // constrain the variance to push for better identification
  sigma_time2 ~ inv_gamma(30,3); // constrain the variance to push for better identification
  betax1 ~ normal(0,3);
  betax2 ~ normal(0,3);
  
  // parallelize the likelihood with map_rect
  
  target += sum(map_rect(overT, dparams, varparams, time_points, alldata));

}
generated quantities {
  matrix[J,T] alpha1_m = to_matrix(dparams[1:(T*J)],J,T);
  matrix[J,T] alpha2_m = to_matrix(dparams[(T*J+1):(2*T*J)],J,T);
}


As far as the cluster goes, it has Intel Xeon 2.4 Ghz chips with Infiniband.

This is fairly involved… so I may easily overlook things. However, from glancing at the code you should consider to vectorize things. For example, you could create and array of indices which contain all the data items which are missing (where gen_out is 99999) for those data items. Then you can have the loop over N in map_rect to build up the arguments to be given to the bernoulli and the poisson such that the call to the bernoulli and poisson can finally happen using their vectorized versions. I hope it’s clear what I mean. In pseudo code:

int[number_of_missing] idx_missing;
int[number_of_nonmissing] idx_nonmissing;
vector[number_of_missing] missing_logit;
vector[number_of_nonmissing] nonmissing_logit;
// one more for the poisson_log

for(i in 1:number_of_missing) {
   // setup missing_logit
}
for(i in 1:number_of_nonmissing) {
   // setup nonmissing_logit and the poisson_log
}

// finally you call in a vectorized way the bernoulli and the poisson

If each vectorized chunk is large, then this should make things run a lot faster.

FYI: If you have defined S shards and you have C cpus, then you end up calculating per CPU blocks of size S/C…the “scheduling” is very simple minded.

Out of curiosity… what are the problem sizes? Does this model converge and what is the scientific question addressed with it?

Hi @wds15 -

Thanks so much for asking about my research :D. This is a study of polarization between Islamists and secularists on Twitter in the Arab Middle East. You can see an earlier draft of the paper using variational Bayes here:

https://osf.io/preprints/socarxiv/wykmj/

What the model is doing is trying to uncover trends in polarization between rival groups on Twitter over time, and particularly to know to what extent groups are influencing each other. So it’s a combined measurement / time series model. In long format I have a ~20m row dataset representing 7k columns X 8 rows X 250 time points (or thereabouts). The time dimension adds a lot of heft to the data.

To update on the code, I did do the vector packing you mentioned and it did speed things up, probably about 10% or so. Also I discovered while doing this that there many not be a true bottleneck after all. The gradient calculations that Stan reports are not predictive of MPI/maprect performance. Even though the gradient calculations might take longer, using additional cores can still improve performance. So I ended up using all 700 cores and it did help (though I will admit the marginal benefit appeared to be much smaller than it could have been with a better parallelized model).

As I have been able to run multiple chains and they converged, I consider this done for now. Thanks for your help!

2 Likes

It’s really great to hear that map rect basically made this possible…without scalable parallelism your model would probably not fit in finite time. So congrats on managing all that!

1 Like