Although priors are set exactly as posterior, sampling from priors gives invalid predictions

I am trying to understand how to set priors correctly in a shifted log-normal mixed model.

Let’s consider the following model with one fixed predictor, Species as random intercept for the outcome as well as for the auxiliary parameters (+ parameters-correlations estimation):

library(ggplot2)
library(brms)

formula <- brms::brmsformula(
  Sepal.Length ~ Sepal.Width + (1|G|Species),
  ndt ~ 1 + (1|G|Species), 
  sigma ~ 1 + (1|G|Species),
  family = shifted_lognormal()
)

m <- brms::brm(formula, data = iris, refresh = 0, iter = 2000, seed = 33)
pp_check(m, nsamples = 500)

Here are the parameters of that model:

# Print Parameters
p <- as.data.frame(parameters::parameters(m, effects = "all", dispersion = TRUE))
p[c("Parameter", "Median", "MAD", "Rhat")]
                                     Parameter   Median     MAD Rhat
1                                  b_Intercept  1.3e+00 1.8e-01  1.0
2                            b_sigma_Intercept -2.6e+00 2.9e-01  1.0
3                              b_ndt_Intercept -2.8e+06 2.5e+06  1.3
4                                b_Sepal.Width  1.4e-01 1.5e-02  1.0
23                       sd_Species__Intercept  3.0e-01 1.9e-01  1.0
24                   sd_Species__ndt_Intercept  2.0e+00 1.8e+00  1.0
25                 sd_Species__sigma_Intercept  5.1e-01 3.5e-01  1.0
26       cor_Species__Intercept__ndt_Intercept -7.0e-03 5.7e-01  1.0
27     cor_Species__Intercept__sigma_Intercept  4.0e-01 5.4e-01  1.0
28 cor_Species__ndt_Intercept__sigma_Intercept  4.7e-03 6.0e-01  1.0

(Note that one of them has an Rhat > 1.1 and a very high dispersion, but the model seems to make sensical predictions nonetheless).

Problem: I want to understand how the priors can impact the model, and investigate more informative priors, which is tricky given the relatively abstract nature of that model.

However, I cannot sample from the model’s priors above (to see how the model performs with only priors), as it has flat priors. So I have to define informative priors manually.

So my approach would be to start with a model which priors are the same as the “working” posteriors found above. From there, I’d investigate how modulating them impact the model’s predictions. My assumption was that sampling from a model’s priors which are very informative priors will behave similarly as sampling from a model with the same priors - as posteriors.

So I went on to set priors following the results found above:

# To see default priors: get_prior(formula, data = iris)
priors <- brms::validate_prior(c(
  # Intercept -----------------
  set_prior('normal(1.3, 0.18)', class = 'Intercept'), # b_Intercept  
  set_prior('normal(-2.6, 0.29)', class = 'Intercept', dpar = 'sigma'), # b_sigma_Intercept
  set_prior('normal(0, 250000)', class = 'Intercept', dpar = 'ndt'), # b_ndt_Intercept
  
  # SDs -----------------
  set_prior('normal(0.3, 0.19)', class = 'sd', group = c("", "Species")), # sd_Species__Intercept
  set_prior('normal(2, 1.8)', class = 'sd', dpar = 'ndt', group = c("", "Species")), # sd_Species__ndt_Intercept
  set_prior('normal(0.51, 0.35)', class = 'sd', dpar = 'sigma', group = c("", "Species")), # sd_Species__sigma_Intercept
  
  # Fixed -----------------
  set_prior('normal(0.14, 0.015)', class = 'b', coef = "Sepal.Width") # b_Sepal.Width
), formula, data = iris)

priors
                prior     class        coef   group resp  dpar nlpar bound       source
               (flat)         b                                                 default
  normal(0.14, 0.015)         b Sepal.Width                                        user
    normal(1.3, 0.18) Intercept                                                    user
    normal(0, 250000) Intercept                            ndt                     user
   normal(-2.6, 0.29) Intercept                          sigma                     user
 lkj_corr_cholesky(1)         L                                                 default
 lkj_corr_cholesky(1)         L             Species                        (vectorized)
    normal(0.3, 0.19)        sd                                                    user
       normal(2, 1.8)        sd                            ndt                     user
   normal(0.51, 0.35)        sd                          sigma                     user
    normal(0.3, 0.19)        sd             Species                                user
    normal(0.3, 0.19)        sd   Intercept Species                        (vectorized)
       normal(2, 1.8)        sd             Species        ndt                     user
       normal(2, 1.8)        sd   Intercept Species        ndt             (vectorized)
   normal(0.51, 0.35)        sd             Species      sigma                     user
   normal(0.51, 0.35)        sd   Intercept Species      sigma             (vectorized)

The problem is that when sampling from the priors of such model, it gives either Inf or very very high values… I tried tinkering with the parameters but I cannot seem to set them properly to make predictions within a sensical range…

# Model fitting
m_priors <- brms::brm(formula, data = iris, 
                      prior = priors, sample_prior = "only", 
                      refresh = 0, seed = 33)
