# Gaussian process gradient matching for inference in dynamical systems

#1

Hia,

I’ve been looking at ways to scale parameter inference in dynamical systems, both brute force (courtesy of MPI & threading) and through approximations. One promising approach is that of gradient matching with a gaussian process (GP) fit to data.

Recently Wenk et al. (2018) gave a thorough treatment of this approach in enough detail that I felt comfortable trying to implement it in Stan. Their paper can be found here: https://arxiv.org/pdf/1804.04378.pdf and their source code in Python here: https://github.com/wenkph/FGPGM.

The gist of the idea is that instead of comparing data (y) to the true states (x) of a dynamical system generated by the model parameters (\theta), we first fit a GP to the data (f(x, \phi)) and compare the gradients of the optimised GP function (\dot{x}|\phi) against the gradients generated by dynamical system for a given state (\dot{x}|\theta). This sidesteps the need to integrate the entire system which becomes prohibitive with an increasing number of states (S) or samples (N), as might occur in hierarchical modelling where different instances of a system share hyperparameters. A derivation of the full likelihood is given on pg. 9. [note: any bastardisation of notation and syntax is my own.]

Implementing this in Stan was fun (in a twisted sort of way) but resulted in a moderately complicated model that doesn’t work. Below is an example of the lynx-hare system from @Bob_Carpenter 's case study. For S = 2 and N =100, 2000 iterations take about 20mins for four chains but gives n_{eff} = 2, \hat{R} > 50,000 and unhelpfully wide posterior intervals. The authors on the other hand report correct inference in about 50mins with a single Gibbs chain (which I was dubious about, hence Stan).

If anyone is game to look over the model, I’d really appreciate having any obvious mistakes I’ve probably made pointed out. Alternatively, if someone is in a better position to comment on the validity of this approach it’d be great to know if I’m barking up the wrong tree.

Model code below and attached with an R script and data. Optimising N * S GPs can take a while, so a sensible set of toy parameters are already included.

Thanks for the help,
Andrew

hudson-bay-lynx-hare.csv (470 Bytes)

functions{
// Function to calculate derivatives
real[] dx_dt(vector x, real[] theta,
real[] x_r, int[] x_i) {

real u = x[1];
real v = x[2];

real alpha = theta[1];
real beta = theta[2];
real gamma = theta[3];
real delta = theta[4];

real du_dt = (alpha - beta * v) * u;
real dv_dt = (-gamma + delta * u) * v;
return { du_dt, dv_dt };
}

// Function for log determinant of matrices
real log_cholesky_determinant(matrix x){
return log(prod(diagonal(cholesky_decompose(x)))^2);
}

// Function to calculate GP-ODE match
matrix D, matrix A){

int T = rows(x);
vector[T] dx_gp = D * x;
vector[T] diff = dx_f - dx_gp;

real prob = -0.5 * dot_product(diff, A \ diff);
real det = -0.5 * log_cholesky_determinant(A);

return det + prob;
}

// Function to calculate GP-obs match
real logProbGP(vector y, vector x,
matrix C_Phi, real sigma){

int T = rows(y);

// Prob of x given GP prior
real det_phi = -0.5 * log_cholesky_determinant(C_Phi); //determinants are constant
real p_phi = -0.5 * dot_product(x, C_Phi \ x);

// Prob of x given data
real det_y = -0.5 * log(prod(rep_vector(sigma^2, T))); //could precalc as data
real p_y = -0.5 * inv_square(sigma) * squared_distance(x, y);

return det_phi + p_phi + det_y + p_y;
}

// Function to calculate log-likelhood
real logLike(vector[] y, vector[] x, vector y_mean,
vector y_sd, matrix[,] kernArr, vector sigma,
real[] pars, real[] x_r, int[] x_i){

int S = x_i[1];
int T = x_i[2];
vector[T] x_raw[S];
vector[T] dx_f[S];
real scale = 2 * S * T;
real ll = 0;

// Convert to ODE scale
for(i in 1:S)
x_raw[i] = y_mean[i] + x[i] * y_sd[i];

// Generate gradients for each timestep
for(i in 1:T)
dx_f[, i] = dx_dt(to_vector(x_raw[, i]), pars, x_r, x_i);

// Evaluate for each state
for(i in 1:S){
real p_ode;
real p_gp;

// Normalise gradients to match GP
dx_f[i] = dx_f[i] / y_sd[i];

// Compare against GP
p_ode = logProbGrad(dx_f[i], x[i], kernArr[i, 2], kernArr[i, 3]);

p_gp = logProbGP(y[i], x[i], kernArr[i, 1], sigma[i]);

// Increment log-likelihood
ll = ll + p_ode + p_gp;
}

return -ll / scale;
}

