I am fitting a bayesian survival model. There is a continuous covariate that is of primary interest. I have very short time series for each patient (2 or 3 points). In the literature the relative change of this covariate during a specified “evaluation” window relative to baseline is utilized. However, this window is quite broad and so the time that these measurements are taken can be quite different between patients. Furhermore I have a significant number of patients whose measurements fall just outside this window.
So I thought it would be good to fit a simple linear model with random effects for slope and intercept (jointly with my survivla model) for each patient and interpolate this covariate at the sime time point for all patients. I did this, it worked out great.
However, in practice/literature this covariate is actually dichotomized (there is some threshold that is defined and patients are categorized as < or > this threshold). I know Stan cannot sample discrete variables due to the nature of Hamiltonian Monte Carlo, however, in my case I am sampling a continuous variable and then deriving a dichotomous variable from it.
Is this okay to do? When I coded it up I could not get the sampler to work out (even though it did just fine with the latent interpolated covariate case). I thought maybe it was because of the discrete-ness. I’ll attach my stan code (there is a lot going on) in case someone can spot something I missed:
functions {
/* compute correlated group-level effects
* Args:
* z: matrix of unscaled group-level effects
* SD: vector of standard deviation parameters
* L: cholesky factor correlation matrix
* Returns:
* matrix of scaled group-level effects
*/
matrix scale_r_cor(matrix z, vector SD, matrix L) {
// r is stored in another dimension order than z
return transpose(diag_pre_multiply(SD, L) * z);
}
}
data {
int<lower=1> N; // total number of observations
int<lower=1> M; // total number of patients
int<lower=1> NP; // prediction
array[N] int<lower=1, upper=M> id;
array[NP] int<lower=1, upper=M> pred_id;
vector[N] Y; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
int<lower=1> Kc; // number of population-level effects after centering
// data for group-level effects of ID 1
int<lower=1> R; // number of coefficients per level
// group-level predictor values
vector[N] Z1;
vector[N] Z2;
int<lower=1> NC; // number of group-level correlations
matrix[M, K] X_baseline; // population-level design matrix for size data
matrix[M, K] X_evaluation; // population-level design matrix for size data
matrix[NP, K] X_pred; // population-level design matrix for size data
vector[M] Z1_baseline;
vector[M] Z1_evaluation;
vector[NP] Z1_pred;
vector[M] Z2_baseline;
vector[M] Z2_evaluation;
vector[NP] Z2_pred;
int<lower=0> J; // Number of time intervals.
vector[J] hPriorSh; // Shape parameters for the gamma prior distribution of the baseline hazard.
real c0; // Rate parameter for the gamma prior distribution of the baseline hazard.
int<lower=0> P; // Dimensionality of the covariates.
matrix[M, P] Xd; // Matrix of covariates for survival
matrix[M, J] R_tilde_minus_D_tilde; // Matrix indicating risk set minus event set for each observation across intervals.
matrix[M, J] D_tilde; // Matrix indicating which intervals an observation has an event.
}
transformed data {
matrix[N, Kc] Xc; // centered version of X without an intercept
matrix[M, Kc] Xc_baseline; // centered version of X without an intercept
matrix[M, Kc] Xc_evaluation; // centered version of X without an intercept
matrix[NP, Kc] Xc_pred; // centered version of X without an intercept
vector[Kc] means_X; // column means of X before centering
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
Xc_baseline[, i - 1] = X_baseline[, i] - means_X[i - 1];
Xc_evaluation[, i - 1] = X_evaluation[, i] - means_X[i - 1];
Xc_pred[, i - 1] = X_pred[, i] - means_X[i - 1];
}
}
parameters {
vector[Kc] b; // regression coefficients
real Intercept; // temporary intercept for centered predictors
real<lower=0> sigma; // dispersion parameter
vector<lower=0>[R] sd; // group-level standard deviations
matrix[R, M] z; // standardized group-level effects
cholesky_factor_corr[R] L; // cholesky factor of correlation matrix
vector[P+2] beta; // regression coefficients
vector<lower=0>[J] h_seq; // parameter with a gamma prior
}
transformed parameters {
matrix[M, R] r; // actual group-level effects
// using vectors speeds up indexing in loops
vector[M] r1;
vector[M] r2;
real lprior = 0; // prior contributions to the log posterior
// compute actual group-level effects
r = scale_r_cor(z, sd, L);
r1 = r[, 1];
r2 = r[, 2];
lprior += student_t_lpdf(Intercept | 3, 2.3, 2.5);
lprior += student_t_lpdf(sigma | 3, 0, 2.5) - 1 * student_t_lccdf(0 | 3, 0, 2.5);
lprior += student_t_lpdf(b | 3, 0, 2.5);
//lprior += student_t_lpdf(sd | 3, 0, 2.5) - 2 * student_t_lccdf(0 | 3, 0, 2.5);
lprior += std_normal_lpdf(sd) - 2 * std_normal_lccdf(0);
lprior += lkj_corr_cholesky_lpdf(L | 1);
lprior += student_t_lpdf(beta | 3, 0, 2.5);
lprior += gamma_lpdf(h_seq | hPriorSh, c0);
}
model {
vector[N] mu = rep_vector(0.0, N);
matrix[M, P+2] X_imp;
matrix[M, J] exp_xbeta_mat; // Matrix where each column is the exponential of X multiplied by beta.
vector[J] first_sum; // Vector to store the summation terms for the risk set minus event set.
matrix[M, J] h_mat; // Replicating the hazard sequence across `n` rows.
matrix[M, J] h_exp_xbeta_mat; // Matrix storing product of hazard sequence and the exponential transformation of X and beta.
vector[J] second_sum; // Vector to store the summation terms for the event set.
vector[M] log_size_baseline = Intercept + Xc_baseline * b;
vector[M] log_size_evaluation = Intercept + Xc_evaluation * b;
vector[M] rel_change_10wk;
for (n in 1:M){
log_size_baseline[n] += r1[n] * Z1_baseline[n] + r2[n] * Z2_baseline[n];
log_size_evaluation[n] += r1[n] * Z1_evaluation[n] + r2[n] * Z2_evaluation[n];
}
rel_change_10wk = (exp(log_size_evaluation) - exp(log_size_baseline)) ./ exp(log_size_baseline);
for (k in 1:P){
X_imp[,k] = Xd[,k];
}
for (i in 1:M){
X_imp[i,P+1] = 0;
X_imp[i,P+2] = 0;
if (rel_change_10wk[i] <= -0.3){
X_imp[i,P+1] = 1;
}else if (rel_change_10wk[i] >= 0.2){
X_imp[i,P+2] = 1;
}
}
exp_xbeta_mat = rep_matrix(exp(X_imp * beta), J); // Matrix where each column is the exponential of X multiplied by beta.
h_mat = rep_matrix(h_seq', M); // Replicating the hazard sequence across `n` rows.
h_exp_xbeta_mat = -h_mat .* exp_xbeta_mat; // Matrix storing product of hazard sequence and the exponential transformation of X and beta.
for (j in 1:J) {
first_sum[j] = sum(exp_xbeta_mat[, j] .* R_tilde_minus_D_tilde[, j]); // Summing over the risk set minus event set for the `j-th` interval.
second_sum[j] = sum(log1m_exp(h_exp_xbeta_mat[, j]) .* D_tilde[, j]); // Summing over the event set for the `j-th` interval using the log1m_exp transformation.
}
target += sum(-h_seq .* first_sum + second_sum); // Update the target log posterior with the likelihood component.
mu += Intercept;
for (n in 1:N) {
mu[n] += r1[id[n]] * Z1[n] + r2[id[n]] * Z2[n];
}
target += normal_id_glm_lpdf(Y | Xc, mu, b, sigma);
target += std_normal_lpdf(to_vector(z));
target += lprior;
}
generated quantities {
}
The issue is that the sampled chains look terrible (low ESS, high Rhat, traceplots not nice fuzzy catterpillars, and each sample exceeded treedepth), note I had none of these issues when I just used the latent interpoltaed continuous covariate.
If this should be possible then I just assume that my sample size is much too small for this complex model (I only have 58 observations)