Interpolating continuous variable and dichotomizing

I am fitting a bayesian survival model. There is a continuous covariate that is of primary interest. I have very short time series for each patient (2 or 3 points). In the literature the relative change of this covariate during a specified “evaluation” window relative to baseline is utilized. However, this window is quite broad and so the time that these measurements are taken can be quite different between patients. Furhermore I have a significant number of patients whose measurements fall just outside this window.

So I thought it would be good to fit a simple linear model with random effects for slope and intercept (jointly with my survivla model) for each patient and interpolate this covariate at the sime time point for all patients. I did this, it worked out great.

However, in practice/literature this covariate is actually dichotomized (there is some threshold that is defined and patients are categorized as < or > this threshold). I know Stan cannot sample discrete variables due to the nature of Hamiltonian Monte Carlo, however, in my case I am sampling a continuous variable and then deriving a dichotomous variable from it.

Is this okay to do? When I coded it up I could not get the sampler to work out (even though it did just fine with the latent interpolated covariate case). I thought maybe it was because of the discrete-ness. I’ll attach my stan code (there is a lot going on) in case someone can spot something I missed:

functions {
 /* compute correlated group-level effects       
  * Args:
  *   z: matrix of unscaled group-level effects  
  *   SD: vector of standard deviation parameters
  *   L: cholesky factor correlation matrix      
  * Returns:
  *   matrix of scaled group-level effects
  */
  matrix scale_r_cor(matrix z, vector SD, matrix L) {
    // r is stored in another dimension order than z
    return transpose(diag_pre_multiply(SD, L) * z);
  }
}
data {
  int<lower=1> N;  // total number of observations
  int<lower=1> M; // total number of patients
  int<lower=1> NP; // prediction
  array[N] int<lower=1, upper=M> id;
  array[NP] int<lower=1, upper=M> pred_id;
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> Kc;  // number of population-level effects after centering
  // data for group-level effects of ID 1
  int<lower=1> R;  // number of coefficients per level
  // group-level predictor values
  vector[N] Z1;
  vector[N] Z2;
  int<lower=1> NC;  // number of group-level correlations
  matrix[M, K] X_baseline;  // population-level design matrix for size data
  matrix[M, K] X_evaluation;  // population-level design matrix for size data
  matrix[NP, K] X_pred;  // population-level design matrix for size data
  vector[M] Z1_baseline;
  vector[M] Z1_evaluation;
  vector[NP] Z1_pred;
  vector[M] Z2_baseline;
  vector[M] Z2_evaluation;
  vector[NP] Z2_pred;

  int<lower=0> J;  // Number of time intervals.
  vector[J] hPriorSh;  // Shape parameters for the gamma prior distribution of the baseline hazard.
  real c0;  // Rate parameter for the gamma prior distribution of the baseline hazard.
  
  int<lower=0> P;  // Dimensionality of the covariates.
  matrix[M, P] Xd;  // Matrix of covariates for survival
  
  matrix[M, J] R_tilde_minus_D_tilde;  // Matrix indicating risk set minus event set for each observation across intervals.
  matrix[M, J] D_tilde;  // Matrix indicating which intervals an observation has an event.
}
transformed data {
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  matrix[M, Kc] Xc_baseline;  // centered version of X without an intercept
  matrix[M, Kc] Xc_evaluation;  // centered version of X without an intercept
  matrix[NP, Kc] Xc_pred;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering

  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];

    Xc_baseline[, i - 1] = X_baseline[, i] - means_X[i - 1];
    Xc_evaluation[, i - 1] = X_evaluation[, i] - means_X[i - 1];
    Xc_pred[, i - 1] = X_pred[, i] - means_X[i - 1];
  }
}
parameters {
  vector[Kc] b;  // regression coefficients
  real Intercept;  // temporary intercept for centered predictors
  real<lower=0> sigma;  // dispersion parameter
  vector<lower=0>[R] sd;  // group-level standard deviations
  matrix[R, M] z;  // standardized group-level effects
  cholesky_factor_corr[R] L;  // cholesky factor of correlation matrix

  vector[P+2] beta;  // regression coefficients
  vector<lower=0>[J] h_seq;  // parameter with a gamma prior
}
transformed parameters {
  matrix[M, R] r;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[M] r1;
  vector[M] r2;
  real lprior = 0;  // prior contributions to the log posterior
  // compute actual group-level effects
  r = scale_r_cor(z, sd, L);
  r1 = r[, 1];
  r2 = r[, 2];

  lprior += student_t_lpdf(Intercept | 3, 2.3, 2.5);
  lprior += student_t_lpdf(sigma | 3, 0, 2.5) - 1 * student_t_lccdf(0 | 3, 0, 2.5);
  lprior += student_t_lpdf(b | 3, 0, 2.5);
  //lprior += student_t_lpdf(sd | 3, 0, 2.5) - 2 * student_t_lccdf(0 | 3, 0, 2.5);
  lprior += std_normal_lpdf(sd) - 2 * std_normal_lccdf(0);
  lprior += lkj_corr_cholesky_lpdf(L | 1);

  lprior += student_t_lpdf(beta | 3, 0, 2.5);
  lprior += gamma_lpdf(h_seq | hPriorSh, c0);
}
model {
  vector[N] mu = rep_vector(0.0, N);

  matrix[M, P+2] X_imp;
  matrix[M, J] exp_xbeta_mat;  // Matrix where each column is the exponential of X multiplied by beta.
  vector[J] first_sum;  // Vector to store the summation terms for the risk set minus event set.
  matrix[M, J] h_mat;  // Replicating the hazard sequence across `n` rows.
  matrix[M, J] h_exp_xbeta_mat;  // Matrix storing product of hazard sequence and the exponential transformation of X and beta.
  vector[J] second_sum;  // Vector to store the summation terms for the event set.
  vector[M] log_size_baseline = Intercept + Xc_baseline * b;
  vector[M] log_size_evaluation = Intercept + Xc_evaluation * b;
  vector[M] rel_change_10wk;

  for (n in 1:M){
    log_size_baseline[n] += r1[n] * Z1_baseline[n] + r2[n] * Z2_baseline[n];
    log_size_evaluation[n] += r1[n] * Z1_evaluation[n] + r2[n] * Z2_evaluation[n];
  }
  rel_change_10wk = (exp(log_size_evaluation) - exp(log_size_baseline)) ./ exp(log_size_baseline);
  for (k in 1:P){
    X_imp[,k] = Xd[,k];
  }
  for (i in 1:M){
    X_imp[i,P+1] = 0;
    X_imp[i,P+2] = 0;
    if (rel_change_10wk[i] <= -0.3){
      X_imp[i,P+1] = 1;
    }else if (rel_change_10wk[i] >= 0.2){
      X_imp[i,P+2] = 1;
    }
  }

  exp_xbeta_mat = rep_matrix(exp(X_imp * beta), J);  // Matrix where each column is the exponential of X multiplied by beta.
  h_mat = rep_matrix(h_seq', M);  // Replicating the hazard sequence across `n` rows.
  h_exp_xbeta_mat = -h_mat .* exp_xbeta_mat;  // Matrix storing product of hazard sequence and the exponential transformation of X and beta.
  for (j in 1:J) {
    first_sum[j] = sum(exp_xbeta_mat[, j] .* R_tilde_minus_D_tilde[, j]);  // Summing over the risk set minus event set for the `j-th` interval.
    second_sum[j] = sum(log1m_exp(h_exp_xbeta_mat[, j]) .* D_tilde[, j]);  // Summing over the event set for the `j-th` interval using the log1m_exp transformation.
  }

  target += sum(-h_seq .* first_sum + second_sum);  // Update the target log posterior with the likelihood component.

  mu += Intercept;
  for (n in 1:N) {
    mu[n] += r1[id[n]] * Z1[n] + r2[id[n]] * Z2[n];
  }
  target += normal_id_glm_lpdf(Y | Xc, mu, b, sigma);
  target += std_normal_lpdf(to_vector(z));
  target += lprior;
}
generated quantities {
}

The issue is that the sampled chains look terrible (low ESS, high Rhat, traceplots not nice fuzzy catterpillars, and each sample exceeded treedepth), note I had none of these issues when I just used the latent interpoltaed continuous covariate.

If this should be possible then I just assume that my sample size is much too small for this complex model (I only have 58 observations)

I haven’t looked closely at your stan code to verify what you’re doing, but from your description you sample a parameter representing the continuous value for the unobserved covariate, then theshold that value and use it in the downstream likelihood computation. This is a problem for Stan, because the likelihood function will now be discontinuous in the value of the continuous parameter, so gradient-based methods will break.

Alternatives might include:

  1. Interpolate and dichotomize the covariate in a first step, then treat that information as fixed in a second step.
  2. Probabilistically interpolate whether the covariate value for each observation is greater than or less than the threshold, such that you have a parameter interpretable as the probability that the parameter is above the threshold, then marginalize over the possibilities that it is above and below the threshold.
1 Like