Circular inference model

I have some problems with a model that might be interesting to some of you

We’re trying to re-implement a circular inference model for information integration in stan.

(see figure 2 (especially 2c and 2d) in https://www.nature.com/articles/ncomms14218?WT.feed_name=subjects_neuroscience for intuitions about the model)

Code + demo data on https://github.com/maltelau/circularinference_standemo

The core of the model is a weighting function F(information, weight) that gives the s shapes seem in the figures linked above, and the likelihood (with more intuitive parameter names here):

Choice ~ bernoulli_logit(
    intercept + 
    F(log_perceptual_information + loop_noise, weight_perceptual) +
    F(log_prior_information + loop_noise, weight_prior));

the log_*_information and the Choice are observed, while the the loop noise is the parameter:

loop_noise = F(loop_strength_perceptual .* log_perceptual_information, weight_perceptual) +
             F(loop_strength_prior .* log_prior_information, weight_prior);

We’d like to think we could improve on this by adding partial pooling over participants as well as including symptom strength in the model (as well as the ability to keep all the uncertainty in the model that we get from using stan in the first place)

A basic model with full pooling does fine with little divergences. It has some problems recovering the loop strengths at any useful precision, but I suspect the original model couldn’t really do that either; it’s just hidden because they only did point estimates.

Adding partial pooling on participants cost us some elegance (had to do logit transform, more on this). It slows down the model a little bit, but chains seem healthy, the handful of divergences don’t look too systematic, and it still gives us about 1/3 n_eff
(CircularInference_stamdemo_nosymptoms.stan)

# logit transform to [0,1] interval then transform further to [.5,1] interval
weight_perceptual = inv_logit(wSelf + wSelfP[Participant] + wSelfSymptoms * Symptoms)/2+.5;

(traceplots on the github above)

BUT, adding symptom strength totally destroys the sampling
(CircularInference_stamdemo_symptoms.stan))

loop_strength_perceptual = aSelf + aSelfP[Participant] + aSelfSymptoms * Symptoms;
> print(symptom_model, pars = c("fixed_effects", "symptom_effects", "participant_effects", "lp__"))
Warning messages:
1: There were 930 divergent transitions after warmup. Increasing adapt_delta above 0.99 may help. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
2: Examine the pairs() plot to diagnose sampling problems

Inference for Stan model: CircularInference_stamdemo_symptoms.
3 chains, each with iter=1000; warmup=500; thin=1; 
post-warmup draws per chain=500, total post-warmup draws=1500.

                            mean se_mean    sd    2.5%     25%     50%    75%  97.5% n_eff Rhat
