Understanding the Likelihood Model for the Cox Family in brms

I am trying to understand the basic approach of the cox() family in brms. My understanding was that the idea is to model the baseline rate as a spline using the spline2 package, and then the covariates are proportional hazards off this baseline hazard rate.

Effectively, rather than having a semi-parametric baseline rate like in the classic Cox-PH, you instead have a modelled baseline curve using a spline.

So far, so good.

The problem lies when I try to reconstruct that spline from the various parameters outputted by the model. I appreciate this output is not available currently in brms, but I am happy to hack out a solutoin for this myself, but I am not quite understanding the logic in the cox() family code in brms.

I was extracting the spline but the baseline rate seems to start a 1 and so my estimated survival rates are much too low, despite both the brms and classic mle mostly agreeing on the values of the model parameters.

Any pointers welcome and happy to share code and any other bits that might help.!

lets assume your brms model fit is called bfit. First lets get the baseline hazard (m-spline) basis matrix. This matrix is based on the time to event values that are being modeled.

s_data <- standata(bfit)
b0 <- brms:::bhaz_basis_matrix(s_data$Y, list(df=5, intercept=TRUE)) ## you might need to modify the arguments passed to the basis function if you changed the defaults

now lets pull out the posterior samples of the coefficients of the splines

sbhaz_post <- rstan::extract(bfit$fit, "sbhaz")$sbhaz ## You'll need to do some more indexing if you stratified your model, i'll leave that as an exercise

Let’s construct a new basis matrix for large range of time points that we wish to plot. These will need to be I-splines so we make sure to set integrate=TRUE

cb <- brms:::bhaz_basis_matrix( seq(0, max(s_data$Y), by=0.1), basis=b0, integrate=TRUE)

Now samples of the baseline hazard are given by cb %*% t(sbhaz_post)

You can plot it with something like this:

tibble(
    time = seq(0, max(s_data$Y),by=0.1), 
    cum_haz = rowMeans(cb %*% t(sbhaz_post)), 
    lower = apply(cb %*% t(sbhaz_post), 1, quantile, 0.05/2), 
    upper = apply(cb %*% t(sbhaz_post), 1, quantile, 1-0.05/2)
) |> 
ggplot(aes(x=time, y = cum_haz)) + 
    geom_line() + 
    geom_ribbon(aes(ymin=lower,ymax=upper),alpha=0.15)