Hi all,
New to Stan and could use some help. I have 96 threads available to me, but I can’t seem to specify reduce_sum in a way that speeds up my model. I’m too new to Stan to understand how I should properly use reduce_sum, so if there’s an answer I would appreciate as much detail as possible.
With N at a 10K sample, the original takes 15 minutes. With N at the full 900K, it takes about 72 hours.
Thanks in advance!
Original model spec
data {
int<lower=1> N;
matrix[N, 5] X;
vector[N] y;
int<lower=1> K1[1];
int<lower=1,upper=max(K1)> L1[1, N];
}
parameters {
real mu_a;
real mu_b1;
real mu_b2;
real log_sigma_a;
real log_sigma_b;
real theta_a;
real theta_b;
real theta_n;
real theta_n0;
real<lower=0> dmu_sc_tau;
vector[K1[1]] dmu_sc;
}
model {
vector[N] mu = mu_a + mu_b1 * X[:, 1] + mu_b2 * square(X[:, 1]) + dmu_sc_tau * dmu_sc[L1[1]] ;
vector[N] sigma = exp( (log_sigma_a - 3) + (log_sigma_b - 3) * X[:, 1] );
vector[N] theta = theta_a + theta_b * X[:,1] + theta_n * X[:,2] + theta_n0 * X[:,4] ;
vector[N] p = inv_logit(theta);
mu_a ~ normal(0, 0.3);
mu_b1 ~ normal(0, 1);
mu_b2 ~ normal(0, 1);
log_sigma_a ~ normal(0, 2);
log_sigma_b ~ normal(0, 6);
theta_a ~ normal(0, 1);
theta_b ~ normal(0, 3);
theta_n ~ normal(0, 6);
theta_n0 ~ normal(6, 6);
dmu_sc_tau ~ normal(0, 0.2);
dmu_sc ~ std_normal();
for (i in 1:N) {
if (y[i] == 0) {
target += log(p[i]);
} else {
target += log(1 - p[i]) + normal_lpdf(y[i] | mu[i], sigma[i]);
}
}
}
Attempt at reduce_sum, runs about 50% longer
functions {
real partial_sum(int[] dummy_index, int start, int end,
vector y, vector p, vector mu, vector sigma) {
real psum = 0;
for (i in start:end) {
if (y[i] == 0) {
psum += log(p[i]);
} else {
psum += log(1 - p[i]) + normal_lpdf(y[i] | mu[i], sigma[i]);
}
}
return psum;
}
}
data {
int<lower=1> N;
matrix[N, 5] X;
vector[N] y;
int<lower=1> grainsize;
int<lower=1> K1[1];
int<lower=1,upper=max(K1)> L1[1, N];
}
transformed data {
int dummy_index[N] = rep_array(1, N);
}
parameters {
real mu_a;
real mu_b1;
real mu_b2;
real log_sigma_a;
real log_sigma_b;
real theta_a;
real theta_b;
real theta_n;
real theta_n0;
real<lower=0> dmu_sc_tau;
vector[K1[1]] dmu_sc;
}
model {
vector[N] mu = mu_a + mu_b1 * X[:, 1] + mu_b2 * square(X[:, 1]) + dmu_sc_tau * dmu_sc[L1[1]] ;
vector[N] sigma = exp( (log_sigma_a - 3) + (log_sigma_b - 3) * X[:, 1] );
vector[N] theta = theta_a + theta_b * X[:,1] + theta_n * X[:,2] + theta_n0 * X[:,4] ;
vector[N] p = inv_logit(theta);
mu_a ~ normal(0, 0.3);
mu_b1 ~ normal(0, 1);
mu_b2 ~ normal(0, 1);
log_sigma_a ~ normal(0, 2);
log_sigma_b ~ normal(0, 6);
theta_a ~ normal(0, 1);
theta_b ~ normal(0, 3);
theta_n ~ normal(0, 6);
theta_n0 ~ normal(6, 6);
dmu_sc_tau ~ normal(0, 0.2);
dmu_sc ~ std_normal();
target += reduce_sum(partial_sum, dummy_index, grainsize, y, p, mu, sigma);
}