Variational Inference: ELBO is not ascending

Hello Stan community! I’ve been working on a classification problem using variational inference in Stan. In my setup, matrix Z of dimensions N \times D represents cluster memberships, with each row as a one-hot encoding. Since Stan relies on ADVI, necessitating continuous parameters, I’ve implemented a Gumbel-Softmax for a continuous relaxation of the discrete variables. The data were generated using the same stochastic process outlined in the log-likelihood.

However, I’ve observed that the ELBO (Evidence Lower Bound) isn’t showing significant improvement, even with adapt_engaged = TRUE. The parameters(\alpha_0, a_1 and \beta) seem to hover around 2 (in the range of 1.5-2.5). I’m curious whether this lack of ascent is due to the high dimensions (N=100, D=3) or if the model’s complexity might be challenging for variational inference. Any insights or suggestions would be greatly appreciated!
Here is my stan code:

data {
  int<lower=1> N;
  int<lower=0> D;
  int<lower=0> tmax;
  int<lower=0> period;
  matrix[N,N] T_dis;
  matrix[2,N] d;
  vector[N] s;
  vector[N] t_inf;
  real tau;
}

parameters {
  real<lower=0> a0;
  real<lower=0> a1;
  real<lower=0> beta;
  real<lower=0> tbeta;
  simplex[D] theta[N];
  vector[D] G;
}

