It is often convenient to be able to get samples grouped by the chain they came from (e.g. to identify whether different chains have found different modes of the posterior or to evaluate Rhat). @ariddell pointed me to the forums to discuss this feature proposal and see whether there is interest in adding such functionality.
I wasn’t aware of the feature discussion on the forum prior to implementation, so here is a pull request. From the PR description:
- adds a function
get_samples(param: str, flatten_chains: bool = True) -> np.ndarrayto the
Fitobject which reproduces the behaviour of
__getitem__by default. However, setting
flatten_chains=Falseleaves the trailing dimension as
(num_saved_samples, num_chains). This behaviour is useful for comparing initialisations of different chains, computing Rhat, etc.
- adds a
chain__column to the data frame produced by
Fit.to_framewhich can be used to identify which chain generated a given sample.
Keen to hear your thoughts on whether this should be added to pystan.