Hello Stan community! I’ve been working on a classification problem using variational inference in Stan. In my setup, matrix Z of dimensions N \times D represents cluster memberships, with each row as a one-hot encoding. Since Stan relies on ADVI, necessitating continuous parameters, I’ve implemented a Gumbel-Softmax for a continuous relaxation of the discrete variables. The data were generated using the same stochastic process outlined in the log-likelihood.
However, I’ve observed that the ELBO (Evidence Lower Bound) isn’t showing significant improvement, even with adapt_engaged = TRUE
. The parameters(\alpha_0, a_1 and \beta) seem to hover around 2 (in the range of 1.5-2.5). I’m curious whether this lack of ascent is due to the high dimensions (N=100, D=3) or if the model’s complexity might be challenging for variational inference. Any insights or suggestions would be greatly appreciated!
Here is my stan code:
data {
int<lower=1> N;
int<lower=0> D;
int<lower=0> tmax;
int<lower=0> period;
matrix[N,N] T_dis;
matrix[2,N] d;
vector[N] s;
vector[N] t_inf;
real tau;
}
parameters {
real<lower=0> a0;
real<lower=0> a1;
real<lower=0> beta;
real<lower=0> tbeta;
simplex[D] theta[N];
vector[D] G;
}
transformed parameters{
matrix[N, D] Z;
for(n in 1:N){
Z[n] = softmax((log(theta[n]) + G)/tau)';
}
matrix[N,N] Z_dotprod;
matrix[2,N] d_Z_dotprod;
vector[N] ones_vector = rep_vector(1, N);
vector[N] one_Z_dotprod;
for(i in 1:N){
for(j in 1:N){
Z_dotprod[i,j] = dot_product(Z[i], Z[j]);
}
}
for(i in 1:N){
d_Z_dotprod[1,i] = dot_product(d[1], Z_dotprod'[i]);
d_Z_dotprod[2,i] = dot_product(d[2], Z_dotprod'[i]);
one_Z_dotprod[i] = dot_product(ones_vector, Z_dotprod'[i]);
}
}
model {
a0 ~ gamma(3,1);
a1 ~ gamma(3,1);
beta ~ gamma(3,1);
tbeta ~ gamma(3,1);
for(n in 1:N){
//G[n] ~ gumbel(0,1);
theta[n] ~ uniform(0,1);
}
G ~ gumbel(0,1);
for(t in 1:tmax){
vector[N] p = rep_vector(0, N);
for(n in 1:N){
if(t_inf[n]>=t || t_inf[n]==-1){
for(j in 1:N){
if(j!=n && t_inf[j]<=t-1 && t-1<t_inf[j]+period && t_inf[j]!=-1){
p[n] += Z_dotprod[n,j]*pow(T_dis[n,j], -beta) + (1-Z_dotprod[n,j])*pow(distance(d_Z_dotprod'[j]/one_Z_dotprod[j],d'[n]),-tbeta);
}
}
p[n] = 1-exp(-(a0+a1*s[n])*p[n]);
if(t_inf[n]==t){
if(p[n]!=0){
target += log(p[n]);//bernoulli_lpmf(1|p[n]);
}
}
else if(t_inf[n]>t || t_inf[n]==-1){
if(p[n]!=0){
target += log(1-p[n]);//bernoulli_lpmf(0|p[n]);
}
}
}
}
}
}
Here are the ELBOs for 4000 iterations:
Chain 1: ------------------------------------------------------------
Chain 1: EXPERIMENTAL ALGORITHM:
Chain 1: This procedure has not been thoroughly tested and may be unstable
Chain 1: or buggy. The interface is subject to change.
Chain 1: ------------------------------------------------------------
Chain 1:
Chain 1:
Chain 1:
Chain 1: Gradient evaluation took 0.018342 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 183.42 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1:
Chain 1:
Chain 1: Begin eta adaptation.
Chain 1: Iteration: 1 / 250 [ 0%] (Adaptation)
Chain 1: Iteration: 50 / 250 [ 20%] (Adaptation)
Chain 1: Iteration: 100 / 250 [ 40%] (Adaptation)
Chain 1: Iteration: 150 / 250 [ 60%] (Adaptation)
Chain 1: Iteration: 200 / 250 [ 80%] (Adaptation)
Chain 1: Success! Found best value [eta = 1] earlier than expected.
Chain 1:
Chain 1: Begin stochastic gradient ascent.
Chain 1: iter ELBO delta_ELBO_mean delta_ELBO_med notes
Chain 1: 100 -199.889 1.000 1.000
Chain 1: 200 -201.878 0.505 1.000
Chain 1: 300 -200.462 0.339 0.010
Chain 1: 400 -201.057 0.255 0.010
Chain 1: 500 -198.243 0.009 0.010
Chain 1: 600 -196.104 0.009 0.011
Chain 1: 700 -199.136 0.011 0.014
Chain 1: 800 -196.473 0.013 0.014
Chain 1: 900 -199.098 0.013 0.014
Chain 1: 1000 -198.023 0.012 0.014
Chain 1: 1100 -196.379 0.010 0.013
Chain 1: 1200 -199.133 0.010 0.013
Chain 1: 1300 -195.545 0.011 0.014
Chain 1: 1400 -198.170 0.013 0.014
Chain 1: 1500 -197.403 0.012 0.014
Chain 1: 1600 -200.285 0.012 0.014
Chain 1: 1700 -202.428 0.011 0.013
Chain 1: 1800 -198.055 0.013 0.014
Chain 1: 1900 -199.806 0.014 0.014
Chain 1: 2000 -199.264 0.011 0.011
Chain 1: 2100 -199.527 0.009 0.009
Chain 1: 2200 -200.470 0.004 0.005
Chain 1: 2300 -198.774 0.004 0.005
Chain 1: 2400 -198.749 0.004 0.005
Chain 1: 2500 -198.151 0.004 0.005
Chain 1: 2600 -199.765 0.005 0.008
Chain 1: 2700 -198.928 0.004 0.004
Chain 1: 2800 -198.186 0.005 0.004
Chain 1: 2900 -195.917 0.007 0.008
Chain 1: 3000 -198.191 0.008 0.011
Chain 1: 3100 -198.158 0.007 0.011
Chain 1: 3200 -198.538 0.006 0.011
Chain 1: 3300 -197.713 0.004 0.004
Chain 1: 3400 -198.019 0.002 0.002
Chain 1: 3500 -200.362 0.005 0.004
Chain 1: 3600 -198.486 0.007 0.009
Chain 1: 3700 -196.311 0.008 0.011
Chain 1: 3800 -198.683 0.011 0.012
Chain 1: 3900 -197.990 0.009 0.011
Chain 1: 4000 -196.480 0.009 0.011
Chain 1: Informational Message: The maximum number of iterations is reached! The algorithm may not have converged.
Chain 1: This variational approximation is not guaranteed to be meaningful.
Chain 1:
Chain 1: Drawing a sample of size 1000 from the approximate posterior…
Chain 1: COMPLETED.
Pareto k diagnostic value is 2.18. Resampling is disabled. Decreasing tol_rel_obj may help if variational algorithm has terminated prematurely. Otherwise consider using sampling instead.