I read this https://discourse.mc-stan.org/t/hilbert-space-approximate-gp-prior-why-use-de-meaned-basis-functions-phi/25852 older thread, thats why a new one about hilbert space GP approximation.
@avehtari mentioned:
Removing the intercept term removes just one correlating variable, but half of the basis functions are still correlating, and thus removing the intercept term is not a complete solution.
Thus I thought about why not QR-decompose the basis-functions as mentioned here: Stan User manual QR decompositon and come up with the following Stan program:
functions {
// Spectral densities.
vector diagSPD_EQ(real alpha, real rho, real L, int M) {
vector[M] j = linspaced_vector(M, 1, M);
vector[M] lambda = (j * pi() / (2*L))^2;
vector[M] s = (alpha^2) * sqrt(2*pi()) * rho * exp(-0.5 * (rho^2) * lambda);
return sqrt(s);
}
// Basis functions [N x M]
matrix PHI(int N, int M, real L, vector x, int QR_decomp) {
vector[M] j = linspaced_vector(M, 1, M);
vector[M] sqrt_lambda = j * pi() / (2 * L);
matrix[N, M] x_L = rep_matrix(x + L, M);
matrix[N, M] phi = sin(diag_post_multiply(x_L, sqrt_lambda)) / sqrt(L);
if(QR_decomp)
return qr_thin_Q(phi) * sqrt(N - 1);
return phi;
}
}
data {
int<lower=1> N; // number of observations
vector[N] x; // univariate covariate
real<lower=0> c_f; // boundary factor
int<lower=1> M_f; // number of basis functions
real<lower=0> lengthscale_f; // lengthscale of f
real<lower=0> sigma_f; // scale of f
int<lower = 0, upper = 1> QR_phi; // QR decompose basis functions?
}
transformed data {
// Normalize data
real xmean = mean(x);
real xsd = sd(x);
vector[N] xn = (x - xmean)/xsd;
// Boundary value
real L_f = c_f * max(xn);
// Basis functions for f
matrix[N,M_f] PHI_f = PHI(N, M_f, L_f, xn, QR_phi);
// spectral densities
vector[M_f] diagSPD_f = diagSPD_EQ(sigma_f, lengthscale_f, L_f, M_f);
}
generated quantities {
vector[N] f = PHI_f * (diagSPD_f .* to_vector(normal_rng(0, rep_vector(1, M_f))));
}
library(cmdstanr)
library(ggplot2)
library(dplyr)
library(rstan)
fileq <- "gb_fit.stan"
# Compile model
modq <- cmdstan_model(stan_file = fileq)
# Input data
dat <- list(x = seq(0,1,length.out=100),
N = 100,
c_f = 1.5,
M_f = 40,
sigma_f = 1,
lengthscale_f = 1,
QR_phi = 1 # bool for QR decomp. basis functions
)
# QR
s_QR <- modq$sample(data=dat, fixed_param=TRUE, chains=1,adapt_engaged=F,
iter_warmup = 0, iter_sampling=1000)
dat$QR_phi <- 0
s_PHI_f <- modq$sample(data=dat, fixed_param=TRUE, chains=1,adapt_engaged=F,
iter_warmup = 0, iter_sampling=1000)
# Plotting helper
f_q <- function(s, x=xn, id="x") {
f <- s$draws('f')
fmat <- matrix(0, nrow=dim(f)[1], ncol=dim(f)[3])
for (i in 1:dim(f)[1]) fmat[i, ] <- as.numeric(f[i, 1, ])
u <- apply(fmat, 2, quantile, prob=0.95)
l <- apply(fmat, 2, quantile, prob=0.05)
data.frame(x = x, u=u, l=l, phi=id)
}
# Plotting
xn <- as.numeric(scale(dat$x))
dq <- rbind(f_q(s_QR, id="QR PHI bf"), f_q(s_PHI_f, id="PHI bf"))
dq %>%
ggplot(aes(x=x, ymin=l, ymax=u, lty=phi, col=phi, fill=phi)) +
geom_ribbon(lwd=0.8, alpha = 0.3) +
theme_classic() +
labs(title = "90% Coverage Interval", y = "f") +
theme(legend.text = element_text(size=12),
legend.position = 'top',
legend.title = element_blank()
)
The variances are on a different scaling, but the rest looks good:
This approach makes the basis function linearity independent and thus improves the sampling.
I am right with that? Or do I miss something?