How can I reparametrize an ODE-based filter to improve efficiency (especially treedepth)?

I have noisy measurements of acceleration and position and want to obtain corrected values for acceleration, velocity, and position. To do this I express acceleration using splines, and integrate it in time to obtain velocity and position.

The basic idea is the same as for a Kalman filter.

The method works well in principle, a simple example result is shown below.

The main problem I have is that Stan consistently needs very high tree depths when solving this model. In the small example I show it is 11, which I could live with, but in some instances a tree depth of 15 is not enough.

How can debug this issue in a structured way and reparametrize the model to make it more computationally tractable? I have done a lot of trial and error, including following the Gaussian Process example in the user manual, but everything I’ve tried so far shows the same problem.

Any other hints how I can improve the model would of course also be appreciated. Are there other people out there who use Stan for similar applications?

functions {
  // b-spline code from https://github.com/milkha/Splines_in_Stan
  real build_b_spline(real t, vector ext_knots, int ind, int order);
  real build_b_spline(real t, vector ext_knots, int ind, int order) {
    // INPUTS:
    //    t:          the points at which the b_spline is calculated
    //    ext_knots:  the set of extended knots
    //    ind:        the index of the b_spline
    //    order:      the order of the b-spline
    real b_spline;
    real w1 = 0;
    real w2 = 0;
    if (order==1)
      // B-splines of order 1 are piece-wise constant
      b_spline = (ext_knots[ind] <= t) && (t < ext_knots[ind+1]);
    else {
      if (ext_knots[ind] != ext_knots[ind+order-1])
        w1 = (t - ext_knots[ind]) /
             (ext_knots[ind+order-1] - ext_knots[ind]);
      if (ext_knots[ind+1] != ext_knots[ind+order])
        w2 = 1 - (t - ext_knots[ind+1]) /
                 (ext_knots[ind+order] - ext_knots[ind+1]);
      // Calculating the B-spline recursively as linear interpolation of two lower-order splines
      b_spline = w1 * build_b_spline(t, ext_knots, ind, order-1) +
                 w2 * build_b_spline(t, ext_knots, ind+1, order-1);
    }
    return b_spline;
  }

  vector sho(real t, vector y, int num_basis, vector eta, vector ext_knots, int order) {
    vector[2] dydt;
    dydt[1] = 0.;
    for (i in 1:num_basis) dydt[1] += build_b_spline(t, ext_knots, i, order)*eta[i];
    dydt[2] = y[1];
    return dydt;
  } 
}  

data {
  // signal
  int<lower=0> N;
  vector[N] time;
  vector[N] altitude;
  vector[N] acceleration;
  // spline parameters
  int num_knots;            // num of knots
  vector[num_knots] knots;  // the sequence of knots
  int spline_degree;        // the degree of spline (is equal to order - 1)
}

transformed data {
  int order = spline_degree + 1;
  int num_basis = num_knots + spline_degree - 1; // total number of B-splines
  vector[spline_degree + num_knots] ext_knots_temp;
  vector[2*spline_degree + num_knots] ext_knots; // set of extended knots
  ext_knots_temp = append_row(rep_vector(knots[1], spline_degree), knots);
  ext_knots = append_row(ext_knots_temp, rep_vector(knots[num_knots], spline_degree));

  // Build spline matrix at measurement points
  matrix[N,num_basis] B;
  for (i in 1:N) {
    for (j in 1:num_basis) {
      B[i,j] = build_b_spline(time[i], ext_knots, j, order);
    }
  }
  
  vector[N-1] dt = time[2:N] - time[1:(N-1)];
}

parameters {
  real v0;
  real<lower=0> h0;
  vector[num_basis] eta; // for a

  real<lower=0> sigma_a;
  real<lower=0> sigma_h;
}

transformed parameters {
  vector[2] y_init = [v0, h0]';
  vector[N] a_est;
  profile("splines") {
  a_est = B*eta;
  // for (t in 1:N) {
  //   a_est[t] = 0.;
  //   for (i in 1:num_basis) a_est[t] += build_b_spline(time[t], ext_knots, i, order)*eta[i];
  // }
  }
  
  vector[N] v_est;
  vector[N] h_est;
  
  v_est[1] = y_init[1];
  h_est[1] = y_init[2];
  
  
  profile("time_integration") {
  for (i in 1:(N-1)) {
    v_est[i+1] = v_est[i] + dt[i]*a_est[i];
    h_est[i+1] = h_est[i] + dt[i]*v_est[i];
  }
  }
}

