Zero chain movement/mixing in forced ODE model

I am having trouble with the following ODE model, where I basically get no mixing between the Markov chains. I cannot seem to get to the bottom of why this is the case – the likelihood surface looks nicely curved towards the area of high density. I have tried using highly informative priors, and restricting the bounds of the parameters considerably.

I have also tried fixing all but one parameter, and still I have the issue that there is no exploration of the Markov chains. However when I look at a plot of the log-likelihood as I vary this parameter, there is considerable curvature towards the maximum.

Finally I tried starting the chains at the maximum likelihood estimates (to see whether the poor exploration of the chains is because of long tails) and still the chains hardly move at all.

Does anyone have any other ideas to diagnose what is wrong with my model?

functions{
  int find_interval_elem(real x, vector sorted, int start_ind){
    int res;
    int N;
    int max_iter;
    real left;
    real right;
    int left_ind;
    int right_ind;
    int iter;

    N = num_elements(sorted);

    if(N == 0) return(0);

    left_ind  = start_ind;
    right_ind = N;

    max_iter = 100 * N;
    left  = sorted[left_ind ] - x;
    right = sorted[right_ind] - x;

    if(0 <= left)  return(left_ind-1);
    if(0 == right) return(N-1);
    if(0 >  right) return(N);

    iter = 1;
    while((right_ind - left_ind) > 1  && iter != max_iter) {
      int mid_ind;
      real mid;
      // is there a controlled way without being yelled at with a
      // warning?
      mid_ind = (left_ind + right_ind) / 2;
      mid = sorted[mid_ind] - x;
      if (mid == 0) return(mid_ind-1);
      if (left  * mid < 0) { right = mid; right_ind = mid_ind; }
      if (right * mid < 0) { left  = mid; left_ind  = mid_ind; }
      iter = iter + 1;
    }
    if(iter == max_iter)
      print("Maximum number of iterations reached.");
    return(left_ind);
  }
  
  real[] deriv_aslanidi(real t, real[] I, real[] theta, real[] x_r, int[] x_i){
    
    int aLen = x_i[1];
    vector[aLen] ts = to_vector(x_r[1:aLen]);
    vector[aLen] V = to_vector(x_r[(aLen+1):(2*aLen)]);
    int aT = find_interval_elem(t, ts, 1);
    real aV = (aT==0) ? V[1] : V[aT];
    
    real xtau = theta[1] / (1 + exp(aV/ theta[2])) + theta[3];
    real xinf = 1 / (1 + exp(-(aV + theta[4]) / theta[5]));
    
    real dydt[1];
    dydt[1] = (xinf - I[1]) / xtau;
    return dydt;
  }
  
  vector solve_aslanidi_forced_ode(real[] ts, real X0, real[] theta, real[] V, real t0){
    int x_i[1];
    real X_Kr[size(V),1];
    vector[size(V)] I;
    x_i[1] = size(V);
    
    X_Kr = integrate_ode_bdf(deriv_aslanidi, rep_array(X0, 1), t0, ts, theta, to_array_1d(append_row(to_vector(ts), to_vector(V))), x_i, 10^(-5),10^(-5),10^3);

    for(i in 1:x_i[1]){
      real t = ts[i];
      int aT = find_interval_elem(t, to_vector(ts), 1);
      real aV = (aT==0) ? V[1] : V[aT];
      real rInf = 1 / (1 + exp((aV + theta[6]) / theta[7]));
      I[i] = theta[8] * X_Kr[i,1] * rInf * (aV + 85);
    }
    
    return(I);
  }
  
  real calculateLogLikelihood(real[] I, real[] ts, real X0, real[] theta, real[] V, real t0, real sigma, int N){
    
  vector[N] I_int;
  real aLogProb;
  
  // solve ODE using stiff solver
  I_int = solve_aslanidi_forced_ode(ts, X0, theta, V,-0.1);
  
  // likelihood
  aLogProb = normal_lpdf(I|I_int,sigma);

  return(aLogProb);
  }
}

data{
  int N;
  real V[N];
  real I[N];
  real ts[N];
  real t0;
}

transformed data {
  int x_i[0];
}