fixed_effects[1]            0.37    0.04  0.52   -0.79    0.15    0.39   0.52   1.45   159 1.01
fixed_effects[2]            0.40    0.75  1.06   -1.57   -0.43    0.30   1.55   1.77     2 1.82
fixed_effects[3]            0.08    0.15  0.69   -1.56   -0.29    0.31   0.39   1.44    21 1.05
fixed_effects[4]            1.11    0.03  0.50    0.21    0.85    1.11   1.25   2.33   284 1.01
fixed_effects[5]            1.04    0.20  0.61    0.19    0.70    0.76   1.38   2.49     9 1.11
symptom_effects[1]         -0.04    0.00  0.06   -0.16   -0.07   -0.04  -0.02   0.07   212 1.01
symptom_effects[2]          0.63    0.46  0.58   -0.02    0.18    0.29   1.37   1.53     2 4.49
symptom_effects[3]         -0.51    0.68  0.84   -1.78   -1.65    0.01   0.12   0.28     2 9.44
symptom_effects[4]          0.28    0.20  0.41   -0.04   -0.02    0.15   0.41   1.53     4 1.21
symptom_effects[5]          0.69    0.07  0.57   -0.10    0.28    0.58   1.02   1.96    71 1.04
participant_effects[1,1]   -0.57    0.52  0.71   -1.60   -1.40   -0.33   0.00   0.58     2 2.11
participant_effects[1,2]    1.09    0.11  1.34   -1.04    0.40    1.01   1.44   4.53   137 1.02
participant_effects[1,3]   -0.81    0.16  0.97   -3.21   -1.08   -0.59  -0.47   0.92    35 1.05
participant_effects[1,4]    0.48    0.19  4.20   -0.81   -0.13   -0.01   0.28   3.69   501 1.01
participant_effects[1,5]    0.89    0.20  3.56   -0.91   -0.09    0.40   1.31   4.59   307 1.02
participant_effects[2,1]   -0.02    0.04  0.38   -0.81   -0.17   -0.07   0.13   0.84    98 1.05
participant_effects[2,2]   -0.30    0.10  1.20   -2.48   -0.62   -0.47   0.00   2.43   137 1.01
participant_effects[2,3]    0.09    0.30  1.03   -2.28   -0.45    0.27   0.65   1.97    12 1.08
participant_effects[2,4]    0.71    0.15  2.89   -0.89   -0.05    0.39   1.08   2.76   349 1.01
participant_effects[2,5]    0.46    0.29  4.32   -0.88   -0.42   -0.20   0.21   4.29   230 1.02
participant_effects[3,1]    0.08    0.19  0.42   -0.99   -0.09    0.13   0.37   0.67     5 1.19
participant_effects[3,2]   -0.10    0.68  1.32   -3.39   -0.79    0.04   0.66   1.51     4 1.32
participant_effects[3,3]    1.28    0.21  0.95   -0.03    0.76    0.97   1.69   3.63    21 1.09
participant_effects[3,4]    0.18    0.56  3.83   -0.93   -0.75   -0.15   0.29   3.85    47 1.02
participant_effects[3,5]    0.59    0.29  3.44   -0.84   -0.10    0.00   0.31   4.37   137 1.03
participant_effects[4,1]    0.11    0.02  0.34   -0.55   -0.04    0.04   0.23   0.98   308 1.03
participant_effects[4,2]   -0.98    0.05  0.77   -2.99   -1.14   -1.06  -0.55   0.32   207 1.02
participant_effects[4,3]   -0.85    0.48  0.86   -2.18   -1.58   -0.92  -0.16   0.74     3 1.33
participant_effects[4,4]    0.02    0.64  2.00   -1.18   -0.90   -0.18   0.23   3.73    10 1.07
participant_effects[4,5]    0.83    0.31  4.90   -0.78   -0.02    0.29   0.40   5.29   245 1.02
participant_effects[5,1]    0.62    0.60  0.79   -0.54   -0.01    0.28   1.54   1.76     2 2.58
participant_effects[5,2]   -0.16    0.62  1.01   -1.53   -1.05   -0.13   0.52   1.81     3 1.48
participant_effects[5,3]    0.22    0.46  0.82   -1.49   -0.34    0.34   0.74   1.53     3 1.32
participant_effects[5,4]    0.37    0.05  0.98   -0.80   -0.07    0.17   0.51   2.93   339 1.00
participant_effects[5,5]    0.85    0.23  3.88   -0.85   -0.08    0.38   0.98   3.92   276 1.02
lp__                     -122.16   31.82 39.63 -184.49 -172.99 -101.43 -91.51 -77.61     2 5.84

Samples were drawn using NUTS(diag_e) at Mon Oct 30 23:32:58 2017.

Things I tried that didn’t help:

  • Letting it run for 10000 iterations with warmup=8000
  • Fiddling with the priors between ~ normal(0,0.1) and normal(0, 10)
  • z-scaling symptom data
  • Participant pooling on the symptom parameter. This doesn’t really make sense though as each participant has just one value for the symptom strength.
  • Fiddling with max_treedepth and adapt_delta

Ideas I still vaguely feel like I haven’t explored enough

  • Put very tight priors on the loop strength parameters due to them flatting out so much when far from zero
  • Do away with the logit transform on the w parameters. This is all to work around the specific boundaries the function expects. We didn’t need this for the no-symptom model since we could do participant_effects ~ multi_normal(main_effects, sigma) and could just put bounds on the main effects and the participant effects parameters. I asked last month about working out those bounds, but now I can’t see how to do the same for three interdependent parameters (ie wSelf, wSelfP, wSelfSymptoms)