model {
  profile("priors") {
  y_init[1] ~ normal(0,10);
  y_init[2] ~ normal(0,1000);
  eta ~ normal(0,10);
  sigma_a ~ exponential(.1);
  sigma_h ~ exponential(.01);
  }
  
  profile("measurements") {
  acceleration ~ normal(a_est, sigma_a);
  altitude ~ normal(h_est, sigma_h);
  }
}

1 Like

Since you mention the Kalman filter, marginalizing over the latent states is one way that will certainly reduce treedepth. You will pay for it with added code complexity / time per gradient evaluation, but my experience is that overall performance is usually much better.

Thank you for this interesting suggestion. I need to think about how to I could best do this for my model

In the meantime, I played with the number of basis functions in my model to get a feel for how the number of parameters in the model influences the tree depth. I don’t see a clear influence right now.

I also tried a QR reparametrization as recommended by the user manual for linear regression.

As was perhaps expected because this is not a linear regression problem, I did not see any improvement.

My understanding is that a large number of leapfrog steps is caused by a complex posterior geometry. But I am really stumped how I can understand the complex posterior geometry in my case and maybe reduce its complexity through reparametrization.

Edit: On second thought, it seems evident that a linear transformation of the parameter space will not change the complexity of the posterior geometry.

Maybe something like here is happening that’s causing the large treedepth? CoDatMo Liverpool & UNINOVE models (slow ODE implementation and Trapezoidal Solver) - #37 by Funko_Unko

Your lines seem quite wiggly, I’m guessing without regularization, any sequence of wildly oscillating coefficients for the acceleration will have a similar effect on the measured (?) position.

Your lines seem quite wiggly, I’m guessing without regularization, any sequence of wildly oscillating
coefficients for the acceleration will have a similar effect on the measured (?) position.

That’s exactly it. Integrating twice will smooth out these oscillations. Thus many different (oscillating) acceleration time histories will lead to the same posterior probability, resulting in a very complex posterior geometry.

One confirmation for this is that just removing this one line will lead to a huge speedup:

altitude ~ normal(h_est, sigma_h);

Also, my basis function matrix integrated twice is fairly ill-conditioned.

Unfortunately, I cannot use a regularizing prior on the acceleration coefficients, since this will also reduce my ability to follow the jumps in the acceleration signal.

But I see that keeping acceleration measurement noise at a fixed, unrealistically small level instead of inferring it has a regularizing effect while still following the jumps. While that’s not a good permanent solution it’s at least a temporary fix.

What this tells me is I need to put more thought into finding more suitable basis functions. Right now I only have some vague ideas:

  • Use splines for position and differentiate instead of integrating
  • Find basis functions that allow occasional big jumps while suppressing small oscillations. Wavelets?
  • Anything that will improve the condition number of the position basis function matrix without jeopardizing the condition number of the acceleration basis function matrix (but what?)

I’d keep it simple with either piecewise linear and globally continuous or piecewise constant and not globally continuous basis functions for the acceleration. Or maybe piecewise linear + non-continuous? I don’t see what higher order splines would add.

As for the regularization, my naive thinking is that something like the horseshoe prior for the “jumps” might be what you are looking for. But I have no idea whether this has been used for something like this, is suitable or efficient. Maybe @avehtari knows more.

You are right, higher order splines don’t add much except making the inferred accelerations look less “synthetic”. Switching to piecewise linear basis functions gives me another significant speedup.

Between this and keeping acceleration measurement noise at a fixed, small value I see a speedup of roughly 10 compared to my original model. Thank you @Funko_Unko for the crucial hint.

Thank you also for pointing me towards horseshoe priors. I did not manage to get additional benefits from it in this case, but it looks very interesting for other applications I have.