parameters{
  real<lower=0> p1;     // ms
  real<lower=0> p2;     // mV
  real<lower=0> p3;     // ms
  real<lower=0> p4;     // mV
  real<lower=0> p5;     // mV
  real p6;              // mV
  real<lower=0> p7;     // mV
  real<lower=0> p8;
  real<lower=0,upper=1> X0;
  real<lower=0> sigma;
}

transformed parameters{
  real theta[8];
  theta[1] = p1;
  theta[2] = p2;
  theta[3] = p3;
  theta[4] = p4;
  theta[5] = p5;
  theta[6] = p6;
  theta[7] = p7;
  theta[8] = p8;
}

model{
  
  target += calculateLogLikelihood(I, ts, X0, theta, V, -0.01, sigma, N);
  
  //priors
  p1 ~ normal(900,500);
  p2 ~ normal(5,1);
  p3 ~ normal(100,10);
  p4 ~ normal(0.1,0.02);
  p5 ~ normal(12.25,3);
  p6 ~ normal(-5.6,1);
  p7 ~ normal(20.4,3);
  p8 ~ normal(0.01,0.001);
  sigma ~ normal(1,0.1);
}

Out of curiosity, what is find_interval_elem doing? Any chance you could post a plot of a run of the ODE (just so it’s easier to look at)?

Hi Ben,

It is finding the index at which a time t is found in the vector ts. This
is then used to find the particular driving voltage at that point in time
that affects the RHS of the ODE.

Errrm, what sort of plot would you like? In terms of the chains – their
evolution is just a flat line…

Best,

Ben

Makes sense @ the interval thing

I was mainly curious what sort of dynamics the ODE exhibited. It looks like it is 1D, so is it a wiggly line? Is it a flat line? Exponential? What’s the range of the values it might explore? etc.

Hi Ben,

Yes, it is a 1d output. The output (a current) is sensitive to the applied
(cyclical) voltage, although overall has a roughly stationary mean. I will
send over a graph later if you think useful? I am not at a computer that I
can run the model on currently. Would it be more useful if I sent over some
fake data?

Best,

Ben

Hi Ben,

Please find attached some fake data for this model (current which is the
output, and the driving voltage). To solve the ODE (using ODE45 in Stan)
and plot I just do the following,

aDF <- read.csv(‘simple.csv’)
lThinning <- seq(1,nrow(aDF),1)
N <- length(lThinning)
I <- aDF$ikr.IKr[lThinning]
V <- aDF$membrane.V[lThinning]
ts <- aDF$engine.time[lThinning]

expose_stan_functions(‘aslanidi_4_bdf_logLikelihood.stan’)
p1=1000
p2=5
p3=100
p4=0.085
p5=12.25
p6=-5.4
p7=20.4
p8 = 0.03
sigma=0.1
theta = c(p1,p2,p3,p4,p5,p6,p7,p8)
complex_forced <- solve_aslanidi_forced_ode(ts,0.4, theta, V, -0.0001)

plot(ts,complex_forced,type=‘l’)

aShortDF <- data.frame(time=ts,model=complex_forced,true=I)
aShortDF <- melt(aShortDF,id.vars = ‘time’)
ggplot(aShortDF,aes(x=time,y=value,colour=as.factor(variable))) +
geom_path()

I also attach a picture of the typical solution to the ODE – as you can
see there are some quite abrupt changes in the current!

Best,

Ben

simple.csv (101 KB)

Dang, those are some pretty big wiggles for sure. What does your noise look like in your measured data?

@wds15 you have any suggestions on this? I’ve never messed with ODEs this complicated in Stan.

Hi Ben,

There is some noise, although not enough to affect parameter
identification. I’m going to do some ABC and MCMC using adaptive covariance
matrix sampling and this should give us a better idea about the shape of
the posteriors.

As an aside, at Oxford we have just begun an attempt to develop a parameter
inference methodology for fitting ODE and PDE models from electrochemistry
and cardiac modelling. This means that the bulk of models are actually as
complex, but typically more, than that which I shared. We also commonly use
cyclically-varying voltage to drive the system as this helps with parameter
identification. As part of this we are trying Stan’s NUTS vs some other
MCMC methods that we have used before (typically those that do not require
the gradient of the likelihood).

Best,

Ben

Ben: Do you really want that prior on sigma?

normal( **1** ,0.1);
Since your initial value for sigma was 0.1 and your signal is O(1) and your noise is small, I’m assuming you meant

