I’m continuing a discussion from arviz github with @OriolAbril here for better exposure.
I would like to compare the out-of-sample predictive performance of some statistical models that I’ve been fitting and analysing with cmdstanpy and arviz. My data represents cell densities in a series of timecourse experiments. Each experiment is referred to as a ‘replicate’ and the goal is to predict unseen replicates.
When I tried to find the leave-one-replicate-out elpd using arviz’s loo
function, there were some pareto-k warnings, indicating that the pareto-smoothed importance sampling estimation method didn’t work in all cases. I was getting ready to do the cross-validation analysis manually, but then I came across the experimental arviz function reloo
, which seems to have solved my problem. Still I want to check if I’ve used it correctly as I’m not sure.
Here is the relevant part of my models’ generated quantitites block:
generated quantities {
vector[R] llik = rep_vector(0, R);
for (n in 1:N_test){
int r = replicate_test[n];
int c = clone[r];
...
real yhat_test =
yt(t_test[n], R0[r], mu, exp(log_kq[c]), exp(log_td[c]), exp(log_kd[c]));
llik[r] += lognormal_lpdf(y_test[n] | log(yhat_test), err_test);
}
}
The vector llik
stores the total log likelihood for replicates that appear in the test data, and is zero for other replicates. I think the total log likelihood is the right number to use for estimating leave-one-replicate-out lpd in t this case.
In order to use reloo I made the following subclass of arviz’s SamplingWrapper
:
class CustomSamplingWrapper(az.SamplingWrapper):
def __init__(self, msmts, priors, x_cols, **super_kwargs):
self.msmts = msmts
self.priors = priors
self.x_cols = x_cols
super(CustomSamplingWrapper, self).__init__(**super_kwargs)
def sample(self, data):
"""Call CmdStanModel.sample."""
return self.model.sample(data=data, **self.sample_kwargs)
def get_inference_data(self, mcmc):
"""Call arviz.from_cmdstanpy."""
return az.from_cmdstanpy(mcmc, **self.idata_kwargs)
def log_likelihood__i(self, excluded_obs, idata__i):
"""Get the out-of-sample log likelihoods from idata__i."""
ll = idata__i.log_likelihood["llik"]
return ll.where(ll != 0, drop=True)
def sel_observations(self, idx):
"""Construct a stan input where replicate idx is out-of-sample."""
original = get_stan_input(self.msmts, self.priors, self.x_cols)
m_test = self.msmts.loc[lambda df: df["replicate_fct"].eq(idx[0] + 1)]
m_train = self.msmts.drop(m_train.index)
d_test = original.copy()
d_test["t"] = m_train["day"].values
d_test["t_test"] = m_test["day"].values
d_test["y"] = m_train["y"].values
d_test["y_test"] = m_test["y"].values
d_test["replicate"] = m_train["replicate_fct"].values
d_test["replicate_test"] = m_test["replicate_fct"].values
d_test["N"] = len(m_train)
d_test["N_test"] = len(m_test)
return d_test, {}
This seems to work, in that the reloo function runs and gives pretty much the results I expected. However I’m not sure that I implemented sel_observations
and log_likelihood__i
in the way intended. I think the second return value of sel_observations
should be a potential stan input like the first, which log_likelihood__i
will use to find the test predictions, but I ended up not using it at all.
My questions are
- Is there a better way I could write my
sel_observations
orlog_liklihood__i
methods? - Could I have made my life easier with a
log_lik_fun
or similar? I couldn’t really work out from theSamplingWrapper
code how this is supposed to work.
Anyway thanks very much to the arviz developers for this useful feature!