Arviz.reloo: am I doing it right?

I’m continuing a discussion from arviz github with @OriolAbril here for better exposure.

I would like to compare the out-of-sample predictive performance of some statistical models that I’ve been fitting and analysing with cmdstanpy and arviz. My data represents cell densities in a series of timecourse experiments. Each experiment is referred to as a ‘replicate’ and the goal is to predict unseen replicates.

When I tried to find the leave-one-replicate-out elpd using arviz’s loo function, there were some pareto-k warnings, indicating that the pareto-smoothed importance sampling estimation method didn’t work in all cases. I was getting ready to do the cross-validation analysis manually, but then I came across the experimental arviz function reloo, which seems to have solved my problem. Still I want to check if I’ve used it correctly as I’m not sure.

Here is the relevant part of my models’ generated quantitites block:

generated quantities {
  vector[R] llik = rep_vector(0, R);
  for (n in 1:N_test){
    int r = replicate_test[n];
    int c = clone[r];
    ...
    real yhat_test =
      yt(t_test[n], R0[r], mu, exp(log_kq[c]), exp(log_td[c]), exp(log_kd[c]));

    llik[r] += lognormal_lpdf(y_test[n] | log(yhat_test), err_test);
  }
}

The vector llik stores the total log likelihood for replicates that appear in the test data, and is zero for other replicates. I think the total log likelihood is the right number to use for estimating leave-one-replicate-out lpd in t this case.

In order to use reloo I made the following subclass of arviz’s SamplingWrapper:

class CustomSamplingWrapper(az.SamplingWrapper):
    def __init__(self, msmts, priors, x_cols, **super_kwargs):
        self.msmts = msmts
        self.priors = priors
        self.x_cols = x_cols
        super(CustomSamplingWrapper, self).__init__(**super_kwargs)

    def sample(self, data):
        """Call CmdStanModel.sample."""
        return self.model.sample(data=data, **self.sample_kwargs)

    def get_inference_data(self, mcmc):
        """Call arviz.from_cmdstanpy."""
        return az.from_cmdstanpy(mcmc, **self.idata_kwargs)

    def log_likelihood__i(self, excluded_obs, idata__i):
        """Get the out-of-sample log likelihoods from idata__i."""
        ll = idata__i.log_likelihood["llik"]
        return ll.where(ll != 0, drop=True)

    def sel_observations(self, idx):
        """Construct a stan input where replicate idx is out-of-sample."""
        original = get_stan_input(self.msmts, self.priors, self.x_cols)
        m_test = self.msmts.loc[lambda df: df["replicate_fct"].eq(idx[0] + 1)]
        m_train = self.msmts.drop(m_train.index)
        d_test = original.copy()
        d_test["t"] = m_train["day"].values
        d_test["t_test"] = m_test["day"].values
        d_test["y"] = m_train["y"].values
        d_test["y_test"] = m_test["y"].values
        d_test["replicate"] = m_train["replicate_fct"].values
        d_test["replicate_test"] = m_test["replicate_fct"].values
        d_test["N"] = len(m_train)
        d_test["N_test"] = len(m_test)
        return d_test, {}

This seems to work, in that the reloo function runs and gives pretty much the results I expected. However I’m not sure that I implemented sel_observations and log_likelihood__i in the way intended. I think the second return value of sel_observations should be a potential stan input like the first, which log_likelihood__i will use to find the test predictions, but I ended up not using it at all.

My questions are

  • Is there a better way I could write my sel_observations or log_liklihood__i methods?
  • Could I have made my life easier with a log_lik_fun or similar? I couldn’t really work out from the SamplingWrapper code how this is supposed to work.

Anyway thanks very much to the arviz developers for this useful feature!

1 Like

Everything looks good, and the best indicator is probably that the function runs.

This should be defined by how were the pointwise log likelihood values originally calculated to then call loo so I can’t really comment on that without knowing and understanding the model. As an outsider, the code does make sense.

What I can offer is a bit of advise to be completely sure it is working, at the cost of adding one extra refit to the list. Like it is done in the sampling wrappers documentation, you can manually modify a good khat value to be larger than the threshold. reloo will then calculate the exact CV for that replicate. Given that it was a good khat, the loo_i value corresponding to that replicate will be virtually the same in both the original loo object and the reloo one. If it were not it would mean that the calculation of llik for the test values is incorrect.

You can check that in the docs example, where all khat are lower than 0.7 and we manually increase 4 so that reloo is triggered and refits the model 4 times. You can use plot_elpd to compare the pointwise elpd values for the original loo and the reloo output and see that indeed, the PSIS approximation is really close to the exact value for all 4 values:

az.plot_elpd({"orig": loo_orig, "reloo": loo_relooed})

image

This is perfectly right, the class should ideally allow to wrap sampling from all the libraries that integrate with ArviZ, so sel_observations splits the data between train and test. The first argument is passed to the function that refits the model on the new train data, and the second is passed to the function that calculates the pointwise log likelihood on the test data.

In many libraries, the calculation of the log lik on the test data is done by taking parameters from the posterior to define the likelihood function and calling it on the test data.

In Stan however, the data is split between train and test but both are merged into the same dict, because the log lik calculation happens during fitting, thus, the function that refits the model needs also the test data. It is then perfectly fine to pass a dummy variable as the 2nd argument, as long as log_likelihood__i knows about it and ignores it.

We haven’t gotten to writing a pseudogeneric cmdstanpy wrapper class to act as a template, but we do have one for pystan. You can see in its example that the 2nd argument is only a string indicating the name of the stan variable that stores the log lik on the test data. My hope is that this will be general enough so that most users will be able to subclass the pystan wrapper and only write the sel_observations method like it is done in the example. In your case you need the where to drop the zeros so it would not actually help, log_likelihood__i has to be rewritten anyways, but I hope it gives some light into why are there unused returns and arguments lying around.

They look fine to me. My only comment would be that it does not look like get_stan_input needs to be called for every refit, you could probably have original and msmts as attrs and skip the 1st line in sel_observations, so nothing that actually matters.

This is actually a question that I still ask myself, and it may even be a matter of personal preference or depend on the model. The docs do have the same reloo example using pystan written twice, one calculating the test log lik in the stan code and another using a python function and leveraging xarray so you don’t even need to care about broadcasting.

At least until there is some feedback on using both approaches, and if there is a clear preference for one of the approaches, I plan on having both alternatives on the docs. Also, due to the python class nature of the wrappers, both will continue to work, you’d just have to overwrite the right method of the wrapper.

3 Likes

Thanks a lot for the detailed response. I hadn’t seen those examples in the documentation so that’s really helpful. Also great to get an insight into the thinking behind the interface!

1 Like