Reduce_sum performance many threads

I want to use the nice reduce_sum implementation to parallelize the likelihood evaluation of my model. In this process, I recognized that two different implementations, that give the same result, have very different run times when many threads are used. Unluckily the slower implementation is much nicer and in some cases the only possible implementation without moving part of the likelihood evaluation outside of the partial_sum function. I wonder why the one implementation is so much faster and if there is a way to get the other implementation to the same speed.


Below is the code for the redcard example (as given in https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html). I will refer to this in the following as “Case A”. Here the vector “beta[1] + beta[2] * rating[start:end]” is inserted directly into the binomial_logit_lpmf function. Another way would be to define a vector res, set it to beta[1] + beta[2] * rating[start:end] and insert this res vector in the binomial_logit_lpm function (“Case B” from here on)
Redcard Example Case A

functions {                                                                                                                                                                                                        
  real partial_sum(int[] slice_n_redcards,                                                                                                                                                                         
                   int start, int end,                                                                                                                                                                             
                   int[] n_games,                                                                                                                                                                                  
                   vector rating,                                                                                                                                                                                  
                   vector beta) {                                                                                                                                                                                  
                                                                                                                                                                                                                   
    return binomial_logit_lpmf(slice_n_redcards |                                                                                                                                                                  
                               n_games[start:end],                                                                                                                                                                 
                               beta[1] + beta[2] * rating[start:end]);                                                                                                                                              }                                                                                                                                                                                                                
}                                                                                                                                                                                                                  
data {                                                                                                                                                                                                             
  int<lower=0> N;                                                                                                                                                                                                  
  int<lower=0> n_redcards[N];                                                                                                                                                                                      
  int<lower=0> n_games[N];                                                                                                                                                                                         
  vector[N] rating;                                                                                                                                                                                                
  int<lower=1> grainsize;                                                                                                                                                                                          
}                                                                                                                                                                                                                  
parameters {                                                                                                                                                                                                       
  vector[2] beta;                                                                                                                                                                                                  
}                                                                                                                                                                                                                  
model {                                                                                                                                                                                                            
                                                                                                                                                                                                                   
  beta[1] ~ normal(0, 10);                                                                                                                                                                                         
  beta[2] ~ normal(0, 1);                                                                                                                                                                                          
                                                                                                                                                                                                                   
  target += reduce_sum(partial_sum, n_redcards, grainsize,                                                                                                                                                         
                       n_games, rating, beta);                                                                                                                                                                     
} 

Redcard Example Case B

functions {                                                                                                                                                                                                        
  real partial_sum(int[] slice_n_redcards,                                                                                                                                                                         
                   int start, int end,                                                                                                                                                                             
                   int[] n_games,                                                                                                                                                                                  
                   vector rating,                                                                                                                                                                                  
                   vector beta) {                                                                                                                                                                                  
    vector[end-start+1] res;                                                                                                                                                                                       
    res = beta[1]+beta[2]*rating[start:end];                                                                                                                                                                       
    return binomial_logit_lpmf(slice_n_redcards | n_games[start:end], res);                                                                                                                                        
  }                                                                                                                                                                                                                
}

Where in Case B only the partial sum function is different. The rest is the same like in Case A.
When I test the run time for both cases, I see that when I am using one thread per chain, both cases take about 2:40 for 1000 warm-ups and 1000 samples, but if I use a lot of threads (namely 48 in this example), the run time for Case A is about 10 seconds and for Case B it is about 45 seconds. So defining and setting the res vector before entering it into the binomial_logit_lpmf slows down the calculation significantly when many threads are used.


In the above redcard example it is of course no problem to just use Case A, but in the model I want to fit, the total model is a superposition of different sources. In a very simple case of this model, the total model is simply given like this:

M_{tot}=\Sigma_i c_i \cdot V_i

Where V_i are precalculated vectors for the individual sources and I am only fitting for the normalizations c_i. The number of sources can be up to 30 and then it starts to get impractical to write everything like in Case A for the redcard example. I did the run time comparison also for a simple stan model with 5 sources, with some simulated mock data. So the sources code could be either

Simple Model Case A

