Simple Hierarchical Spatial Inference Model

Hi Stan Community,

Sorry for the long post – but at least the model is short!

I was hoping for some help learning how to use Stan to develop a hierarchical spatial model that infers a complete spatial field from noisy, biased, and gappy observations. Specifically, I’m trying to recreate the simple example here (Section 4). There is even a Matlab implementation of a custom Gibbs Sampler (zip file here;/GHCN_CONUS_Demo; note the link in the previous document does not work).

I’ve tried to start very basic, by regarding the posterior medians of the parameters and the process (from the Matlab solution) as ‘truth’, giving them as data to Stan, and then just letting it infer the process. This produces the correct result and is reasonably fast.

However, as soon as I start letting it infer the parameters it takes quite awhile, despite all the priors being conjugate. For example, the attached model took 1.5 hours for 500 iterations. The Gibbs sampler in Matlab takes about a minute for 1000 iterations.

Surely there must be something wrong with how I’m specifying the model. I suppose there could also be a problem with this debugging strategy, but I also have tried to specify the model completely with the real input data, as well as noise added to the ‘true’ posterior process, and have the same problems. I’ve also tried specifying an upper bound on the covariance parameters but this didn’t make a difference.

The model code I’m attaching is a demonstration of inferring just the sig_2 parameter. As noted, it took 1.5 hours for 500 iterations, produced a reasonable looking posterior process, but substantially underestimated the parameter : the median and max of sig_2 are 7.9 with a max of 9.8, while the posterior median ‘truth’ is 16.7.

A specific question, that perhaps betrays my lack of understanding of Stan, is: why am I getting this warning:

Informational Message: The current Metropolis proposal is about to be rejected because of the following issue:
Exception: multi_normal_lpdf: LDLT_Factor of covariance parameter is not positive definite.  last conditional variance is 0. 

I don’t see how sig_2 * R could not be positive definite. This only happens early in the simulation, I guess it could be drawing a 0, but still seems strange.
Another problem I have is that the model hangs after finishing sampling. This happens when I specify the full covariance matrix (SIG_2 in code) in the transformed parameter. My guess is that it is struggling with the size of the matrix (730 x 730 x 500), but I don’t think this is THAT big…