Debugging is ill-posed – there are a near infinite number of potential problems that lead to the same output diagnostics, and so you have to find some way to prioritize the most likely problems when looking deeper into the diagnostics. Considering the model structure is particularly important. For some general discussion see Identity Crisis.

Here the problem isn’t basis functions per say but most likely rather the coupling of neighboring latent states with the lines

v_est[i+1] = v_est[i] + dt[i]*a_est[i];
h_est[i+1] = h_est[i] + dt[i]*v_est[i];

For those familiar with hierarchical models this should look like a non-centered parameterization, and from a mathematical perspective it’s essentially equivalent. There are funnels abound and the considerations in Sections 3.2 and 3.3 of Hierarchical Modeling become very relevant. In particular you will probably have to center the parameters at the observation times with sudden changes, where the realized likelihood becomes much more informative than the latent time-series model.

The solution was under my nose the whole time. @betanalpha 's remarks about the coupling of the latent states gave me a push in the right direction.

The key is to realize that velocity and position can be expressed as a linear combination of the coefficients of the spline approximation. This means from a mathematical perspective, I am just solving a linear regression problem on the coefficients.

Using the QR reparametrization from the Stan users manual reduces the treedepth to 4 instead of 11. This is more than a 100x speedup compared to the original code.

Details in case it helps anybody else:

I assemble the outcome vector y and predictor matrix x in R as follows:

### assemble outcome vector ###
y <- c(df$acceleration, df$altitude)

### assemble predictor matrix ###

library(splines)

n <- 100
basis_a <- bs(df$time, df = n, degree = 1, intercept = TRUE)

# tedious matrix algebra goes here
integrate_basis <- function(basis) {
  basis_i <- matrix(0, nrow = nrow(basis), ncol = ncol(basis)+1)
  basis_i[,1] <- 1
  for( i in 2:nrow(basis)) basis_i[i,2:(ncol(basis)+1)] <- basis_i[i-1,2:(ncol(basis)+1)] + dt*basis[i-1,]
  basis_i
}

basis_v <- integrate_basis(basis_a)
basis_h <- integrate_basis(basis_v)

x <- matrix(0, nrow = 2*nrow(df), ncol = n+2)
x[1:nrow(df),3:ncol(x)] <- basis_a
x[(nrow(df)+1):(2*nrow(df)),1:ncol(x)] <- basis_h

Of course, I could do the same thing in the transformed data block in Stan.

The Stan code is straight from the user’s guide with a few modifications:

// linear regression lifted straight from Stan users manual part 1 section 1.2 with small modifications
data {
  int<lower=0> N;   // number of data items
  int<lower=0> K;   // number of predictors
  matrix[N, K] x;   // predictor matrix
  vector[N] y;      // outcome vector
}
transformed data {
  matrix[N, K] Q_ast;
  matrix[K, K] R_ast;
  matrix[K, K] R_ast_inverse;
  // thin and scale the QR decomposition
  Q_ast = qr_thin_Q(x) * sqrt(N - 1);
  R_ast = qr_thin_R(x) / sqrt(N - 1);
  R_ast_inverse = inverse(R_ast);
}
parameters {
  vector[K] theta;      // coefficients on Q_ast
  real<lower=0> sigma_a;  // error scale
  real<lower=0> sigma_h;
}
model {
  sigma_a ~ exponential(1);
  sigma_h ~ exponential(1);
  
  vector[N] xx = Q_ast * theta;
  y[1:N%/%2] ~ normal(xx[1:N%/%2], sigma_a);  // likelihood: acceleration
  y[(N%/%2+1):N] ~ normal(xx[(N%/%2+1):N], sigma_h); // likelihood: position
}
generated quantities {
  vector[K] beta;
  beta = R_ast_inverse * theta; // coefficients on x
}

In R, I reconstruct acceleration, velocity, and position from beta as follows:

a_est <- p$beta[,3:(n+2)]%*%t(basis_a)
v_est <- p$beta[,2:(n+2)]%*%t(basis_v)
h_est <- p$beta[,1:(n+2)]%*%t(basis_h)
5 Likes

Nice! In hindsight this makes sense; because the differential equations that relate acceleration, velocity, and position are linear the linear structure of splines will propagate through each and everything will compose well with linear transformations.

1 Like