// Function to pre-calculate GP covariance over gradients
matrix[] covKern(real rbf_var, real rbf_len, vector t,
real nugget, real gamma){

int T = rows(t);
matrix[T, T] C_Phi;         // Covariance over x
matrix[T, T] C_Dash;        // Covariance of gradients over x
matrix[T, T] C_DoubleDash;  // Covariance over gradients
matrix[T, T] D;             // Linear system for gradients
matrix[T, T] A;             // Combined GP + gradients kernel

matrix[T, T] kernArr[3];

for(i in 1:T){
for(j in 1:T){

real t_diff = t[i] - t[j];
real k = rbf_var * exp(- 0.5 * t_diff^2 * inv_square(rbf_len));

C_Phi[i, j] = k;
C_Dash[i, j] = inv_square(rbf_len) * t_diff * k;
C_DoubleDash[i, j] = (inv_square(rbf_len) - t_diff^2 / rbf_len^4) * k;
}

C_Phi[i, i] = C_Phi[i, i] + nugget; // ensure PSD
}

D = C_Dash' / C_Phi;
A = C_DoubleDash - C_Dash' / C_Phi * C_Dash;

for(i in 1:T)
A[i, i] = A[i, i] + gamma; // Allows for GP-ODE mismatch

// Package matrices
kernArr[1] = C_Phi;
kernArr[2] = D;
kernArr[3] = A;

return(kernArr);
}
}
data {
int<lower=1> N;         // # samples
int<lower=1> S;         // # states
int<lower=1> T;         // # observation times

vector[T] y[N, S];      // normalised data
vector[S] y_mean[N];    // data mean
vector[S] y_sd[N];      // data scale
vector<lower=0>[T] t;   // observation times

// GP hyperparameters, optimised on normalised data
vector<lower=0>[S] rbf_var[N];
vector<lower=0>[S] rbf_len[N];
vector<lower=0>[S] sigma[N];
real nugget;
real gamma;
}
transformed data{
// Lower bounds for non-negative states
vector[S] L[N];

// Real and integer data
real x_r[0];
int x_i[2] = {S, T};

// GP covariance function
matrix[T, T] kernArr[N, S, 3];
for(i in 1:N){
for(j in 1:S){
kernArr[i, j] = covKern(rbf_var[i, j], rbf_len[i, j], t, nugget, gamma);

L[i, j] = -y_mean[i, j] / y_sd[i, j];
}
}
}
parameters{
real<lower=0> theta[4];           // parameters of system
vector<lower=0>[T] x_ub[N, S];    // Unbounded states
}
transformed parameters{
vector[T] x[N, S];                // True states (normalised)

// Constrain states above zero
for(i in 1:N)
for(j in 1:S)
x[i, j] = L[i, j] + x_ub[i, j];
}
model{
// Priors
theta[{1, 3}] ~ normal(1, 0.5);
theta[{2, 4}] ~ normal(0.05, 0.05);

for(i in 1:N){

// Estimate true state
for(j in 1:S)
x_ub[i, j] ~ normal(0, 1);

// Match against observations
target += logLike(y[i], x[i], y_mean[i], y_sd[i], kernArr[i], sigma[i], theta, x_r, x_i);
}
}
generated quantities{
// x are already predicted states,
// could simulate with ODE solver for comparison
}


#2

Howdy, howdy! Have some unorganized thoughts (edit: this sentence made no sense previously):

Have you looked at pairplots for things? Is it that different chains are in different places? Or that the posterior is super correlated and there are treedepth things? Or is it divergences?

vector<lower=0>[T] x_ub[N, S];    // Unbounded states


That looks bounded, but it says it isn’t bounded?

Going by what’s in the paper, it seems like there are four terms in their log density:

\theta \sim \text{something} – prior on ODE parameters
x \sim \text{multi_normal}(0, C_\phi) – GP prior on states
y \sim \text{normal}(x, \sigma^2) – Observational noise
And then adding in a \text{log_multi_normal_lpdf}(f(x, θ) | Dx, A + \gamma I) (how they get this seems weird to me but ignore that for now)

Any particular reason to avoid writing it like this? Since you’re not sampling your GP hyperparameters you can do all your Choleskys n’ such before hand and this should be fast enough.

#3

Ben, you little ripper!