pp_check(m_priors, nsamples = 500)

      Estimate Est.Error Q2.5 Q97.5
 [1,]      Inf       NaN  1.9   Inf
 [2,]      Inf       NaN  1.8   Inf
 [3,]      Inf       NaN  1.8   Inf
 [4,]      Inf       NaN  1.8   Inf
 [5,]      Inf       NaN  2.0   Inf
 [6,]      Inf       NaN  2.1   Inf
 [7,]      Inf       NaN  1.9   Inf
 [8,]      Inf       NaN  1.9   Inf
 [9,]      Inf       NaN  1.8   Inf
[10,]      Inf       NaN  1.8   Inf

Any help or suggestions on how to set informative priors for such model is more than welcome!

Perhaps the marginal posterior distributions for the model parameters fail to capture the joint posterior well (e.g. if there are important covariances or nonlinear relationships among the parameters). Or perhaps normal approximations to the margins are really bad. It might be enlightening to check a pairs plot of the model posterior.

Additionally, it looks like you’ve set a prior of normal(0, 2.5e+05) for a parameter whose posterior was has median -2.8e+06 and mad 2.5e+06. So assuming that the normal approximation is ok-ish, your mean is off by about 3 million and your sd is off by an order of magnitude. However, that’s already getting ahead of things, because you don’t have convergence for that parameter, and given the absurd mean and sd the parameter is clearly unidentified in your model. I would advise dealing with that before proceeding, even if the model seems to be making sensible predictions.

3 Likes

I second that assessment (the rest of the reply as well, but that deals with specific issues of possible lack of convergence). The posterior distribution is a multidimensional distribution, while we (usually) specify unidimensional priors for each parameter, so the higher the dimension the more likely (more covariances) is that there is information missing there.

You can think of it another way: technically it doesn’t matter if the parameter values come from a “prior” or a “posterior”, you are just picking individual sets of parameters from a distribution and putting them through the same model. Up to sampling error, if these sets come from the same distribution they will give the same output.

2 Likes

I see, so essentially, traditional priors over parameters alone are not enough to “mimick” or rather reproduce a model. In other words, two models could have the same parameters, but different covariances which would result in different predictions. That’s an interesting thought! Though a bit despairing in regards to my issue :)

In regards to the model above, following your suggestions, I have tried to:

  1. estimate properly the weird parameter by increasing chains, iterations, delta, max_treedepth, as well as setting a starting value (following this).
inits <- list()
for(i in 1:5) inits[[i]] <- list(Intercept_ndt = -5)

m <- brms::brm(formula, data = iris, iter = 5000, chains = 5, 
               adapt_delta = 0.9, max_treedepth = 15, inits = inits,
               refresh = 0, seed = 33)

Unfortunately that doesn’t help and the parameter just keep getting immense:

      Parameter   Median     MAD Rhat
b_ndt_Intercept -2.6e+13 3.8e+13  1.8
  1. I went for getting the “pairs plot of the model posterior”. I think you are referring to:
pairs(m)

I’m not sure what to look for in this plot though… are there any particular or typical patterns that I should be wary of?

  1. I tried to set the priors to this bug number… but it won’t allow me to (the SD is too big)…
Compiling Stan program...
|
Semantic error in 'C:/Users/user/AppData/Local/Temp/RtmpWSLGjV/model-45c70fb5cbd.stan', line 98, column 52 to column 62:
   -------------------------------------------------
    96:    target += normal_lpdf(Intercept | 1.3, 0.18);
    97:    target += normal_lpdf(Intercept_sigma | -2.6, 0.29);
    98:    target += normal_lpdf(Intercept_ndt | -260000000, 3800000000);
                                                             ^
    99:    target += normal_lpdf(sd_1[1] | 0.3, 0.19)
   100:      - 1 * normal_lccdf(0 | 0.3, 0.19);
   -------------------------------------------------

Integer literal cannot be larger than 2_147_483_647.

Well… I’m not sure what options I have left 😅 Thanks though for your input so far!

Yes, that’s the gist of it. The posterior distribution is likely to include correlations between its dimensions, and even when they don’t it may not be well approximated by independent parametric distributions that match the marginal means and variances.

Priors are usually used to both directly include previous information you have about the parameters, and sometimes to constrain the problem based on properties of the model or other observations that can be expressed as a simple probability statement. The main point is that – as the name says – this is the best you can do a priori because you haven’t used any data to inform the parameters and it’s hard to know more than simple things about them beforehand (this technically doesn’t have to mean that was available before using the data, but that is not in the data itself, because once you do inference with the data there is nothing more you can infer from it).
The decision to use more informative priors usually needs justification by better prior information, or results which you realize are unrealistic (for instance, you estimate the density of a species to be more than the number of atoms in the solar system).

You can sample from uniform priors, the problem may be that you get values in the extremes of the distribution and get model output that is completely off of the data. That just means that your prior knowledge is not great, and that’s fine, that’s why you need the data.

