Fail to marginalize linear regression model with one change point

Hi,

I have a linear regression model with one change point, i.e. a regression model with two straight lines, with the position of change point being a latent variable.

I originally coded it in JAGS and it worked. However, I cannot assume the latent change point follows a categorical prior distribution in STAN (i.e. assume the change point to be one of the data points), hence I need to marginalize the change point.

I read the STAN documentation (change point model) and tried to translate it into STAN using marginalization but the result of the marginalized model always gives two lines with similar slopes. I tried to cheat by assuming the change point to be uniformly distributed. It worked in STAN, however, it is slow. The model takes 25 seconds in JAGS to generated (5000samples*3chains) fitting 8 data sets with sample size 10-20, but the same number of samples takes 60+ seconds in STAN. So I really want to marginalize it for efficiency.

Once I get the marginalization to work, I will try dynamical programming as suggested by the STAN documentation to improve the efficiency.

/*
*normal regression example with change point
*/

data {
real<lower=0> ymax;
real<lower=0> xmax;
int<lower=1> N; //the number of students
vector<lower=0,upper=50>[N] y; //the exam mark
vector<lower=0,upper=50>[N] X; //the school assessment mark
}
transformed data {
real log_unif;
log_unif = -log(N);
}
parameters {
real<lower=0.05,upper=0.2> s; //Scale parameter for laplace distribution
real<lower=0> a1; //slope before break point
real<lower=0> a2; //slope after break point
}
transformed parameters {
        vector<lower=0.00000000000000000001,upper=50>[N] linpred;
        real<lower=-ymax,upper=ymax> b1; //intercept
        real<lower=-ymax,upper=ymax> b2; //intercept
        vector[N] lp;
        b2 = ymax - (a2)*xmax;
        lp = rep_vector(log_unif, N);
        for (tau in 1:N) {
                b1 = b2 + (a2-a1)*X[tau];
                print("b1 is !!!!",b1,"b2 is !!!!",b2,"a1 is !!!!",a1,"a2 is !!!!",a2,"change point is !!!!",X[tau])
                linpred[1:tau] = a2*X[1:tau]+b2;
                linpred[(tau+1):N] = a1*X[(tau+1):N]+b1;
                print("X[1]",X[1],"linpred[1]",linpred[1],"xmax: ",xmax," ymax:",ymax)
                for (n in 2:N)
                        lp[tau] = lp[tau] + double_exponential_lpdf(y[n] |log(linpred[n]),s);
                //+ double_exponential_lpdf(linpred[n] | n < tau ? e : l,s);
        }
}

model {  
        s ~ normal(0.125,2.24)  T[0.05,0.2]; 
        a1 ~ lognormal(-0.01,0.82);
        a2 ~ lognormal(0.02,0.82);
        target += log_sum_exp(lp);
}
'

The JAGS code

"
model {
tau ~ dcat(xs[]) # the changepoint
s ~ dnorm(0.125, .2)I(0.05,.2)
a1 ~ dlnorm(-0.01,1.5)
a2 ~ dlnorm(0.02,1.5)
b2 <- ymax - (a2)*xmax
b1 <- b2 + (a2-a1)*x[tau]

    for(i in 1:N) {
    xs[i]  <- 1/N    # all x_i have equal priori probability to be the changepoint
    
    mu[i]  <- step(i-tau-1)   * (  a1      *x[i] +     b1       ) + 
    step(tau-i) * (a2*x[i] + b2)
    
    mu2[i] <- log(max(mu[i],0.00001))
    y[i] ~ ddexp(mu2[i], tau2)
    
    }
    tau2<-1/s
    }
    "