I’m writing a stan model and would like to use penalized smooths similar to how brms implements s(x)
.
Thanks to TJ Mahr’s excellet blog post on the topic, I feel like I have a handle on how brms might handle these smooths: there is 1 fixed effect (which appears to be a linear component) plus a bunch of random effects. The random effects are applied to the basis functions to compute the smooth.
I’m looking for clarification on what matrices are extracted from mgcv
to be used in the model. Below, I show a small example that I think is correct and compare it to a brms fit. I use the mcycle
data, similar to TJ’s post.
The stan model is
data{
int n;
int k;
int k2;
vector[n] y;
matrix[n, k] X;
matrix[n, k2] Z;
}
parameters{
real Intercept;
vector[k] beta;
vector[k2] gamma;
real<lower=0> sigma;
real<lower=0> sigma2;
}
transformed parameters{
vector[n] mu = X * beta + Z *(sigma2 .* gamma ) + Intercept;
}
model{
Intercept ~ student_t(3, -13.3, 35.6);
beta ~ normal(0, 1);
gamma ~ normal(0, 1);
sigma ~ student_t(3.5, 0, 35.6);
sigma2 ~ student_t(3.5, 0, 35.6);
y ~ normal(mu, sigma);
}
The matrices X and Z are extracted in the following way
library(tidyverse)
library(mgcv)
mcycle <- MASS::mcycle %>%
tibble::rowid_to_column(var = 'i')
sm <- smoothCon(
s(times, k=-1),
data=mcycle,
absorb.cons = T,
diagonal.penalty = T
)
re <- smooth2random(sm[[1]], "", type=2)
X <- re$Xf
Z <- re$rand$Xr
When I this model and a similar model using brms, I get very similar estimates. Show below is a plot of the predictions, where the line is the mean of mu
from my model, and the dots are the results from predict(fit_brms)
. Things look pretty good, so I’m hopeful my appraoch is correct, but wanted to check.
Full code for reproducibility
library(tidyverse)
library(mgcv)
library(brms)
library(tidybayes)
mcycle <- MASS::mcycle %>%
tibble::rowid_to_column(var = 'i')
# Fit with brms
brms_formula <- accel ~ s(times, k=10)
fit_brms <- brm(
brms_formula,
prior = c(
prior(normal(0, 1), class = 'b')
),
data = mcycle,
backend = 'cmdstanr',
adapt_delta = 0.99
)
pred <- predict(fit_brms) %>%
bind_cols(mcycle)
# -------------------------------------------------------------------------
stan_code <- '
data{
int n;
int k;
int k2;
vector[n] y;
matrix[n, k] X;
matrix[n, k2] Z;
}
parameters{
real Intercept;
vector[k] beta;
vector[k2] gamma;
real<lower=0> sigma;
real<lower=0> sigma2;
}
transformed parameters{
vector[n] mu = X * beta + Z *(sigma2 .* gamma ) + Intercept;
}
model{
Intercept ~ student_t(3, -13.3, 35.6);
beta ~ normal(0, 1);
gamma ~ normal(0, 1);
sigma ~ student_t(3.5, 0, 35.6);
sigma2 ~ student_t(3.5, 0, 35.6);
y ~ normal(mu, sigma);
}
'
sm <- smoothCon(
s(times, k=-1),
data=mcycle,
absorb.cons = T,
diagonal.penalty = T
)
re <- smooth2random(sm[[1]], "", type=2)
X <- re$Xf
Z <- re$rand$Xr
stan_data <- list(
n= nrow(mcycle),
y = mcycle$accel,
X = X,
Z = Z,
k = ncol(X),
k2 = ncol(Z)
)
stan_code %>%
write_stan_file() %>%
cmdstan_model() -> model
fit_stan <- model$sample(stan_data, adapt_delta = 0.99)
fit_stan %>%
spread_draws(mu[i]) %>%
mean_qi(mu) %>%
inner_join(mcycle) %>%
ggplot(aes(times, mu)) +
geom_line() +
# geom_point(data=mcycle, aes(times, accel), color='red', inherit.aes = F) +
geom_point(data=pred, aes(times, Estimate ), color='red')