Any general help would be much appreciated, including pointing to any examples/papers implementing such a model in Stan, which I haven’t been able to find. I can provide the pystan driver and data if you’d like, as well as figures comparing the posterior between Stan and the Gibbs Sampler in Matlab.

            data {
              int<lower=0> N;        // process locations to infer at
              vector[N] z;           // temperature observations
              real mu;               // process mean
              real<lower=0.0> tau_2; // data covariance
              cov_matrix[N] R;       // correlation matrix (scaled distances between locations)
            }
            transformed data {
              vector[N] MU;
              // cov_matrix[N] TAU_2;

              MU    = rep_vector(mu, N);
              // TAU_2 = diag_matrix(rep_vector(tau_2, N));
            }

            parameters {
              vector[N] y;            // process
              real<lower=0> sig_2;    // process covariance
            }

            transformed parameters {
             //cov_matrix[N] SIG_2;
             //SIG_2 = sig_2 * R;
            }

            model {
                // priors
                // mu     ~ normal(0, 25);
                // tau_2  ~ inv_gamma(5, 25);
                sig_2  ~ inv_gamma(5, 100);

                // process/state level
                y ~ multi_normal(MU, sig_2*R);


                // data/observation level
                //z ~ multi_normal(y, TAU_2); // believe this produces same result as below
                z ~ normal(y, sqrt(tau_2);

            }

Thanks,
CC

So this seems to be a special case of a 2D Gaussian Process (which I think is called a Gaussian Random Field) where isotropy has been assumed (collapsing the 2D to 1D) and an empirical covariance matrix determined simply by the squared distances (so, no latent-covariance parameters modelled). There’s a couple ways to speed up this specific implementation, and I’ll include those below, but if you’re not already familiar with GPs I suggest reading the SUG section on GPs as the implementation there is rather more flexible/general.

[posting the above now so others don’t waste time responding while I’m coding/verifying my implementations…]

Sticking with your restricted implementation, but renaming a few things for clarity (note also use of SDs rather than variances), it should be faster if you pre-compute the cholesky decomposition of the correlation matrix in the transformed data then use a so-called “non-centered” parameterization for the latent function samples. If you’re not doing inference on mu specifically, it’d also be faster to simply subtract mu from y before sampling. You could do that outside Stan, but I’ll do it in transformed data for clarity:

data {
  int<lower=0> N;        // process locations to infer at
  vector[N] y;           // temperature observations
  real mu;               // latent process mean
  real<lower=0.0> noise; // observation noise (SD)
  cov_matrix[N] R ; //correlation matrix
}
transformed data {
  vector[N] y_minus_mu  = y - mu ;
  matrix[N,N] chol_R = cholesky_decompose(R) ;
}
parameters {
  vector[N] f_prime ;  // helper for efficient mvn sampling of the latent process
  real<lower=0> f_amplitude ;  // latent amplitude
}
transformed parameters{
  row_vector[N] f = transpose(	diag_pre_multiply(rep_vector(f_amplitude,N),chol_R)	* f_prime ) ;
}
model {
  f_amplitude ~ weibull(2,1) ; //change this to your prior on the latent amplitude
  f_prime ~ std_normal();
  y_minus_mu ~ normal(f, noise) ;
}

And if you want to play, here’s an R script to generate data from a 1D GP then fit it with a variant of the above (computing the covariance matrix in transformed data bc I was lazy and knew how to do it quickly in Stan):


library(tidyverse)
library(cmdstanr)

sim_code = "
//adapted from SUG 10.2 (https://mc-stan.org/docs/2_25/stan-users-guide/simulating-from-a-gaussian-process.html)
// much more efficient through use of cholesky & non-centered parameterization
data {
  int<lower=1> N;
  real x[N];
  real<lower=0.0> f_amplitude ;
  real<lower=0.0> f_lengthscale ;
}
transformed data {
  vector[N] mu = rep_vector(0, N);
  matrix[N, N] K = cov_exp_quad(x, f_amplitude, f_lengthscale);
  for(n in 1:N){
    K[n,n] = K[n,n] + 1e-15; //to ensure positive-definite
  }
  matrix[N,N] chol_K = cholesky_decompose(K) ;
}
parameters {
  vector[N] f_;
}
transformed parameters{
  vector[N] f = chol_K*f_ ;
}
model {
  f_ ~ std_normal();
}
"

sim_mod = 
	(
		sim_code
		%>% cmdstanr::write_stan_file()
		%>% cmdstanr::cmdstan_model()
	)


x = seq(-10,10,length.out = 100)

sim_fit = sim_mod$sample(
	data = tibble::lst(
		x = x
		, N = length(x)
		, f_amplitude = 1
		, f_lengthscale = 1
	)
	, chains = 1
	, iter_warmup = 1e3 #we DO want to warmup, even when generating data!
	, iter_sampling = 1
)

f = as.numeric(sim_fit$draws(variables='f'))

#show the sample latent function
dat = tibble::tibble(x=x,f=f)
(
	dat
	%>% ggplot(aes(x=x,y=f))
	+ geom_line()
	+ geom_point()
)
#nice and smooth!

#add some noise
dat = 
	(
		dat
		%>% dplyr::mutate(
			y = rnorm(n(),f,.5)
		)
	)

#show the noisy observations
(
	dat
	%>% ggplot()
	+ geom_line(aes(x=x,y=f),colour='red')
	+ geom_point(aes(x=x,y=y),colour='blue')
)
#noisier

# code for inference
inference_code = "
data {
  int<lower=0> N;        // process locations to infer at
  vector[N] y;           // temperature observations
  real x[N] ; //locations
  real mu;               // latent process mean
  //real<lower=0.0> f_amplitude ; //latent process amplitude
  real<lower=0.0> f_lengthscale; // latent process lengthscale
  real<lower=0.0> noise; // observation noise (SD)
}
transformed data {
  vector[N] y_minus_mu  = y - mu ;
  matrix[N, N] K = cov_exp_quad(x, 1.0 , f_lengthscale);
  for(n in 1:N){
    K[n,n] = K[n,n] + 1e-6; //to ensure positive-definite
  }
  matrix[N,N] chol_K = cholesky_decompose(K) ;
}
parameters {
  vector[N] f_prime ;  // helper for efficient mvn sampling of the latent process
  real<lower=0> f_amplitude ;  // latent amplitude
}
transformed parameters{
  row_vector[N] f = transpose(	diag_pre_multiply(rep_vector(f_amplitude,N),chol_K)	* f_prime ) ;
}
model {
  f_amplitude ~ weibull(2,1) ;
  f_prime ~ std_normal();
  y_minus_mu ~ normal(f, noise) ;
}
"

inference_mod = 
	(
		inference_code
		%>% cmdstanr::write_stan_file()
		%>% cmdstanr::cmdstan_model()
	)

#fit inference_mod
fit = inference_mod$sample(
	data = tibble::lst(
		y = dat$y
		, x = dat$x
		, N = length(y)
		, mu = 0 # treating mu as known
		, f_lengthscale = 1 #treating as known
		, noise = 0.5 #treating sigma as known
	)
	, chains = 1
	, iter_warmup = 1e3
	, iter_sampling = 1e3
)

#check diagnostics
fit$cmdstan_diagnose()

#look at the posterior for f_amplitude
(
	fit$draws(variables='f_amplitude')
	%>% bayesplot::mcmc_dens()
)

#look at the posterior on the latent function
(
	fit$draws(variables='f')
	%>% posterior::as_draws_df()
	%>% tibble::as_tibble()
	%>% dplyr::select(-.chain,-.iteration)
	%>% tidyr::pivot_longer(
		cols = c(-.draw)
	)
	%>% dplyr::mutate(
		name = stringr::str_replace(name,stringr::fixed('f['),'')
		, name = stringr::str_replace(name,stringr::fixed(']'),'')
		, xindex = as.numeric(name)
		, x = dat$x[xindex]
	)
	%>% dplyr::select(-name,-xindex)
	%>% dplyr::group_by(x)
	%>% dplyr::summarise(
		mean = mean(value)
		, lo50 = quantile(value,.25)
		, hi50 = quantile(value,.75)
		, lo95 = quantile(value,.025)
		, hi95 = quantile(value,.975)
		, .groups = 'drop'
	)
	%>% ggplot()
	+ geom_ribbon(
		aes(
			x = x
			, ymin = lo95
			, ymax = hi95
		)
		, alpha = .5
	)
	+ geom_ribbon(
		aes(
			x = x
			, ymin = lo50
			, ymax = hi50
		)
		, alpha = .5
	)
	+ geom_line(
		aes(
			x = x
			, y = mean
		)
	)
	+ geom_line(
		data = dat
		, mapping = aes(x=x,y=f)
		, colour = 'red'
	)
	+ geom_point(
		data = dat
		, mapping = aes(x=x,y=y)
		, colour = 'blue'
	)
)
4 Likes

Not sure if discourse sends a new notification when an edit has occurred, but I just updated my response with the promised code/demos.

1 Like

This is absolutely incredible! Thank you so much. I will go through and may pester you again before long… :D.

1 Like

Hi @mike-lawrence I’ve looked through most of the code you shared and implemented the speed up for the full hierarchical model. It seems to work nicely. I’m still very about why my model is so slow though.

I think the main reason is that Stan isn’t picking up on the conjugacy of the priors.
Could you explain why Stan doesn’t just update the posterior using the closed form? Is it perhaps confused by the inv_gamma instead of an an Inverse-Wishart?

Thanks,
CC

I suspect it’s due to the centered parameterization, which is known to cause sampling issues.

Indeed, as far as I know Stan doesn’t care about congugacy whatsoever. That is, when expressing a model, each distribution used to express structure in the model is used behind the scenes (or explicitly if you use target += syntax) to increment an accumulator for the log-probability. When cmdstan runs, it sees data (things it shouldn’t change), parameters (things it will change) and a single function that takes the data and parameters as input and spits out the log-probability as output. It uses auto-diff to build a function that also yields (for a given set of input data and parameters) the gradient (i.e. the slope of the log-probability in the parameter space at the location implied by the input parameters), and uses this combined with dynamic HMC to achieve exploration of parameter values that (hopefully) converge on the posterior implied by the data & model. So far as I understand, conjugacy is a property that’s useful if, instead of MC sampling, you are working out the posterior analytically. For MC samplers, it’s irrelevant. Well, except in the sense that a mathematically-equivalent-but-computationally-simpler expression might be achievable if you put the work in to derive one. I think the section of the GP chapter showing how you can integrate out the latent function would be an example of this.

3 Likes

FWIW here is an example that uses some analytic results based on conjugacy to speed up a Stan model, I think by directly coding the easier-to-sample expression that arises from conjugacy for the part of the model that is conjugate.

Disclaimer: I haven’t taken the time to really understand what goes on at that link, so I’m taking on faith that it does what it says it does.

3 Likes

Thanks for all the information, guys. The GP chapter is helpful as well as that conjugacy model comparison. I will probably work on developing a similar approach, as the full model I’m leveraging (/expanding) was designed with conjugacy in mind.

1 Like