There’s a conceptual problem with that. You can compare more or less informative priors (if for instance you have a direct point estimate of one or more parameters and include priors with less or more variance), but if you are using the posterior as prior you are already using all information you had from the data, so there’s probably nothing more you can learn from that data – i.e. if you get that posterior with some non-informative priors, you’re only going to get “more of the same” posterior and falsely increase your confidence in those parameter values by using the information from the data more than once by repeatedly making the priors more like the posterior.
So if you are interested in exploring more or less informative priors, I suggest that you do not use any information that comes directly from the data (I guess that’s not necessarily a hard rule, but what you describe is clearly redundant use of data).

1 Like

Thanks for your thorough reply! I guess I fully understood the problem of circularity, to not use the result of a model to inform its priors, but it’s not exactly what I had in mind in this issue, in which I merely wanted to “play around” with the toy example (no real data; nor real inference) to get an intuitive feeling of how the priors impact the predictions (as some of the parameters are referring to non-straightfoward properties of the model). Hence the goal of having a model that at leasts gets me non-invalid predictions… Hope it clarifies the whys and hows.

That said, another perspective on that would be that, still in the context of the toy example above, even though I have no very prior information on the priors, I do have a precise meta-prior that the model should make predictions that fall somewhat within the range of the original mpg (the response) data. I mean it could even be ± 3 times that range, but at least not Inf as it is the case currently. In essence, I’m just trying to fulfil this meta-prior and get a model which priors lead to a non-nonsensical outcome :)


Coming back to the example, I tried removing the specifications for the ndt shift parameter, as it is the one giving troubles to the estimation.

library(ggplot2)
library(brms)

formula <- brms::brmsformula(
  Sepal.Length ~ Sepal.Width + (1|G|Species),
  # ndt ~ 1 + (1|G|Species), 
  sigma ~ 1 + (1|G|Species),
  family = shifted_lognormal()
)

m <- brms::brm(formula, data = iris, refresh = 0, iter = 1000, seed = 33)
pp_check(m, nsamples = 500)

# Print Parameters
p <- as.data.frame(parameters::parameters(m, effects = "all", dispersion = TRUE))
p[c("Parameter", "Median", "MAD", "Rhat")]
                                 Parameter     Median        MAD     Rhat
1                              b_Intercept  0.6361720 0.54560866 1.002123
2                        b_sigma_Intercept -2.1542450 0.34805518 1.002864
3                            b_Sepal.Width  0.2278215 0.06767031 1.001253
16                   sd_Species__Intercept  0.4746815 0.30775069 1.001787
17             sd_Species__sigma_Intercept  0.3515750 0.28001718 1.000480
18 cor_Species__Intercept__sigma_Intercept  0.5471270 0.54522615 1.001268

Estimation went ok. So again, set the priors accordingly:

# To see default priors: get_prior(formula, data = iris)
priors <- brms::validate_prior(c(
  # Intercept -----------------
  set_prior('normal(0.63, 0.54)', class = 'Intercept'), # b_Intercept  
  set_prior('normal(-2.15, 0.35)', class = 'Intercept', dpar = 'sigma'), # b_sigma_Intercept
  # set_prior('normal(0, 250000)', class = 'Intercept', dpar = 'ndt'), # b_ndt_Intercept
  
  # SDs -----------------
  set_prior('normal(0.47, 0.31)', class = 'sd', group = c("", "Species")), # sd_Species__Intercept
  # set_prior('normal(2, 1.8)', class = 'sd', dpar = 'ndt', group = c("", "Species")), # sd_Species__ndt_Intercept
  set_prior('normal(0.35, 0.28)', class = 'sd', dpar = 'sigma', group = c("", "Species")), # sd_Species__sigma_Intercept
  
  # Fixed -----------------
  set_prior('normal(0.23, 0.07)', class = 'b', coef = "Sepal.Width") # b_Sepal.Width
), formula, data = iris)


# Model fitting
m_priors <- brms::brm(formula, data = iris, 
                      prior = priors, sample_prior = "only", 
                      refresh = 0, seed = 33)
pp_check(m_priors, nsamples = 500)

aaand it works! Sure the prediction is not the same as the one of the “real” model, but that could be related to all the caveats mentioned. But at least it gives me some predictions within a meaningful range (sure, some very extreme observations but that’s fine).

From there I can start refining the priors to see intuitively (or visually) what’s their effect.

One final issue would be that this whole thread was motivated because of the difficutly to understand priors and parameters in the context of somewhat more complex models, and ndt ~ 1 + (1|group) was precisely one of the parameter that I’d be interested in seeing in action.

I wonder why it is particularly hard to estimate it… (and it seems to happen also in other models with other data).

In conclusion, I’m not sure what can be done with the original “full” model, but lowering the complexity of the model seems to address the problem of invalid predictions…

You should try to find something on sensitivity analysis which takes looks at uncertainty from a somewhat different angle, starting for instance from a single set of parameters (e.g. a MAP point estimate) and perturbing them one at a time.
It seems like you could accomplish what you want using that kind of approach, instead of trying to set up a (multi-dimensional) prior distribution by hand that will not go off the charts.