How can I improve the speed of a linear regression Stan model with one break point?

performance
#1

Hi, any suggestion about how to improve the speed is appreciated.

I have the following simple model and I was sampling it using STAN, JAGS and BUGS. The model is essentially linear regression with one breakpoint, i.e., regression using two straight lines that are connected.

Fitting the same model (5000 samples * 3 chains) on eight data sets totally takes 25 seconds in JAGS, 50 seconds in BUGS and 65 seconds in STAN. It surprises me because STAN is a newer language and I thought it was going to be faster. Speed is very important to me because I have >20,000 such datasets to fit in a short period, so it has to be fast.

Alternatively, is that 5000 samples in STAN is >> 5000 samples in BUGS? if then, which statistic should I observe to make sure the models have the same effective number of samples?

/*
*normal regression example with change point
*/

data {
real<lower=0> ymax;
real<lower=0> xmax;
int<lower=1> N; //the number of students
vector<lower=0,upper=50>[N] y; //the exam mark
vector<lower=0,upper=50>[N] X; //the school assessment mark
}

parameters {
real<lower=0.05,upper=0.2> s; //Scale parameter for laplace distribution
real<lower=0> a1; //slope before break point
real<lower=0> a2; //slope after break point
real<lower = 0, upper = xmax> bp; // the breakpoint age, with some constraints
}
transformed parameters {
        vector<lower=0.00000000000000000001,upper=50>[N] linpred;
        real<lower=-ymax,upper=ymax> b1; //intercept
        real<lower=-ymax,upper=ymax> b2; //intercept
       
        b2 = ymax - (a2)*xmax;
        b1 = b2 + (a2-a1)*bp;

        for (i in 1:N) {
            if (X[i] < bp) {
              linpred[i] = fmax(a1*X[i]+b1,0.00000000000000000001);
            } else {
              linpred[i] = fmax(a2*X[i]+b2,0.00000000000000000001);
            }
        }
}

model {  
s ~ normal(0.125,2.24)  T[0.05,0.2]; 
a1 ~ lognormal(-0.01,0.82);
a2 ~ lognormal(0.02,0.82);
y ~ double_exponential(log(linpred),s);
}
0 Likes

#2

It always seems dangerous to make absolute statements, but I’ll go with yes.

You want to check the Neff diagnostics (check https://mc-stan.org/users/documentation/case-studies/rstan_workflow.html or https://mc-stan.org/users/documentation/case-studies/pystan_workflow.html)

That’s an estimate of the number of effective samples you got. BUGS and JAGS probably give these numbers as well (though all three programs probably estimate the numbers differently).

0 Likes

#3

Whoops, I missed this, there’s no option for this in Stan right now. Just see if you can estimate a lowest upper bound number of samples you need and roll with that. Add a conditional to your postprocessing scripts that lets you know if you dropped below your threshold.

0 Likes

#4

thx!

0 Likes