functions {                                                                                                                                                                                                        
  real partial_sum(int[] counts, int start, int stop,                                                                                                                                                              
                   real[] c, vector[] base_counts, int num_bases){                                                                                                                                                 
                                                                                                                                                                                                                   
    return poisson_propto_lpmf(counts |                                                                                                                                                                            
                               c[1]*base_counts[1][start:stop]+                                                                                                                                                    
                               c[2]*base_counts[2][start:stop]+                                                                                                                                                    
                               c[3]*base_counts[3][start:stop]+                                                                                                                                                    
                               c[4]*base_counts[4][start:stop]+                                                                                                                                                    
                               c[5]*base_counts[5][start:stop]                                                                                                                                                     
                               );                                                                                                                                                                                  
  }                                                                                                                                                                                                                
}                                                                                                                                                                                                                  
                                                                                                                                                                                                                   
                                                                                                                                                                                                                   
data {                                                                                                                                                                                                             
  int<lower=0> datapoints;                                                                                                                                                                                         
  int counts[datapoints];                                                                                                                                                                                          
                                                                                                                                                                                                                   
  int num_bases;                                                                                                                                                                                                   
                                                                                                                                                                                                                   
  vector[datapoints] base_counts[num_bases];                                                                                                                                                                       
                                                                                                                                                                                                                   
  int grainsize;                                                                                                                                                                                                   
}                                                                                                                                                                                                                  
                                                                                                                                                                                                                   
parameters {                                                                                                                                                                                                       
                                                                                                                                                                                                                   
  real<lower=0> c[num_bases];                                                                                                                                                                                      
                                                                                                                                                                                                                   
}                                                                                                                                                                                                                  
                                                                                                                                                                                                                   
model{                                                                                                                                                                                                             
  c ~ lognormal(0,1);                                                                                                                                                                                              
  target += reduce_sum(partial_sum, counts, grainsize, c, base_counts, num_bases);                                                                                                                                 
}                                                                                                                                                                                                                  
                                                                                                                                                                                                                   
generated quantities {                                                                                                                                                                                             
  int ppc[datapoints];                                                                                                                                                                                             
  vector[datapoints] tot;                                                                                                                                                                                          
  tot = rep_vector(0.0, datapoints);                                                                                                                                                                               
  for (i in 1:num_bases){                                                                                                                                                                                          
    tot += c[i]*base_counts[i];                                                                                                                                                                                    
  }                                                                                                                                                                                                                
  ppc = poisson_rng(tot);                                                                                                                                                                                          
}   

or Simple Model Case B

functions {                                                                                                                                                                                                        
  real partial_sum(int[] counts, int start, int stop,                                                                                                                                                              
                   real[] c, vector[] base_counts, int num_bases){                                                                                                                                                 
    vector[stop-start+1] model_counts;                                                                                                                                                                             
    model_counts = rep_vector(0.0, stop-start+1);                                                                                                                                                                  
    for (i in 1:num_bases){                                                                                                                                                                                        
      model_counts += c[i]*base_counts[i][start:stop];                                                                                                                                                             
    }                                                                                                                                                                                                              
    return poisson_propto_lpmf(counts | model_counts);                                                                                                                                                             
  }                                                                                                                                                                                                                
}

The runtime analysis is similar compared to the redcard example. For one thread per chain, the runtime is nearly the same (about 8:30 minutes for 500 warm-ups and 300 samples), but if I use 48 threads the run time for case A is 1:30 and for Case B 9:00. So case B with 48 threads is slower than with one thread. In a real fit I have up to 30 of these sources and not always the same number of sources, so using a for loop like in Case B would be much nicer. Also in a more realistic model, I have more complicated sources, for which I do not only fit a normalization. Try to add them directly to the poisson_propto_lpmf function like in Case A is not always possible (or at least very difficult), which means I probably would need to move parts of the calculation outside of the partial_sum function, to avoid defining a vector in the partial_sum function, which is used to sum all the different source contributions together.


Is this large run time difference, when many threads are used, due to memory allocation for the model_counts vector, that has to be done by every thread?
And is there a way to make Case B as fast as Case A, even for many threads?

Hi!

What is the size of the data set here?

Throwing 48 threads on a 2:40 min problem does not sound its really needed, but still - your finding is interesting.

My guess (not sure though) is that the case A leads to a smaller AD footprint, since you create a product of parameters with a large vector as a temporary, which Stan is able to deal with in a more optimised way (essentially you avoid creating N terms on the AD tape). The case B code will always create N terms on the AD stack regardless.

Since you do observe good performance with 1 thread for Case B as well this suggests that you may be able to get back good performance by upping the grainsize - have you tried varying that?

Essentially what may happen here is that more cache is being used for case b whenever things are sliced into many (too) small pieces. So this is why you should probably limit with the grainsize the chunk sizes.