normal(0,0.1);

Hope that helps

Hi Daniel,

Yes, good spot. I will try that although think I’ve run the model before
when I set sigma at 0.1 (rather than infer it) and still had the issues.
Will let you know how it goes!

Best,

Ben

If your model predicts within 0.1 or less and your sigma is 1 then your parameters could wander all over the place and not change the likelihood much. On the other hand, if you initialized sigma = 0.1 with that model, and the initial conditions fit well, then I could imagine it would be hard to move away from the initial conditions, as you’d be stuck in a local optimum.

at least, it’s plausible. Hope it helps.

Also in my experience things become pathological with very small error scale, so having a prior that places maximum density on sigma=0 seems not as good as something like

gamma(3.0,3.0/0.1);

which puts a barrier to going to very very small sigmas

Hello all,

Thanks for your suggestions.

I tried the various priors on sigma that you suggested, but unfortunately,
this didn’t seem to make much of a difference. The chains are still not
mixing well at all.

Any other suggestions? I am currently doing some Metropolis-Hastings
fitting to see how the posteriors might look, and will let you know what I
find.

Best,

Ben

What does this line mean?
aLogProb = normal_lpdf(I|I_int,sigma);

I and I_int are vectors, sigma is real, aLogProb is real. does this automatically sum the lpdf for all the elements of the vector? Or is it just looking at the very first data point and dropping all the rest of the data?

I think the target += … syntax automatically sums a vector return value (pg 75 of current manual), but that won’t be the case in your statements here.

Hi Dan,

normal_lpdf is this context works out the log probability summed across all
the data points; it’s just a vectorised way of writing it in a loop (see
page 484 of the Stan manual).

target += just increments the log probability by the amount on the right
hand side, which is just a scalar. I looked at page 75 of the manual and
don’t think (I could be wrong) that Stan will automatically sum the RHS.

Ben

The lpdf functions always return a scalar. The summation over multiple components happens internally to the function.

You can also try reducing the tolerances on the ODEs (they can be specified as further arguments) and starting at a lower stepsize and targeting a higher acceptance rate (e.g., control=list(stepsize=0.01, adapt_delta=0.99) in R). You can also try running adaptation longer. If the chains don’t converge, you should check what the diagonal mass matrix and stepsize look like; if every chain’s explored enough of the posterior to estimate these, they should be roughly the same.

If they move a bit, but not much, then you can do pairs plots on the parameters to see if you have any problematics shapes (high correlations or banana-like shapes or funnel-like shapes) to give you an idea where the problem might be.

If you haven’t already, I’d suggest testing those functions to make sure they’re doing the right thing. Bugs can cause this problem, too.

The usual reason for not moving is ill-conditioning of the Hessian (that is, sharp curvature in one direction, flat in another region, as in the funnel example). So anything you can do to reparameterize so the posterior is on the unit scale can help with adaptation.

Hi Bob,

I have tried running this at a lower step size and higher acceptance rate – this doesn’t seem to help. Similarly so for the adaptation time. How do I check the diagonal mass matrix?

I have run MCMC for the model using Metropolis-Hastings and a method that adapts the proposal covariance matrix over time, and I do get chain mixing and convergence. There are some strong covariances in the posterior samples (I attach a data file containing a few of the samples from the nine parameters) however.

I am a bit confused here why I seem to get reasonable performance from Metropolis-Hastings but Stan’s NUTS is getting stuck? Does anyone have any suggestions here?

aslanidi_1.csv (1.3 MB)

The discontinuities of your forcing function can be a problem for the ode integrator.

You should make sure that you output from the ode integrator before and after those jumps. So pass into the integrate_ode function a times vector which includes those time-points, but you then drop those additional time-points when you add things to the log-lik. This will ensure that the ODE integrator “sees” the time-resolution of the forcing function. I have seen cases where the ODE integrator would otherwise step over those jumps.

As a sanity check, you could also consider to include as part of the output the integral over the forcing function. This will allow you to check that the ode integrator is doing the right thing as the forcing function only depends on data and thus you know the integral over that function. This way you know if you need those additional steps as outlined above or not.

One more thing: If your forcing function is step-like, then integrate step-wise thtough the system. That is a bit more tedious, but this is actually the preferred solution for the case of step-function inputs.