Proposal: add method to get samples grouped by chain from pystan interface

It is often convenient to be able to get samples grouped by the chain they came from (e.g. to identify whether different chains have found different modes of the posterior or to evaluate Rhat). @ariddell pointed me to the forums to discuss this feature proposal and see whether there is interest in adding such functionality.

I wasn’t aware of the feature discussion on the forum prior to implementation, so here is a pull request. From the PR description:

  • adds a function get_samples(param: str, flatten_chains: bool = True) -> np.ndarray to the Fit object which reproduces the behaviour of __getitem__ by default. However, setting flatten_chains=False leaves the trailing dimension as (num_saved_samples, num_chains). This behaviour is useful for comparing initialisations of different chains, computing Rhat, etc.
  • adds a chain__ column to the data frame produced by Fit.to_frame which can be used to identify which chain generated a given sample.

Keen to hear your thoughts on whether this should be added to pystan.

The RStan 3 / PyStan 3 design document is useful background reading for this kind of discussion: User Interface Guidelines for Developers

It does look like this particular method of accessing draws was not considered.


I wonder if there isn’t a one-liner that involves creating a view into fit._values (shape: (num_sample_and_sampler_params + num_flat_params, num_draws, num_chains)) using existing public functions. Numpy has a lot of really powerful functions for this kind of thing.

Just as second vote that having the functionality to access samples by chain (not necessarily be default) is very useful for diagnosing an important class of computational problems!

1 Like

Yes, I think a simple reshaping of the output can do the job. That’s what the proposed change in the PR does: rather than flattening the (num_samples, num_chains) dimension, it allows the user to choose whether to flatten (the default behaviour to maintain backwards compatibility) or not (using the flatten_chains argument).

There is of course the question of whether this is a sufficiently common use case to add it to pystan or whether this should be handled by user code. My hunch is that it’s worth integrating because it has use for a number of settings (Rhat, studying different modes, diagnosing issues with sampling in different parts of the parameter space, etc.).

I do think waiting to see if others have the same problem is a good idea. There are probably a few other things we would want to consider before adding it as well.

For now, perhaps we could document how to accomplish the reshaping in a FAQ item. For scalars and vectors of length > 1, I think there could be a one- or two-line solution. I’d also be open to adding a utility function which could greatly simplify the existing __getitem__ code by giving us some useful array indexes to work with. Such a function might have other uses.

A very general guide to reshaping fit._values might also be valuable. Perhaps something using einops? Getting per-chain means and stds perhaps?

Yes, that’s what the changes proposed in the PR implement. The interface is

def get_samples(self, param: str, flatten_chains: bool = True):
    """
    Get samples for a given parameter.

    Args:
        param: Name of the parameter to get samples for.
        flatten_chains: Whether to combine all chains by flattening them.

    Returns:
        samples: Array of samples with shape `(stan_dimensions, num_chains * num_samples)` if
            `flatten_chains` is truthy and `(stan_dimensions, num_chains, num_samples)` if not.
    """

The call fit.get_samples(parameter_name, flatten_chains=True) is equivalent to fit[parameter_name] and __getitem__ simply calls get_samples internally.

Something we might also want to consider in the process is adapting the code such that __getitem__ doesn’t create a copy of the underlying memory. That’s of course not a problem for small models but might be more challenging for larger models/many calls to __getitem__.

Happy to wait and see other use cases/requirements.

It occurs to me that this could be implemented as a plugin. See Plugins — pystan 3.3.0 documentation for documentation.

That is, instead of returning the normal Fit instance, the customized Fit instance (the version in the PR on GitHub) could be returned.

The plugin needs no explicit configuration. You just need to pip install it. That, at least, is the theory.