Hi all, although I found a very similar thread on the forum (Loo: calculating point-wise log-lik when using data augmentation for censored observations), I don’t think my question has been answered.
Consider the following model Y \sim \mathcal{N}(aX + b, \sigma), where some of the observations Y are left-censored. We can integrate out the unobserved data using normal_lcdf
and we get the following Stan model (C[n] == 0
indicates that the n
th observation is uncensored, and C[n] == 1
indicates that it is left-censored).
// cens_reg.stan
data {
int N;
array[N] real X;
array[N] real Y;
array[N] int C;
}
parameters {
real a;
real b;
real<lower=0> sigma;
}
transformed parameters {
array[N] real log_lik;
for ( n in 1:N ) {
if ( C[n] == 0 ) {
log_lik[n] = normal_lpdf(Y[n] | a * X[n] + b, sigma);
} else {
log_lik[n] = normal_lcdf(Y[n] | a * X[n] + b, sigma);
}
}
}
model {
target += log_lik;
}
I now want to compare this model with other models using PSIS-LOO. However, this leads to “very bad” pareto-k values. These correspond to some of the left-censored observations that are below the observation limit (equal to -0.5). So they have a log-probability close to 0. The following figure shows the (potentially) censored observations (black) and the “ground truth” values (red), together with the fitted linear model (blue). The bottom panel is made with plot_khat
from the arviz
package.
My question is: Does anyone know how to resolve this? Can we just ignore such warnings and trust the loo
results?
This is the python code I used to create the figure:
import cmdstanpy
import arviz as az
import matplotlib.pyplot as plt
import scipy.stats as sts
import numpy as np
sm = cmdstanpy.CmdStanModel(stan_file="cens_reg.stan")
a = 0.4
b = 0.1
sigma = 0.2
N = 100
X = sts.norm.rvs(size=N)
X.sort()
Y_unc = sts.norm.rvs(loc=X, scale=sigma)
lb = -0.5
C = np.array([0 if y >= lb else 1 for y in Y_unc])
Y = np.array([y if c == 0 else lb for y, c in zip(Y_unc, C)])
sam = sm.sample(data={"X" : X, "Y" : Y, "N" : N, "C" : C})
azsam = az.from_cmdstanpy(sam)
loo = az.loo(azsam)
print(loo)
fig, (ax, bx) = plt.subplots(2, 1, figsize=(7,7))
ax.scatter(X, Y_unc, color='r', s=1, zorder=1)
ax.scatter(X, Y, color='k', s=2, zorder=2)
ax.set_xlabel("X")
ax.set_ylabel("Y")
xs = np.linspace(np.min(X), np.max(X), 100)
a_est = sam.stan_variable("a")
b_est = sam.stan_variable("b")
yss = np.array([[a*x + b for x in xs] for a, b in zip(a_est, b_est)])
lys, mys, uys = np.percentile(yss, axis=0, q=[2.5, 50, 97.5])
ax.plot(xs, mys)
ax.fill_between(xs, lys, uys, alpha=0.3)
az.plot_khat(loo, ax=bx, show_hlines=True)