Psis-loo with censored observations

Hi all, although I found a very similar thread on the forum (Loo: calculating point-wise log-lik when using data augmentation for censored observations), I don’t think my question has been answered.
Consider the following model Y \sim \mathcal{N}(aX + b, \sigma), where some of the observations Y are left-censored. We can integrate out the unobserved data using normal_lcdf and we get the following Stan model (C[n] == 0 indicates that the nth observation is uncensored, and C[n] == 1 indicates that it is left-censored).

// cens_reg.stan
data {
int N;
array[N] real X;
array[N] real Y;
array[N] int C;
}

parameters {
real a;
real b;
real<lower=0> sigma;
}

transformed parameters {
array[N] real log_lik;
for ( n in 1:N ) {
if ( C[n] == 0 ) {
log_lik[n] = normal_lpdf(Y[n] | a * X[n] + b, sigma);
} else {
log_lik[n] = normal_lcdf(Y[n] | a * X[n] + b, sigma);
}
}
}

model {
target += log_lik;
}


I now want to compare this model with other models using PSIS-LOO. However, this leads to “very bad” pareto-k values. These correspond to some of the left-censored observations that are below the observation limit (equal to -0.5). So they have a log-probability close to 0. The following figure shows the (potentially) censored observations (black) and the “ground truth” values (red), together with the fitted linear model (blue). The bottom panel is made with plot_khat from the arviz package.

My question is: Does anyone know how to resolve this? Can we just ignore such warnings and trust the loo results?

This is the python code I used to create the figure:

import cmdstanpy
import arviz as az
import matplotlib.pyplot as plt
import scipy.stats as sts
import numpy as np

sm = cmdstanpy.CmdStanModel(stan_file="cens_reg.stan")

a = 0.4
b = 0.1
sigma = 0.2
N = 100

X = sts.norm.rvs(size=N)
X.sort()
Y_unc = sts.norm.rvs(loc=X, scale=sigma)

lb = -0.5

C = np.array([0 if y >= lb else 1 for y in Y_unc])
Y = np.array([y if c == 0 else lb for y, c in zip(Y_unc, C)])

sam = sm.sample(data={"X" : X, "Y" : Y, "N" : N, "C" : C})
azsam = az.from_cmdstanpy(sam)

loo = az.loo(azsam)
print(loo)

fig, (ax, bx) = plt.subplots(2, 1, figsize=(7,7))

ax.scatter(X, Y_unc, color='r', s=1, zorder=1)
ax.scatter(X, Y, color='k', s=2, zorder=2)
ax.set_xlabel("X")
ax.set_ylabel("Y")

xs = np.linspace(np.min(X), np.max(X), 100)
a_est = sam.stan_variable("a")
b_est = sam.stan_variable("b")
yss = np.array([[a*x + b for x in xs] for a, b in zip(a_est, b_est)])
lys, mys, uys = np.percentile(yss, axis=0, q=[2.5, 50, 97.5])
ax.plot(xs, mys)
ax.fill_between(xs, lys, uys, alpha=0.3)

az.plot_khat(loo, ax=bx, show_hlines=True)

1 Like

I’m confused. Left or right?

Can you show histograms of a) log_lik values and b) 1/exp(log_lik) values for the one with worst khat?

Usually censored observations are weakly informative, although there is of course more leverage in the extreme x values.

Apologies: that should be left-censored (I edited the post). This figure shows the log_lik and 1/exp(log_lik) for observations 1,…,10 (they are sorted by X, so these are the worst).

Can you show histograms of a) log_lik values and b) 1/exp(log_lik) values for the data point with the worst khat?