Dear Stan community,
I am experiencing issues with fitting a mixture regression model with explanatory variables. I believe the problem might be non-identifiability, which becomes more pronounced as the model complexity increases. I’ve tried addressing this by using non-exchangeable priors and pulsive priors, but neither approach has yielded satisfactory results.
There doesn’t seem to be much discussion on mixture models with explanatory variables, so I am posting my problem and a detailed description here. I am relatively new to Stan and some of the mathematical terminology, so please feel free to correct any mistakes. Any suggestions and insights would be greatly appreciated!
Back ground information:
I am trying to fit a mixture regression model on plant functional trait, such as leaf width, biomass, etc. The reason of using mixture regression is because some traits clearly has two or three forms (See histogram below).
We suspect that each form reacts differently to environmental covariates. Therefore, I am trying to fit a mixture regression model composed of two lognormal distributions, each with its own set of coefficients for the environmental covariates.
Here is the data file LW_mix_2var.RData (1.4 KB) his file contains a response variable (LW) and two explanatory variables (de, rei). This data list can be directly fed into the Stan code provided below.
Approaches:
I start from the easiest setting, a model without any explained variable and use order mu to solve the labeling Degeneracy problem. This model provide nice estimation of mu and the clustering proportion is reasonable (around 3:7).
data {
int<lower=1> K; //number of mixture component (clusters)
int<lower=1> N; // number of observations
real LW[N]; //measured leaf width
}
parameters {
simplex[K] theta; // mixture proportion
vector<lower=0>[K] sigma; // dispersion parameter
ordered[K] mu;
}
model {
// priors
vector[K] log_theta = log(theta); // cache log calculation
sigma ~ lognormal(0, 2);
mu ~ normal(0, 10);
for (n in 1:N) {
vector[K] lps = log_theta;
for (k in 1:K)
lps[k] += normal_lpdf(LW[n] | mu[k], sigma[k]);
target += log_sum_exp(lps);
}
}
However, I found that the ordered mu can’t be easily transform to model with explained variables, at least not with my current skill level. I followed chapter 19 in this book and form the model with 2 explain variables. Here, I use non-exchangeable prior for intercept to break labeling degeneracy:
target += normal_lpdf(alpha | 0, 1);
target += normal_lpdf(gamma | 0, 1) -
normal_lcdf(alpha | 0, 1);
Full stan code is here
data {
int<lower=1> K; //number of mixture component (clusters)
int<lower=1> N; // number of observations
vector[N] LW; //measured leaf width
vector[N] de;
vector[N] rei;
}
parameters {
real bde1;
real bde2;
real brei1;
real brei2;
real alpha;
real<upper = alpha> gamma;
simplex[K] theta; // mixture proportion
real<lower=0> sigma1; // dispersion parameter
real<lower=0> sigma2; // dispersion parameter
}
model {
// priors
target += normal_lpdf(bde1 | 0, 1);
target += normal_lpdf(bde2 | 0, 1);
target += normal_lpdf(brei1 | 0, 1);
target += normal_lpdf(brei2 | 0, 1);
target += normal_lpdf(alpha | 0, 1);
target += normal_lpdf(gamma | 0, 1) -
normal_lcdf(alpha | 0, 1);
target += lognormal_lpdf(sigma1 | 0, 1) -
normal_lccdf(0 | 0, 1);
target += lognormal_lpdf(sigma2 | 0, 1) -
normal_lccdf(0 | 0, 1);
target += beta_lpdf(theta | 1, 1);
// likelihood
for(n in 1:N){
target += log_sum_exp(log(theta[1]) +
lognormal_lpdf(LW[n] | alpha + de[n] * bde1 + rei[n] * brei1, sigma1),
log(theta[2]) +
lognormal_lpdf(LW[n] | gamma + de[n] * bde2 + rei[n] * brei2, sigma2));
}
}
This model produced similar clustering proportions, and estimation as previous model (only contain respond variable). Unfortunately, the labeling problem was not resolved by the prior, leading to poor chain behavior, terrible Rhat values, and low effective sample sizes.
mean sd 5.5% 94.5% n_eff Rhat4
bde1 -0.50078036 0.36323954 -0.89303802 -0.0910946 2.024952 9.140970
bde2 0.56871652 0.37493290 0.07015436 1.1625266 2.025801 8.910180
brei1 -0.04276832 0.08592037 -0.16398533 0.1223121 2.698703 1.941315
brei2 0.02580702 0.11141669 -0.08710628 0.2280988 2.162180 3.529945
alpha 0.56568359 0.40147285 -0.09748634 1.0737749 2.013091 13.228487
gamma -0.35791068 0.14097606 -0.61514816 -0.2268431 2.107936 4.248952
theta[1] 0.37051511 0.19125896 0.20693312 0.7235836 2.043524 6.719915
theta[2] 0.62948489 0.19125896 0.27641641 0.7930669 2.043524 6.719915
sigma1 0.29554891 0.25886354 0.12517962 0.7719165 2.016440 11.399809
I followed the guidance in this Mixture Models post and add the repulsive prior. This additional step did help the chain behavior but it significantly affected the estimation of clustering proportion and coefficients.
Full stan code with repulsive model
functions {
real potential(real x, real y) {
return (1 - exp(-squared_distance(x, y)));
}
}
data {
int<lower=1> K; //number of mixture component (clusters)
int<lower=1> N; // number of observations
vector[N] LW; //measured leaf width
vector[N] de;
vector[N] rei;
}
parameters {
real bde1;
real bde2;
real brei1;
real brei2;
real alpha;
real<upper = alpha> gamma;
simplex[K] theta; // mixture proportion
real<lower=0> sigma1; // dispersion parameter
real<lower=0> sigma2; // dispersion parameter
positive_ordered[K] p;
}
model {
// priors
target += normal_lpdf(p | 0, 10);
target += normal_lpdf(bde1 | 0, 1);
target += normal_lpdf(bde2 | 0, 1);
target += normal_lpdf(brei1 | 0, 1);
target += normal_lpdf(brei2 | 0, 1);
//target += normal_lpdf(alpha | 0, 1);
//target += normal_lpdf(gamma | 0, 1) -
// normal_lcdf(alpha | 0, 1);
target += normal_lpdf(alpha | 0, 1);
target += normal_lpdf(gamma | 0, 1);
target += lognormal_lpdf(sigma1 | 0, 1) -
normal_lccdf(0 | 0, 1);
target += lognormal_lpdf(sigma2 | 0, 1) -
normal_lccdf(0 | 0, 1);
target += beta_lpdf(theta | 1, 1);
// likelihood
vector[K] ps;
for(n in 1:N){
ps[1] = log(theta[1]) + lognormal_lpdf(LW[n] | alpha + de[n] * bde1 + rei[n] * brei1, sigma1);
ps[2] = log(theta[2]) + lognormal_lpdf(LW[n] | gamma + de[n] * bde2 + rei[n] * brei2, sigma2);
}
target += log_sum_exp(ps + log(potential(p[1], p[2])));
}
and the result of this version
mean sd 5.5% 94.5% n_eff Rhat4
bde1 0.0896689435 0.9367895 -1.40200595 1.6043320 1884.749 0.9986699
bde2 0.0007346021 0.9309193 -1.49240790 1.4591302 2025.558 0.9993420
brei1 0.1362264287 0.9445447 -1.40160450 1.6530227 1519.362 1.0001007
brei2 -0.0668727475 0.9648211 -1.63400785 1.4601342 1937.525 0.9988794
alpha 0.4534608600 0.7597606 -0.68611967 1.7174675 1195.345 1.0008294
gamma -0.5860839531 0.7354700 -1.78427745 0.5600027 1990.181 0.9996545
theta[1] 0.4896182546 0.2815705 0.05829443 0.9371015 1660.950 0.9986631
theta[2] 0.5103817460 0.2815705 0.06289817 0.9417057 1660.950 0.9986631
sigma1 1.4061470344 1.8500527 0.18286187 4.2627391 1488.600 1.0031409
sigma2 1.3035895523 1.4137183 0.19231260 3.6827207 1328.768 1.0020604
p[1] 4.6711393356 3.7555106 0.36324149 11.7720050 2700.857 0.9982377
p[2] 11.9647803050 5.8779740 3.95653555 21.9628990 2526.463 0.9991870
Thank you for your time and help! Any advice on improving my approach or resolving the non-identifiability issue would be greatly appreciated.
Best regards,
Chieh