That makes so much more sense. I a) missed that x could/should be distributed multivariate and b) got a bit hung up on breaking down each step that I missed the forest for the trees.

Bungled the labelling of this. It’s the positive increment from the lower bound of x which maintains positive states on the data scale (unnormalised):

This means I get a warning about non-linear transforms when sampling x \sim multi\_normal(0, C_{\phi}) but I think it’s ok in this case, right?

Making the suggested changes, the model runs much slower but more correctly. 2000 iterations took ~90mins, n_{eff} > 750 and \hat{R} \approx 1. It recovers parameters in line with Bob’s case study, but off by an order of magnitude which I think means there’s a mistake in my pre-processing code (3D arrays are a pain to keep straight).

About 1/4 of the transitions were divergent, but 3/4 saturated a maximum treedepth of 10, so I’ll have a play at tuning this. The authors also suggest that \gamma can be manually tuned to improve fit, but I don’t yet have a good intuition for this.

Lastly, care to explain what you find weird about MVN(f(x, θ) | Dx, A + \gamma I)? Always happy to improve things.

Cheers!
A

functions{
// Function to calculate derivatives
real[] dx_dt(vector x, real[] theta,
real[] x_r, int[] x_i) {

real u = x[1];
real v = x[2];

real alpha = theta[1];
real beta = theta[2];
real gamma = theta[3];
real delta = theta[4];

real du_dt = (alpha - beta * v) * u;
real dv_dt = (-gamma + delta * u) * v;
return { du_dt, dv_dt };
}

// Function to pre-calculate GP covariance over gradients
matrix[] covKern(real rbf_var, real rbf_len, vector t,
real nugget, real gamma){

int T = rows(t);
matrix[T, T] C_Phi;         // Covariance over x
matrix[T, T] C_Dash;        // Covariance of gradients over x
matrix[T, T] C_DoubleDash;  // Covariance over gradients
matrix[T, T] D;             // Linear system for gradients
matrix[T, T] A;             // Combined GP + gradients kernel

matrix[T, T] kernArr[3];

for(i in 1:T){
for(j in 1:T){

real t_diff = t[i] - t[j];
real k = rbf_var * exp(- 0.5 * t_diff^2 * inv_square(rbf_len));

C_Phi[i, j] = k;
C_Dash[i, j] = inv_square(rbf_len) * t_diff * k;
C_DoubleDash[i, j] = (inv_square(rbf_len) - t_diff^2 / rbf_len^4) * k;
}

C_Phi[i, i] = C_Phi[i, i] + nugget; // ensure PSD
}

D = C_Dash' / C_Phi;
A = C_DoubleDash - C_Dash' / C_Phi * C_Dash;

for(i in 1:T)
A[i, i] = A[i, i] + gamma; // Allows for GP-ODE mismatch

// Package matrices
kernArr[1] = cholesky_decompose(C_Phi);
kernArr[2] = D;
kernArr[3] = cholesky_decompose(A);

return(kernArr);
}
}
data {
int<lower=1> N;         // # samples
int<lower=1> S;         // # states
int<lower=1> T;          // # observation times

vector[T] y[N, S];      // normalised data
vector[S] y_mean[N];    // data mean
vector[S] y_sd[N];      // data scale
vector<lower=0>[T] t;   // observation times

// GP hyperparameters, optimised on normalised data
vector<lower=0>[S] rbf_var[N];
vector<lower=0>[S] rbf_len[N];
vector<lower=0>[S] sigma[N];
real nugget;
real gamma;
}
transformed data{
// Lower bounds for non-negative states
vector[S] L[N];

// Real and integer data
real x_r[0];
int x_i[2] = {S, T};

// GP covariance function
matrix[T, T] kernArr[N, S, 3];
for(i in 1:N){
for(j in 1:S){
kernArr[i, j] = covKern(rbf_var[i, j], rbf_len[i, j], t, nugget, gamma);

L[i, j] = -y_mean[i, j] / y_sd[i, j];
}
}
}
parameters{
real<lower=0> theta[4];           // parameters of system
vector<lower=0>[T] x_ub[N, S];    // Positive increments from lower bounds
}
transformed parameters{
vector[T] x[N, S];                // True states (normalised)
vector[T] x_raw[N, S];

for(i in 1:N){
for(j in 1:S){
x[i, j] = L[i, j] + x_ub[i, j]; // Constrain states above zero
x_raw[i, j] = y_mean[i, j] + x[i, j] * y_sd[i, j]; // Convert to data scale
}
}
}
model{
// Priors
theta[{1, 3}] ~ normal(1, 0.5);
theta[{2, 4}] ~ normal(0.05, 0.05);

for(i in 1:N){
vector[T] dx_f[S];
vector[T] dx_gp[S];

// Generate function gradients for each timestep
for(j in 1:T)
dx_f[, j] = dx_dt(to_vector(x_raw[i, , j]), theta, x_r, x_i);

// Compare for each state
for(j in 1:S){
// GP prior
x[i, j] ~ multi_normal_cholesky(rep_vector(0, T), kernArr[i, j, 1]);

// Observational noise
y[i, j] ~ normal(x[i, j], sigma[i, j]);

// Normalise gradients to match GP
dx_f[j] = dx_f[j] / y_sd[i, j];

dx_gp[j] = kernArr[i, j, 2] * x[i, j];

target += multi_normal_cholesky_lpdf(dx_f[j] | dx_gp[j], kernArr[i, j, 3]);
}
}
}
generated quantities{
// x are already predicted states,
// could simulate with ODE solver, just need to add back real t
}


