How to model a lag function?

I am attempting to model the effects of drought on an ecological response variable. This is somewhat complicated, because the effect is believed to be an aggregate (smooth) function of predictors over an unknown time window, with an unknown lag in its effect on the response. I would prefer to code this in brms if possible, and if it’s not then I’d like to model it with rstan.

The generative model is:

y \sim \mathcal{N}(\mu, \sigma)
\mu = \eta(\text{lag}_{l}(f(x, w)))

Where y is the response variable, x is a set of predictors (for simplicity, just one), \text{lag()} delays the effect of f on y by time l, and f is an arbitrary backward-looking smooth function of predictor x with parameters w.

A simple implementation follows, for which the smooth function f is a rolling mean with width w = 5 and lag time is l = 2. Data are simulated for \mu = \beta_0 + \beta_1 \times \text{lag}_2(\text{rolling mean}(x, 5)), with fixed parameters \beta_0 = 0 and \beta_1 = 0.2.

The goal is to infer lag time l, aggregate window w for a rolling mean, and parameters \beta_0, \beta_1 from the data, which include only x, y, and the time index, \text{year}. I have not included brms or rstan code, because besides the linear function, I’m not sure how to get started.

I found a related post here, in which @paul.buerkner wrote that lag functions were not available in brms at that time (4 years ago), so I’m hopeful that they have been implemented since then, or failing that that I can find a solution with rstan.

# Simulate drought effect

# Generative model: ----
# Data represent a time series
# y is the response
# x is a continuous numeric predictor
# y ~ Normal(mu, sd)
# mu = lag(f(x))
# lag is a function that delays the effect of f on mu for some time.
# f is a backward-looking smoothing function.

# Implementation ----
library(dplyr, warn.conflicts = FALSE)
library(purrr)
## Model parameters ----
# Goal is to estimate b0, b1, w, and l from the data.
b0 <- 0     # intercept for y ~ b0 + b1 * f(x)
b1 <- 0.2   # linear effect of precip on y

w <- 5    # Width of the aggregation window for x effecting y.
l <- 2    # Lag time between aggregated x and its effect on y.

# y ~ N(mu = b0 + b1 * x, sd = 1)
y_fun <- function (x) {
  if (is.na(x)) return(NA)
  y <- rnorm(n = 1, mean = b0 + b1 * x, sd = 1)
  return(y)
}

# Smooth function for x. Could be anything, but a rolling mean is simple.
# Return a rolling mean with a fixed window for a numeric vector.
rolling_mean <- function (x, width) {
  if (length(x) < width) stop("Rolling window must not be smaller than input vector")
  result <- numeric(length = length(x))
  result[1:(width - 1)] <- NA
  for (i in width:length(x)) {
    result[i] <- mean(x[(i - width + 1):i])
  }
  return(result)
}

# Simulate data
{
  set.seed(42)
  sim_years <- 99 + w + l
  precip <- rnorm(sim_years, mean = 10, sd = 2)
  precip_5yr <- rolling_mean(precip, w)
  precip_5yr_2lag <- lag(precip_5yr, n = 2)
  y <- map_dbl(precip_5yr_2lag, .f = y_fun)
  df <- tibble(precip, y) |> 
    na.omit() %>% 
    mutate(year = row_number())
  
}
df
#> # A tibble: 100 × 3
#>    precip     y  year
#>     <dbl> <dbl> <int>
#>  1  13.0  1.75      1
#>  2   9.81 1.94      2
#>  3  14.0  2.41      3
#>  4   9.87 2.31      4
#>  5  12.6  2.27      5
#>  6  14.6  2.37      6
#>  7   7.22 1.89      7
#>  8   9.44 1.93      8
#>  9   9.73 0.672     9
#> 10  11.3  1.77     10
#> # ℹ 90 more rows

Created on 2024-01-09 with reprex v2.0.2