transformed parameters{
  matrix[N, D] Z;
  for(n in 1:N){
    Z[n] = softmax((log(theta[n]) + G)/tau)';
  }
  matrix[N,N] Z_dotprod;
  matrix[2,N] d_Z_dotprod;
  vector[N] ones_vector = rep_vector(1, N);
  vector[N] one_Z_dotprod;
  for(i in 1:N){
      for(j in 1:N){
          Z_dotprod[i,j] = dot_product(Z[i], Z[j]);
      }
  }
  for(i in 1:N){
    d_Z_dotprod[1,i] = dot_product(d[1], Z_dotprod'[i]);
    d_Z_dotprod[2,i] = dot_product(d[2], Z_dotprod'[i]);
    one_Z_dotprod[i] = dot_product(ones_vector, Z_dotprod'[i]);
  }
}

model {
  a0 ~ gamma(3,1);
  a1 ~ gamma(3,1);
  beta ~ gamma(3,1);
  tbeta ~ gamma(3,1);
  for(n in 1:N){
    //G[n] ~ gumbel(0,1);
    theta[n] ~ uniform(0,1);
  }
  G ~ gumbel(0,1);
  for(t in 1:tmax){
    vector[N] p = rep_vector(0, N);
    for(n in 1:N){
      if(t_inf[n]>=t || t_inf[n]==-1){
        for(j in 1:N){
          if(j!=n && t_inf[j]<=t-1 && t-1<t_inf[j]+period && t_inf[j]!=-1){
            p[n] += Z_dotprod[n,j]*pow(T_dis[n,j], -beta) + (1-Z_dotprod[n,j])*pow(distance(d_Z_dotprod'[j]/one_Z_dotprod[j],d'[n]),-tbeta);
          }
        }
        p[n] = 1-exp(-(a0+a1*s[n])*p[n]);
        if(t_inf[n]==t){
          if(p[n]!=0){
            target += log(p[n]);//bernoulli_lpmf(1|p[n]);
          }
        }
        else if(t_inf[n]>t || t_inf[n]==-1){
          if(p[n]!=0){
            target += log(1-p[n]);//bernoulli_lpmf(0|p[n]);
          }
        }
      }
    }
  }
}

Here are the ELBOs for 4000 iterations:

Chain 1: ------------------------------------------------------------
Chain 1: EXPERIMENTAL ALGORITHM:
Chain 1: This procedure has not been thoroughly tested and may be unstable
Chain 1: or buggy. The interface is subject to change.
Chain 1: ------------------------------------------------------------
Chain 1:
Chain 1:
Chain 1:
Chain 1: Gradient evaluation took 0.018342 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 183.42 seconds.
Chain 1: Adjust your expectations accordingly!
Chain 1:
Chain 1:
Chain 1: Begin eta adaptation.
Chain 1: Iteration: 1 / 250 [ 0%] (Adaptation)
Chain 1: Iteration: 50 / 250 [ 20%] (Adaptation)
Chain 1: Iteration: 100 / 250 [ 40%] (Adaptation)
Chain 1: Iteration: 150 / 250 [ 60%] (Adaptation)
Chain 1: Iteration: 200 / 250 [ 80%] (Adaptation)
Chain 1: Success! Found best value [eta = 1] earlier than expected.
Chain 1:
Chain 1: Begin stochastic gradient ascent.
Chain 1: iter ELBO delta_ELBO_mean delta_ELBO_med notes
Chain 1: 100 -199.889 1.000 1.000
Chain 1: 200 -201.878 0.505 1.000
Chain 1: 300 -200.462 0.339 0.010
Chain 1: 400 -201.057 0.255 0.010
Chain 1: 500 -198.243 0.009 0.010
Chain 1: 600 -196.104 0.009 0.011
Chain 1: 700 -199.136 0.011 0.014
Chain 1: 800 -196.473 0.013 0.014
Chain 1: 900 -199.098 0.013 0.014
Chain 1: 1000 -198.023 0.012 0.014
Chain 1: 1100 -196.379 0.010 0.013
Chain 1: 1200 -199.133 0.010 0.013
Chain 1: 1300 -195.545 0.011 0.014
Chain 1: 1400 -198.170 0.013 0.014
Chain 1: 1500 -197.403 0.012 0.014
Chain 1: 1600 -200.285 0.012 0.014
Chain 1: 1700 -202.428 0.011 0.013
Chain 1: 1800 -198.055 0.013 0.014
Chain 1: 1900 -199.806 0.014 0.014
Chain 1: 2000 -199.264 0.011 0.011
Chain 1: 2100 -199.527 0.009 0.009
Chain 1: 2200 -200.470 0.004 0.005
Chain 1: 2300 -198.774 0.004 0.005
Chain 1: 2400 -198.749 0.004 0.005
Chain 1: 2500 -198.151 0.004 0.005
Chain 1: 2600 -199.765 0.005 0.008
Chain 1: 2700 -198.928 0.004 0.004
Chain 1: 2800 -198.186 0.005 0.004
Chain 1: 2900 -195.917 0.007 0.008
Chain 1: 3000 -198.191 0.008 0.011
Chain 1: 3100 -198.158 0.007 0.011
Chain 1: 3200 -198.538 0.006 0.011
Chain 1: 3300 -197.713 0.004 0.004
Chain 1: 3400 -198.019 0.002 0.002
Chain 1: 3500 -200.362 0.005 0.004
Chain 1: 3600 -198.486 0.007 0.009
Chain 1: 3700 -196.311 0.008 0.011
Chain 1: 3800 -198.683 0.011 0.012
Chain 1: 3900 -197.990 0.009 0.011
Chain 1: 4000 -196.480 0.009 0.011
Chain 1: Informational Message: The maximum number of iterations is reached! The algorithm may not have converged.
Chain 1: This variational approximation is not guaranteed to be meaningful.
Chain 1:
Chain 1: Drawing a sample of size 1000 from the approximate posterior…
Chain 1: COMPLETED.
Pareto k diagnostic value is 2.18. Resampling is disabled. Decreasing tol_rel_obj may help if variational algorithm has terminated prematurely. Otherwise consider using sampling instead.

Have you tried sampling this with MCMC to make sure the the model is correct? The amount of fiddly indexing and testing in this code is worrisome as it’s very difficult to get all the indexing right.

Also, I’m afraid Stan’s ADVI algorithm isn’t very robust to bad step sizes. Rather than adapting step sizes, you might want to try a grid of them to find a workable value if you can verify your model is correct.

The code you have in transformed parameter is doing a lot of work that you want to try to push into more standard matrix operations rather than these dot products. For example, rather than matrix multiply for Z_dotprod, you have element wise dot products. Isn’t Z_dotprod == Z * Z'? It’s way more efficient to do this using matrix operations than element wise dot products. Same for the other loops. This won’t change the ADVI algorithm, but it’ll make each iteration much faster.

You want to move ones_vector to the transformed data block so it’s not recreated and autodialed through.

Usually these convoluted conditionals lead to non-differentiable objectives and can be difficult to reason with. Also, I would strongly suggest using standard spacing here and using space around operators, after commas, etc.

These kinds of operations are highly unstable due to underflow/rounding:

p[n] = 1-exp(-(a0 + a1 * s[n]) * p[n]);
...
 ... log(p[n]) ...
 ... log(1 - p[n])...

You need to combine those operations for stability and use log1m_exp(-(a0 + a1 * s[n]) * p[n]) in place of the log(p[n]) and -(a0 + a1 * s[n]) * p[n] in place of `log(1 - p[n]). We have the built-in for arithmetic stability of an otherwise challenging operation.

Your problems may be due to instability here.

One other thing you can do for efficiency is preprocess all of those Tess since t_inf is data, so all the results are known ahead of time. In general, rather than

for (n in 1:N) {
  if (foo(x[n])) {
    x[n] ~ bar(theta1);
  } else {
    x[n] ~ baz(theta2);
  }
}

If x[n] is data, you can build two arrays in the transformed data block,

transformed data {
  vector[num_foo(x)] foo_x;
  vector[num_not_foo(x)] non_foo_x;
  ... write functions to count sizes, num_foo, num_not_foo, and then populate here ...
}
model {
   foo_x ~ bar(theta1);
   non_foo_x ~ baz(theta2);
}

This then allows vectorization of bar and baz, but more importantly removes runtime choice points, which are very costly on modern CPUs relative to arithmetic operations.