What CPU are you using? As I recall, the AMD CPUs do not have an equal amount of cache per core which can make a difference here.

The CPUs are intel Xeons. The real model and data used in this situation takes much longer to fit but it’s a very complicated model which is not as easy to understand as what @bbiltzing put here. but they are analogous

Thanks for the reply, @wds15!

The data sizes in the examples I used here are 124621 for the redcard example and 10000 for the other example.

In the run time comparison above I simply used grainsize=1, but if I set the grainsize = size_data_set/48, the redcard run time for Case A goes down to ~ 7 seconds and for Case B to 30 seconds. So Case A is still much faster.
For the simple model with the 5 sources both cases get much faster with grainsize = size_data_set/48. The run time of Case A goes down to 30 seconds and for Case B to 1:30. But still Case A is a factor of 3 faster.

And yes, the real model takes much longer to evaluate, I just created these simple examples, that show the same behavior.

You should try more values for grainsize. The optimal grainsize will use optimally your CPU cache which is not necessarily size_data_set/48 …

Maybe have a look at the TBB doc:

https://software.intel.com/en-us/node/506060

If c could be a vector, I think you could try the following if I am reading the loop right:

functions {                                                                                                                                                                                                        
  real partial_sum(int[] counts, int start, int stop,                                                                                                                                                              
                   vector c, vector[] base_counts, int num_bases){                                                                                                                                                 
    vector[stop-start+1] model_counts;                                                                                                                                                                             
    for (i in 1:(stop-start+1)){                                                                                                                                                                                      
      model_counts[i] = dot_product(c, base_counts[i+start-1]);
    }                                                                                                                                                                                                              
    return poisson_propto_lpmf(counts | model_counts);                                                                                                                                                             
  }                                                                                                                                                                                                                
} 

This would maybe be a bit faster?

Maybe even leaving it as an array and using to_vector would work, but I think the to_vector would hurt performance here.

functions {                                                                                                                                                                                                        
  real partial_sum(int[] counts, int start, int stop,                                                                                                                                                              
                   real[] c, vector[] base_counts, int num_bases){                                                                                                                                                 
    vector[stop-start+1] model_counts;                                                                                                                                                                             
    for (i in 1:(stop-start+1)){                                                                                                                                                                                    
      model_counts[i] = dot_product(to_vector(c), base_counts[i+start-1]);
    }                                                                                                                                                                                                              
    return poisson_propto_lpmf(counts | model_counts);                                                                                                                                                             
  }                                                                                                                                                                                                                
}

Though the indexing here gets really fun.

Thanks for the link.

I did some further testing for the influence of the grainsize:

Redcard Example
Runtime_compare_redcard.pdf (11.5 KB)

Simple Model Example
Here I reduced the number of warm-ups and sample steps to save some time.
Runtime_compare_simple_model.pdf (12.0 KB)

2 Likes

Thanks for the reply!
I tried your idea, but it turned out to be slower in this case (at least as I implemented it)

functions {                                                                                                                                                                                                
  real partial_sum(int[] counts, int start, int stop,                                                                                                                                                      
                   vector c, vector[] base_counts, int num_bases){                                                                                                                                         
    vector[stop-start+1] model_counts;                                                                                                                                                                     
    for (i in 1:(stop-start+1)){                                                                                                                                                                           
      model_counts[i] = dot_product(c, base_counts[i+start-1]);                                                                                                                                            
    }                                                                                                                                                                                                      
                                                                                                                                                                                                           
    return poisson_propto_lpmf(counts | model_counts);                                                                                                                                                     
  }                                                                                                                                                                                                        
}                                                                                                                                                                                                          
                                                                                                                                                                                                           
                                                                                                                                                                                                           
data {                                                                                                                                                                                                     
  int<lower=0> datapoints;                                                                                                                                                                                 
  int counts[datapoints];                                                                                                                                                                                  
                                                                                                                                                                                                           
  int num_bases;                                                                                                                                                                                           
                                                                                                                                                                                                           
  vector[num_bases] base_counts[datapoints];                                                                                                                                                               
                                                                                                                                                                                                           
  int grainsize;                                                                                                                                                                                           
}

In case you are on windows, then you can throw this into your CXXFLAGS:

-DBOOST_MATH_PROMOTE_DOUBLE_POLICY=false

(you can still put it in on Linux, but you won’t see speedups for lgamma based things like the Poisson, but for other distributions using digamma and the like).

Okay, thanks for the tip. I am running this on Linux.