I am trying to fit a model to some simulated neural data. In the simulated data, I have a set of neurons indexed by i={1, ..., N} and two conditions, a baseline condition Stimulation=0 and a stimulated condition Stimulation=1. I measure the activity of each neuron M times in each condition.
I am fitting a hierarchical model with random intercepts and random slopes across neurons. Because it is a lot of data (N=400, M=500), I am trying out variational inference.
However, the parameters estimated by the model with VI are way off. For the example below, I downsampled the data to N=90 and M=200 so as to run NUTS for comparison.
For two of the model parameters, effectPop
and interceptPop
, which are the mean slope and the mean intercept across neurons respectively, I list below the estimates computed from raw data (i.e. compute the mean difference between conditions for each neuron and average across neurons), the estimate from NUTS and the estimate from VI with 68% credible intervals.
| Raw data | NUTS | VI
effectPop | 0.25 | 0.25 [0.24-0.27] | 0.80 [0.71-0.90]
interceptPop | 4.2 | 4.20 [4.22-4.24] | 1.85 [0.96-3.70]
Furthermore, the results from fitting VI vary considerable by each fit of the model, although there is some consistency in what ways they are off.
The posteriors of these parameters obtain with NUTS don’t look very complicated.
So, what is some possible reasons why VI would be performing SO poorly? Is it just that VI is bad for this problem at hand? The data and problem look fairly simple, I didn’t expect VI to perform so poorly. Is there some heuristics to help VI perform better? And if VI just can’t work, what could be some alternatives to attempt to fit a complex statistical model (see (1) below) to this type of dataset that would seem to maybe be too big for NUTS sampling?
This is the model I’m fitting. It looks a bit quirky because it’s a simplifying of the more complex model I’m trying to fit (see (1)):
data {
int<lower=1> nNeurons; // Number of groups
int<lower=1> nObsPerNeuron; // Number of total datapoints
array[nNeurons, nObsPerNeuron] real Response; // Observation indices for each neuron
array[nNeurons, nObsPerNeuron] int<lower=0, upper=1> Stimulation; // Stimulation indicator
}
parameters {
// Baseline firing rate parameters
real<lower=0> interceptPop;
real<lower=0> interceptSigma;
vector<lower=0>[nNeurons] interceptNrn;
// Stimulation effect parameters
real<lower=0> effectPop;
real<lower=0> effectSigma;
vector<lower=0>[nNeurons] effectNrn;
// Residual noise parameter
real<lower=0> residualSigma;
}
model {
// Priors
interceptPop ~ normal(4, 2);
interceptSigma ~ normal(0, 1);
interceptNrn ~ normal(interceptPop, interceptSigma);
effectPop ~ normal(0.3, 0.4);
effectSigma ~ normal(0, 0.1);
effectNrn ~ normal(effectPop, effectSigma);
residualSigma ~ normal(0, 1);
// Likelihood
// Loop over neurons
for (i in 1:nNeurons) {
// Likelihood under population 1 for this neuron responses
real log_prob_1 = normal_lpdf(to_vector(Response[i,:]) | interceptNrn[i] +
effectNrn[i] * to_vector(Stimulation[i,:]), residualSigma);
target += log_prob_1;
}
}
I’m fitting the model with cmdstanpy in Python.
(1) The model I’m interested in is complicated, with random effects that come from a mixture distribution (which I asked about here). However, to figure out the problems I’m having right now, I’m using a simpler model without mixture distribution.