I had some thoughts about non-centered parameterisations and how they can prevent the sampler from adapting to the posterior geometry. And I suspect there might be a simple fix for that.
Suppose we have 100 measurements in 10 groups and want to estimate the mean. That’s a simple hierarchical model with 10 groups and 10 observation in each group. Parameters are the global mean , the group offsets and the standard deviation of the group offsets. Let’s say the dependent variable is poisson-distributed, so we don’t have to worry about estimating the sd of the observation level error. This makes 12 parameters in total.
Here’s some R-Code to create example data:
library(tidyverse)
library(rstan)
library(tidybayes)
set.seed(2)
nGroups=10
nWithinGroup=10
globalMean=10
groupSd=.1
groupOffsets=rnorm(nGroups,0,groupSd)
data=
crossing(
groupId=seq_len(nGroups)
) %>%
mutate(
datapointId=seq_len(n())
,linearPredictor=globalMean+groupOffsets[groupId]
,y=rpois(n(),exp(linearPredictor))
)
data %>%
ggplot()+
geom_point(
aes(
datapointId
,y
)
)
And the resulting graphic showing the data:
Here’s the code of the stan model:
data {
int nGroups;
int nWithinGroup;
int groupId[nGroups*nWithinGroup];
int y[nGroups*nWithinGroup];
}
parameters {
real globalMean;
real <lower=0> groupSd;
vector [nGroups] groupOffsets_offCenter;
}
transformed parameters {
// variable definitions
real linearPredictor[nGroups*nWithinGroup];
vector [nGroups] groupOffsets;
// off-center parameterisation
groupOffsets=groupOffsets_offCenter*groupSd;
// linear predictor
for(currDatapointId in 1:(nGroups*nWithinGroup))
{
linearPredictor[currDatapointId]=
globalMean+
groupOffsets[groupId[currDatapointId]];
}
}
model {
// high level priors
// use implicit improper uniform priors
// latent terms
groupOffsets_offCenter~normal(0,1);
// likelihood
y~poisson(exp(linearPredictor));
}
This code fits the model and makes some graphics of the results:
# fit model
post=
stan(
"stanmodel1.stan"
,data=
list(
nGroups=nGroups
,nWithinGroup=nWithinGroup
,groupId=data$groupId
,y=data$y
)
)
# extract some results
postExtractions=
post %>%
spread_draws(
globalMean
,`groupOffsets_offCenter[1]`
,`groupOffsets[1]`
,groupSd
,divergent__
,treedepth__
,lp__
)
# plot
par(bg="#303030")
postExtractions %>%
mutate( # add some spread to the discrete variables to reduce overdraw in the plot
.chain=.chain+runif(n(),-.3,.3)
,divergent__=divergent__+runif(n(),-.3,.3)
,treedepth__=treedepth__+runif(n(),-.3,.3)
) %>%
# select( # activate this section to plot only the most interesting variables
# globalMean
# ,`groupOffsets_offCenter[1]`
# ,`groupOffsets[1]`
# ,groupSd
# ) %>%
{pairs(
.
,col=hsv({ # colourize the points according to the estimated standard deviation of the group offsets
a=(.$groupSd)
a=a-min(a)
a=a/max(a)
a=a/3
})
,pch=46
,upper.panel=NULL
)}
par(bg="white")
Here’s the resulting full model diagnostic plot (just for the sake of completeness) (yes, i know there are 4 divergences):
And here a closer look at the most interesting variables:
The samples are colored according to the estimated standard deviation of the group offsets.
Only the first of the group offsets is shown, both before (groupOffsets_offCenter[1]) and after (groupOffsets[1]) the non-centering transform.
The plot shows that In this model there is a strong linear correlation between the estimated mean and the group offsets (the panel on the left in the middle). This should not be a problem for stan, as it can adapt the mass matrix to take that into account. However, since I have introduced a non-centered parameterisation for the group offsets, stan this is not actually the posterior shape that stan sees internally, instead what the sampler has to deal with is the shape shown in the panel in the upper left. The coloring of the samples shows that the correlation between the mean and “groupOffsets_offCenter[1]” depends on the estimate of “groupSd”. As far as I’m aware stan cannot adapt to such a situation, where the correlation between two parameters changes depending on a third parameter.
(I suspect this is why this particular example model samples badly and throws divergences. And indeed, the model can be “repaired” by reparameterizing “globalMean” such that is is no longer correlated with the group offsets.)
I think that there are many models out there where variables are simply linearly correlated, and where the non-centering transform changes this into a type of correlation that stan can’t adapt to. On the opposite I suspect that the number of cases where the non-centering changes the correlation structure into something that is more favorable for the adaptation is much smaller in practice. Granted, I have selected this example dataset specifically to pronounce the issue, but I still think it’s a common problem in practice.
This lead me to the core idea of this post. Since stan nowadays has dedicated syntax for declaring non-centering transformations (as multipliers in the variable declarations in the parameters block), it is possible to change the way these are handled internally.
My Idea was to make it such that the effect the multiplier has on the correlations is respected by the sampler, possibly as a per-parameter opt-in setting declared in the parameters block if there are concerns about applying it globaly (e.g. because of backwards compatiblilty).
I’m not deep enough into the math to be sure how simple an implementation would be, if possible. But I suspect it could be very simple. I had the though that maybe simply multiplying/dividing the row/column of the mass matrix (except for the diagonal element) that corresponds to the non-centered variable with its multiplier might already be enough. But I could be on the wrong track with this.
So my questions are:
Is it possible?
Do you think it is valuable?
edit: I should note that I’m making a few assumptions about how Stan works in this post, and I’m not completely 100 % sure about their correctness, just 99 %. If any of those are wrong I apologise and would welcome corrections.
(Some background info on my motivation to write this post:
I often encounter this issue in the form of correlations between group offsets / gaussian fields / ect. and observation level errors. Recently I’m stuggling with it in the form of correlations between the two parallel markov chains in a structural timeseries model (one chain representing random fluctuations one representing a systematic trend). The latter example causes me a lot of headache at the moment, because I need the non-centered parameterisations, but couldn’t find a reparameterisation that removes all the offending correlations.)