Tips to vectorize / speed up custom function with nested loops

I want to start by acknowledging that I am very much new to Stan and statistical modeling in general, and would appreciate any feedback that I receive, for this model and otherwise.


Problem

As part of my Stan Model specification, I wrote a custom function to bin data in 2D and compute counts. Currently this involves three nested loops and is very slow, and I don’t know how to vectorize this or make it faster in any other way.

The dataset I have is generated from ~ 75000 data points, with 2D bin vectors of length ~ 35 each. The gradient computation is extremely slow, as you can see in this message.

Aug 12, 2020, 5:44:43 PM	WARNING	Gradient evaluation took 0.320882 seconds
Aug 12, 2020, 5:44:43 PM	WARNING	1000 transitions using 10 leapfrog steps per transition would take 3208.82 seconds.

I also want to note that I have tried specifying initial values for some of the parameters in hope that it might speed up sampling, but to no avail.

Below, I have provided some additional description of the data and the model I’m try to create.


Data and Model Description:

The data I am working with is counts of galaxies, binned by two properties - mass and colour (colour is a continuous variable here).

I am trying to create a model that assumes galaxies as coming from one of two galaxy types (called blue and red). I understand that HMC does not work with discrete parameters, so I generate the galaxy type by instead sampling a uniform random variable and assigning the type blue or red based on whether the sampled value is smaller or greater than a threshold, which I am also learning.

For each galaxy type, I am modeling the mass and colour as coming from independent Gaussian distributions with priors on means and variances, which are themselves being learned. Once I generate the masses and colours of galaxies, I call a custom function to calculate counts of galaxies falling in each bin.

The only data I have is counts of galaxies, not actual labels. To specify the likelihood, I model the actual counts for each bin as being normally distributed about the counts computed by my function. Maybe this is not the best way to model things, in which case I appreciate feedback.


My model is specified below. I have also attached the CSV containing binned galaxy counts. Any help on making the sampling faster, or improving the model in general, would be really helpful.


functions {
    vector group_by_bins(int N_galaxies, int N_counts, int N_mass_bins, int N_colour_bins, vector mass, vector colour, vector mass_bins, vector colour_bins) {
        vector [N_counts] counts;
        counts = rep_vector(0.0, N_counts);
        for (k in 1:N_galaxies) {
            real logmstar = mass[k];
            real uminusr = colour[k];
            int i;
            int found;
            int j;
            i = 1;
            found = 0;
            while (i < N_mass_bins && found == 0) {
                j = 1;
                while (j < N_colour_bins && found == 0) {
                    if (   (mass_bins[i] <= logmstar)
                        && (colour_bins[j] <= uminusr)
                        && (i == N_mass_bins || mass_bins[i+1] > logmstar)
                        && (j == N_colour_bins  || colour_bins[j+1] > uminusr)) {
                        counts[(i-1)*N_colour_bins + j] += 1;
                        found = 1;
                    }
                    j += 1;
                }
                i += 1;
            }
        }

        return counts;
    }
}

data {
    int<lower=0> N_galaxies;
    int<lower=0> N_mass_bins;
    int<lower=0> N_colour_bins;
    vector[N_mass_bins] mass_bins;
    vector[N_colour_bins] colour_bins;
    vector[N_mass_bins*N_colour_bins] count_values;
}

transformed data {
    real rho = 0.5;
    int N_counts = N_mass_bins*N_colour_bins;

    real mass_mu_blue = 9.0;
    real mass_mu_red = 10.4;
    real colour_mu_blue = 1.0;
    real colour_mu_red = 2.0;
    real mu_sigma = 0.1;

    //corresponding mean = 0.66, sigma = 0.01
    real mass_sigma_blue_a = 4446.44;
    real mass_sigma_blue_b = 2963.63;

    //corresponding mean = 0.33, sigma = 0.01
    real mass_sigma_red_a = 1113.11;
    real mass_sigma_red_b = 370.70;

    //corresponding mean = 0.25, sigma = 0.01
    real colour_sigma_blue_a = 627.0;
    real colour_sigma_blue_b = 156.5;

    //corresponding mean = 0.16, sigma = 0.01
    real colour_sigma_red_a = 279.77;
    real colour_sigma_red_b = 46.46;

    real sigma = 1;
}

