I want to start by acknowledging that I am very much new to Stan and statistical modeling in general, and would appreciate any feedback that I receive, for this model and otherwise.
Problem
As part of my Stan Model specification, I wrote a custom function to bin data in 2D and compute counts. Currently this involves three nested loops and is very slow, and I don’t know how to vectorize this or make it faster in any other way.
The dataset I have is generated from ~ 75000 data points, with 2D bin vectors of length ~ 35 each. The gradient computation is extremely slow, as you can see in this message.
Aug 12, 2020, 5:44:43 PM WARNING Gradient evaluation took 0.320882 seconds
Aug 12, 2020, 5:44:43 PM WARNING 1000 transitions using 10 leapfrog steps per transition would take 3208.82 seconds.
I also want to note that I have tried specifying initial values for some of the parameters in hope that it might speed up sampling, but to no avail.
Below, I have provided some additional description of the data and the model I’m try to create.
Data and Model Description:
The data I am working with is counts of galaxies, binned by two properties - mass and colour (colour is a continuous variable here).
I am trying to create a model that assumes galaxies as coming from one of two galaxy types (called blue and red). I understand that HMC does not work with discrete parameters, so I generate the galaxy type by instead sampling a uniform random variable and assigning the type blue or red based on whether the sampled value is smaller or greater than a threshold, which I am also learning.
For each galaxy type, I am modeling the mass and colour as coming from independent Gaussian distributions with priors on means and variances, which are themselves being learned. Once I generate the masses and colours of galaxies, I call a custom function to calculate counts of galaxies falling in each bin.
The only data I have is counts of galaxies, not actual labels. To specify the likelihood, I model the actual counts for each bin as being normally distributed about the counts computed by my function. Maybe this is not the best way to model things, in which case I appreciate feedback.
My model is specified below. I have also attached the CSV containing binned galaxy counts. Any help on making the sampling faster, or improving the model in general, would be really helpful.
functions {
vector group_by_bins(int N_galaxies, int N_counts, int N_mass_bins, int N_colour_bins, vector mass, vector colour, vector mass_bins, vector colour_bins) {
vector [N_counts] counts;
counts = rep_vector(0.0, N_counts);
for (k in 1:N_galaxies) {
real logmstar = mass[k];
real uminusr = colour[k];
int i;
int found;
int j;
i = 1;
found = 0;
while (i < N_mass_bins && found == 0) {
j = 1;
while (j < N_colour_bins && found == 0) {
if ( (mass_bins[i] <= logmstar)
&& (colour_bins[j] <= uminusr)
&& (i == N_mass_bins || mass_bins[i+1] > logmstar)
&& (j == N_colour_bins || colour_bins[j+1] > uminusr)) {
counts[(i-1)*N_colour_bins + j] += 1;
found = 1;
}
j += 1;
}
i += 1;
}
}
return counts;
}
}
data {
int<lower=0> N_galaxies;
int<lower=0> N_mass_bins;
int<lower=0> N_colour_bins;
vector[N_mass_bins] mass_bins;
vector[N_colour_bins] colour_bins;
vector[N_mass_bins*N_colour_bins] count_values;
}
transformed data {
real rho = 0.5;
int N_counts = N_mass_bins*N_colour_bins;
real mass_mu_blue = 9.0;
real mass_mu_red = 10.4;
real colour_mu_blue = 1.0;
real colour_mu_red = 2.0;
real mu_sigma = 0.1;
//corresponding mean = 0.66, sigma = 0.01
real mass_sigma_blue_a = 4446.44;
real mass_sigma_blue_b = 2963.63;
//corresponding mean = 0.33, sigma = 0.01
real mass_sigma_red_a = 1113.11;
real mass_sigma_red_b = 370.70;
//corresponding mean = 0.25, sigma = 0.01
real colour_sigma_blue_a = 627.0;
real colour_sigma_blue_b = 156.5;
//corresponding mean = 0.16, sigma = 0.01
real colour_sigma_red_a = 279.77;
real colour_sigma_red_b = 46.46;
real sigma = 1;
}
parameters {
real<lower=0, upper=1> pi;
vector<lower=0, upper=1>[N_galaxies] gs;
//these constraints are based on the dataset I have, maybe they're not necessary.
vector<lower=9, upper=12>[N_galaxies] mass;
//same for these
vector<lower=-0.184, upper=3>[N_galaxies] colour;
//same for these
vector<lower=9, upper=12>[2] mu_mass;
vector<lower=0>[2] sigma_mass;
//same for these
vector<lower=-0.184, upper=3>[2] mu_colour;
vector<lower=0>[2] sigma_colour;
}
model {
int gal_type[N_galaxies];
vector[N_counts] counts;
pi ~ beta(2, 2);
gs ~ uniform(0, 1);
for (i in 1:N_galaxies) {
gal_type[i] = gs[i] <= pi ? 1 : 2; //1 is blue, 2 is red
}
mu_mass[1] ~ normal(mass_mu_blue, mu_sigma);
mu_mass[2] ~ normal(mass_mu_red, mu_sigma);
mu_colour[1] ~ normal(colour_mu_blue, mu_sigma);
mu_colour[2] ~ normal(colour_mu_red, mu_sigma);
sigma_mass[1] ~ inv_gamma(mass_sigma_blue_a, mass_sigma_blue_b);
sigma_mass[2] ~ inv_gamma(mass_sigma_red_a, mass_sigma_red_b);
sigma_colour[1] ~ inv_gamma(colour_sigma_blue_a, colour_sigma_blue_b);
sigma_colour[2] ~ inv_gamma(colour_sigma_red_a, colour_sigma_red_b);
mass ~ normal(mu_mass[gal_type], sigma_mass[gal_type]);
colour ~ normal(mu_colour[gal_type], sigma_colour[gal_type]);
counts = group_by_bins(N_galaxies, N_counts, N_mass_bins, N_colour_bins, mass, colour, mass_bins, colour_bins);
count_values ~ normal(counts, sigma);
}
counts.csv (20.6 KB)