#4

Oooh, I see what’s happening here.

I wouldn’t do this, honestly. I think there are a few ways to do non-parametric fits of things constrained to be positive. For GPs, I think the standard thing is pass it through a non-linear transform like exp. The issue with that here is that then you can’t get the derivatives of your non-parametric function so easily.

Part of the issue is that even if you constrain pointwise bits of your GP to be positive, that doesn’t mean that things in between these points on your GP aren’t negative.

More on the GP later!

Uh oh, not good :P. I’d start by getting rid of the x bounds and then doing a non-centered parameterization for x to see if that helps (that’s what’s happening in the 2.17 manual, section “Latent variable GP” with the f = L_K * eta; stuff).

Lemme start by saying I’m pretty uniformed about this time series stuff, but I am interested enough in it to have bad opinions. So take whatever I say below here with a grain of salt.

What seems to be happening in that paper is they’re using a GP to non-parametrically estimate x (the state) and x’ (the derivatives of the state) from some data.

The part I don’t like is they introduce these F1/F2 things. I see what they’re doing with the math, but I do not see the intuition to it. It feels like the way they set up F1 and F2 was so that they could eventually add a term to their likelihood, MVN(f(x, \theta)|D x,A + \gamma I).

That’s why it seems sketchy to me. It’s clearly not a generative process, and I’m not sure why you’d talk about two separate random variables F1 and F2 and then say they have to be equal. Their argument seems to be based on being able to draw a nice graphical model.

Anyway, long story short, I think what they’re doing is super handwavy. There’s nothing super wrong with handwavy – kinda necessary in a lot of timeseries stuff I think, but if we’re gonna handwave why don’t we do something else?

First off, if it’s inferring states of an ODE and we have fixed hyperparameters, I’d look at Kalman filter stuff. That’s gonna be way faster than GPs (and judging by the title of the paper, speed is important). “Bayesian Filtering and Smoothing” by Simo Särkkä (edit: added the dots to the 'a’s) is really nice and compact and recent. He has a copy on his website for free (or grab a copy, it’s a nice little book): https://users.aalto.fi/~ssarkka/

Second, I’m assuming the goal of this gradient matching stuff is that eventually you want to be able to write down:

\dot{x} \sim \text{normal}(f(x, \theta), \sigma)

which is just a simple, non-linear regression (and this is really where I think MVN(f(x, \theta)|D x,A + \gamma I) came from)

So why not just do that? Estimate your states/gradients of states with some GP/Kalman filter thing, and then just use that as data in a Stan model.

You’re splitting the inference here, so Strictly Speaking this is frowned on, but it’s not terribly unrealistic. God knows any sensor data you get probably goes through fifteen dubiously tuned Kalman filters before it gets to you anyway, and this is the sort of thing it’s easy to evaluate with simulated data (does my posterior contain the truth? Is my filtering messing things up?)

There’s a pile of issues when you split the pipeline like this. If you fit x with a model, do you feed an estimate of the mean of x on to the next thing? Do you sample x from the posterior of the first model? How many data points do you use? Any of things can easily mess up the reliability of the posteriors in your final model. This is just something you’ll have to be aware of. It’s the downside of being handwavy, but it’s not so bad!

Also, your states might be positive, but that doesn’t mean the gradients have to be. So this is a handwavy argument for not having a positivity constraint on your GPs :P.

Anyway, interesting paper, I had fun looking at it. Hope this helps at all!