parameters {
    real<lower=0, upper=1> pi;
    vector<lower=0, upper=1>[N_galaxies] gs;
    //these constraints are based on the dataset I have, maybe they're not necessary.
    vector<lower=9, upper=12>[N_galaxies] mass;
    //same for these
    vector<lower=-0.184, upper=3>[N_galaxies] colour;
    //same for these
    vector<lower=9, upper=12>[2] mu_mass;
    vector<lower=0>[2] sigma_mass;
    //same for these
    vector<lower=-0.184, upper=3>[2] mu_colour;
    vector<lower=0>[2] sigma_colour;
}


model {
    int gal_type[N_galaxies];
    vector[N_counts] counts;
    pi ~ beta(2, 2);
    gs ~ uniform(0, 1);

    for (i in 1:N_galaxies) {
        gal_type[i] = gs[i] <= pi ? 1 : 2; //1 is blue, 2 is red
    }

    mu_mass[1] ~ normal(mass_mu_blue, mu_sigma);
    mu_mass[2] ~ normal(mass_mu_red, mu_sigma);
    mu_colour[1] ~ normal(colour_mu_blue, mu_sigma);
    mu_colour[2] ~ normal(colour_mu_red, mu_sigma);

    sigma_mass[1] ~ inv_gamma(mass_sigma_blue_a, mass_sigma_blue_b);
    sigma_mass[2] ~ inv_gamma(mass_sigma_red_a, mass_sigma_red_b);
    sigma_colour[1] ~ inv_gamma(colour_sigma_blue_a, colour_sigma_blue_b);
    sigma_colour[2] ~ inv_gamma(colour_sigma_red_a, colour_sigma_red_b);

    mass ~ normal(mu_mass[gal_type], sigma_mass[gal_type]);
    colour ~ normal(mu_colour[gal_type], sigma_colour[gal_type]);

    counts = group_by_bins(N_galaxies, N_counts, N_mass_bins, N_colour_bins, mass, colour, mass_bins, colour_bins);
    count_values ~ normal(counts, sigma);
}

counts.csv (20.6 KB)

Hi,

I can’t comment about the modeling, not familiar with the subject.

But Stan part, you can move counts definition to transformed parameters block, which might make this faster.

transformed parameters {
    vector[N_counts] counts;
    counts = group_by_bins(N_galaxies, N_counts, N_mass_bins, N_colour_bins, mass, colour, mass_bins, colour_bins);
}

All those if statements wrt to parameters kill the performance of stan!

I do not really have time, but the model you describes sounds like a mixture model, but I do not see any marginalisation happening. In case I am totally off, then sorry for that… otherwise check the manuals for mixture modelling and how to marginalise.

I would definitely second the suggestion to take look at the user guide sections on mixture models and latent discrete parameters. You should be able to find more examples by searching for those terms on the DIscourse too.

From what I understand, you have a number of galaxies which have two traits (color, mass); these traits have been measured on an ordinal scale but have underlying continuous values. If this is the case, I think you may be able to represent this as a pair of ordinal logistic regressions (or some similar ordinal response model); there’s some information in the manual here and you can find more examples in the Discourse; the brms documentation also has some good discussion about ordinal models.

Finally, it may be a good idea to look at the prior choice recommendation wiki; in particular, the inverse gamma priors on sigma are not recommended.

Welcome to the Stan community.

1 Like

Thank you @ahartikainen @wds15 and @Christopher-Peterson for you comments and suggestions! I will work on incorporating them and will update here with how they fared.

I did come up with a way to remove the loops and the if statement, but that involves 1) using ceil function to find the index of the closest bin, and 2) converting the output of ceil from real value to integer so I can use it as an array index (I used binary search to find the closest integer, as described in some of the other posts on Discourse). I understand from other posts and from the Stan Guide that both these steps would affect the efficiency of NUTS. I implemented those changes on my end to see if things change, and I got only marginal performance improvement.

Re the suggestions you guys made,

  1. The transformed parameters is easy to try, so I’ll do that first.

  2. @Christopher-Peterson your understanding is correct. I will read the user guide on mixture models and other links you suggested. Thanks also for the suggestion of wiki page for prior choice recommendation. Looking forward to understanding why IG prior is not recommended.

I just want to mention that this question/topic can be closed, though I’m not sure how exactly to do that.

As for what worked, I was actually able to obtain additional data that avoids the need for having the binning code in Stan, so the above issues no longer bug me. The suggestion for using mixture models was still very helpful. The resulting model is still slow due to the presence of a for loop, but as per the documentation that is unavoidable with mixture models.