So, in summary: We have five “fixed effects” parameters (bias, two weights, two loop strength). We’d like to put glm-style participant and symptom effects on each of them. Adding symptom strength ruined the sampling. So that may be the issue, or perhaps that in combination with other quirks of the model.

Thoughts?

While I can’t give good advice for the specific model, I think you have not exhausted some standard tricks that have worked for me in the past. Sorry if that’s something you already know or have tried.

In particular, you should try simulating data from the exact model you have written in Stan and see if you can recover the (known) parameters (in the sense that the true parameters are within x% interval x% of the time for multiple xs). Since you had some divergencies even before adding symptom strength, it is likely the model was already broken without symptom strength, but it didn’t manifest strongly.

With simulated data you know the exact true values of the parameters so you can also try putting tight priors centered at the true values on all parameters and then gradually loosen them and see when the model starts misbehaving.

Also you should try the mcmc_parcoord function from bayesplot package, it can be more useful for diagnosis than the pairs plot (it can show you in what parts of the parameter space the divergences tend to be)…

1 Like

Thanks for the suggestions - I’ve been chipping away at this and seem to have arrived at the deeper issue:

  • I somewhat fixed the bounds, which means I could get rid of the inv_logit transform on w. Chains seems a lot better (Rhat <= 1.01)
  • Tried mcmc_parcoord which seems really nice, but in my case, the divergences are spread almost randomly over the parameter space.
  • Ran one simulation on the (now bounded and logit-less) symptom model. It recovered the weights (including symptom effect) okay-ish, but still has problems with the loop strenghts. Plots of results in simulation/

This is very similar to simulations I ran a while ago before trying to add symptom strength (recovered w, but didn’t move a at all). So, I guess now that I fixed the most obviously broken thing (the logit transform on the weights), the underlying issues are visible again: the loop strength parameter isn’t being affected by the data at all (don’t ask me why I didn’t make a plot like this before).

In effect, the goal of the model is trying to condition (a non-standard specification of) uncertainty/bias on symptom strength. Are there any obvious reasons why it isn’t doing anything in my model? (stan/CircularInference_stamdemo_symptoms_bounds.stan on github)

What did you mean when write that the bounds aren’t perfect? Is there a way to express the model so that every legal set of parameters that satisfies declared constraints has finite log density?

You can make your blocks much simpler with some of our new constructs.

transformed parameters {
    vector[N_params] fixed_effects = [a, wSelf, wOthers, aSwelf, aOthers]';

    vector[N_params] symptom_effects = [aSymptoms, wSelfSymptoms, ... ]';

    vector[N] intercept = a + aP[Participant] + aSymptoms * Symptoms;

    vector<lower=0.5, upper=1>[N] ws = 0.5 + 0.5 * inv_logit(wSelf + wSelfP[Participant] + wSelfSymptoms * Symptoms);

    vector<lower=0.5, upper=1>[N] wo = 0.5 + 0.5 * inv_logit(wOthers + wOthersP[Participant] + wOthersSymptoms * Symptoms);

    vector<lower=0>[N] as = aSelf   + aSelfP[Participant]   + aSelfSymptoms * Symptoms;

    vector<lower=0>[N] ao = aOthers + aOthersP[Participant] + aOthersSymptoms * Symptoms;

    vector[N] I = F(ao .* lOthers, wo) + F(as .* lSelf, ws); 

    vector[N] mu = intercept + F(lSelf + I, ws) +  F(lOthers + I, wo);

    vector[N_params] participant_effects[P];
    participant_effects[, 1] = aP;
    ...
    participant_effects[ ,5] = aOthersP;
}

I’d be inclined to try to rework the function F to take its argument w before the 0.5 + 0.5 * inv_logit(x) transform, then try to push that log through and convert internal sums into log_sum_exp. It’ll be more stable that way. But if stability isn’t an issue, not worth the effort if it works as is.

recovered w, but didn’t move a at all

Do you have control over L here:

vector F(vector L, vector w) {
  return log((w .* exp(L) + 1 - w)./((1 - w) .* exp(L) + w));
}

? A parameter not moving at all is super fishy.

Maybe do a generated quantities and record the maximum values of L that are being fed through F? Compare this with and without the symptom stuff? exp(x) gets big really quick.