Parameters way off with variational inference, not sure why

I am trying to fit a model to some simulated neural data. In the simulated data, I have a set of neurons indexed by i={1, ..., N} and two conditions, a baseline condition Stimulation=0 and a stimulated condition Stimulation=1. I measure the activity of each neuron M times in each condition.

I am fitting a hierarchical model with random intercepts and random slopes across neurons. Because it is a lot of data (N=400, M=500), I am trying out variational inference.

However, the parameters estimated by the model with VI are way off. For the example below, I downsampled the data to N=90 and M=200 so as to run NUTS for comparison.

For two of the model parameters, effectPop and interceptPop, which are the mean slope and the mean intercept across neurons respectively, I list below the estimates computed from raw data (i.e. compute the mean difference between conditions for each neuron and average across neurons), the estimate from NUTS and the estimate from VI with 68% credible intervals.

               | Raw data |  NUTS              |  VI
effectPop      |  0.25    |  0.25 [0.24-0.27]  |  0.80 [0.71-0.90]
interceptPop   |  4.2     |  4.20 [4.22-4.24]  |  1.85 [0.96-3.70]

Furthermore, the results from fitting VI vary considerable by each fit of the model, although there is some consistency in what ways they are off.

The posteriors of these parameters obtain with NUTS don’t look very complicated.

So, what is some possible reasons why VI would be performing SO poorly? Is it just that VI is bad for this problem at hand? The data and problem look fairly simple, I didn’t expect VI to perform so poorly. Is there some heuristics to help VI perform better? And if VI just can’t work, what could be some alternatives to attempt to fit a complex statistical model (see (1) below) to this type of dataset that would seem to maybe be too big for NUTS sampling?

This is the model I’m fitting. It looks a bit quirky because it’s a simplifying of the more complex model I’m trying to fit (see (1)):

data {
  int<lower=1> nNeurons;  // Number of groups
  int<lower=1> nObsPerNeuron; // Number of total datapoints
  array[nNeurons, nObsPerNeuron] real Response; // Observation indices for each neuron
  array[nNeurons, nObsPerNeuron] int<lower=0, upper=1> Stimulation; // Stimulation indicator
}

parameters {
  // Baseline firing rate parameters
  real<lower=0> interceptPop;
  real<lower=0> interceptSigma;
  vector<lower=0>[nNeurons] interceptNrn;
  // Stimulation effect parameters
  real<lower=0> effectPop;
  real<lower=0> effectSigma;
  vector<lower=0>[nNeurons] effectNrn;
  // Residual noise parameter
  real<lower=0> residualSigma;
}

model {
  // Priors
  interceptPop ~ normal(4, 2);
  interceptSigma ~ normal(0, 1);
  interceptNrn ~ normal(interceptPop, interceptSigma);
  effectPop ~ normal(0.3, 0.4);
  effectSigma ~ normal(0, 0.1);
  effectNrn ~ normal(effectPop, effectSigma);
  residualSigma ~ normal(0, 1);
  // Likelihood
  // Loop over neurons
  for (i in 1:nNeurons) {
    // Likelihood under population 1 for this neuron responses
    real log_prob_1 = normal_lpdf(to_vector(Response[i,:]) | interceptNrn[i] +
      effectNrn[i] * to_vector(Stimulation[i,:]), residualSigma);
    target += log_prob_1;
  }
}

I’m fitting the model with cmdstanpy in Python.

(1) The model I’m interested in is complicated, with random effects that come from a mixture distribution (which I asked about here). However, to figure out the problems I’m having right now, I’m using a simpler model without mixture distribution.

1 Like

Unfortunately, the VI implementation in Stan is known to be quite fragile. In some cases, you can even get consistently better results via taking a Laplace (normal) approximation at the MAP estimate from the optimize method - see SBC for ADVI and optimizing in Stan (+HMMs) • SBC for some brief exploration of this. However, the Laplace approximation requires getting the Hessian, which I am not sure is supported by cmdstanpy.

Note also that by default, ADVI will asusme the covariance matrix is diagonal. You can definitely try using algorithm = 'fullrank' to let ADVI estimate a full covariance matrix, which may help a bit.

The algorithm assumes the posterior is jointly multivariate normal on the unconstrained scale with either a diagonal (meanfield) or fully flexible (fullrank) covariance matrix. There are many “non-complicated” shapes that don’t look like that. This is especially true for parameters that represent variances (e.g. residualSigma, effectSigma) that tend to have a skewed posterior on the unconstrained scale, even with large datasets. This should be quite easy to check - log-transform all your lower bounded parameters to get the unconstrained scale and then inspect the pairs…

If you need speed, then your model looks like something that would be highly tractable with INLA / inlabru, however I think the package/method only has an R interface. There are plans to bring integrated Laplace approximation into Stan, which should give huge speedups but AFAIK, this is still experimental…

Hope that helps at least a little.

1 Like

That is very useful, thanks! It makes sense what you say about the transformed parameters, I hadn’t thought about that. I may try to go with the R suggestions and see if they provide good enough and fat fits. If I do so I’ll post how it goes in this thread.