Marginalizing over unobserved discrete value fails to converge

I am trying to model an experiment where the treatment depends on the roll of a die and an unknown “condition”. The outcome is whether the person recovers or not and depends on the treatment and condition.

In the data-generating Python script, I assume that recovery depends only on condition, not on treatment, and I want Stan to arrive at that conclusion too.

p_recover_given_condition_and_treatment gives the wrong answer according to the generated data. How can I fix this model?

from random import randint, random
from operator import itemgetter
import json

p_treatment2_given_condition_and_dieroll = {
    (1,1): .2,
    (1,2): .6,
    (2,1): .4,
    (2,2): .8,
}

p_recover_given_condition_and_treatment = {
    (1,1): .3,
    (1,2): .5,
    (2,1): .7,
    (2,2): .9,
}

N = 10000
D = 2
C = 2
T = 2

def make_observation():
    dieroll = randint(1,D)
    condition = randint(1,C)
    p_treatment2 = p_treatment2_given_condition_and_dieroll[condition,dieroll]
    treatment = (random()<p_treatment2)+1
    p_recover = p_recover_given_condition_and_treatment[condition,treatment]
    recover = int(random()<p_recover)
    return locals()

obs = [make_observation() for _ in range(N)]
d = dict(
    D=D,C=C,T=T,N=N,
    dieroll=list(map(itemgetter("dieroll"), obs)),
    treatment=list(map(itemgetter("treatment"), obs)),
    recover=list(map(itemgetter("recover"), obs)),
)
print(json.dumps(d, indent=2))

Marginalizing:

p(d : dieroll, t : treatment, r : recover) = \sum_{c:condition} p(r|c,t)p(t|c,d)p(c)p(d) \\ \ell(d,t,r) = \text{logsumexp}_{c:condition} \log p(r|c,t) + \log(p(t|c,d) + \log p(c) + \log p(d)
data {
  int<lower=0> N;
  int<lower=1> D;
  int<lower=1> T;
  int<lower=1> C;
  int<lower=1,upper=D> dieroll[N];
  int<lower=1,upper=T> treatment[N];
  int<lower=0,upper=1> recover[N];
}
parameters {
  simplex[D] p_dieroll;
  simplex[T] p_treatment_given_condition_and_dieroll[C,D];
  simplex[C] p_condition;
  real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
  for (n in 1:N) {
    dieroll[n] ~ categorical(p_dieroll);
    vector[C] lp;
    for (condition in 1:C) {
      lp[condition] = categorical_lpmf(condition | p_condition);
      treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[condition,dieroll[n]]);
      recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[condition,treatment[n]]);
    }
    target += log_sum_exp(lp);
  }
}

% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: my2_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (64, 65, 67, 62) seconds, 4.3 minutes total
Sampling took (55, 50, 55, 55) seconds, 3.6 minutes total

                                                  Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N_Eff/s    R_hat

lp__                                            -31799  5.9e-02  2.3e+00 -31803 -31798 -31796     1522      7.1      1.0
accept_stat__                                     0.89  1.0e-02     0.12   0.65   0.93    1.0  1.4e+02  6.3e-01  1.0e+00
stepsize__                                        0.67  5.4e-02    0.076   0.61   0.65   0.79  2.0e+00  9.3e-03  2.7e+13
treedepth__                                        2.9  5.5e-02     0.35    2.0    3.0    3.0  4.0e+01  1.9e-01  1.0e+00
n_leapfrog__                                       6.6  1.8e-01      1.2    3.0    7.0    7.0  4.2e+01  1.9e-01  1.0e+00
divergent__                                       0.00      nan     0.00   0.00   0.00   0.00      nan      nan      nan
energy__                                         31804  8.3e-02      3.2  31799  31803  31809  1.5e+03  6.9e+00  1.0e+00

p_dieroll[1]                                      0.51  5.8e-05  5.1e-03   0.50   0.51   0.52     7708       36     1.00
p_dieroll[2]                                      0.49  5.8e-05  5.1e-03   0.48   0.49   0.50     7708       36     1.00
p_treatment_given_condition_and_dieroll[1,1,2]    0.29  7.0e-05  6.5e-03   0.28   0.29   0.30     8569       40     1.00
p_treatment_given_condition_and_dieroll[1,2,2]    0.70  7.1e-05  6.6e-03   0.69   0.70   0.71     8437       39     1.00
p_treatment_given_condition_and_dieroll[2,1,2]    0.29  6.8e-05  6.2e-03   0.28   0.29   0.30     8308       39     1.00
p_treatment_given_condition_and_dieroll[2,2,2]    0.70  7.8e-05  6.5e-03   0.69   0.70   0.71     6885       32     1.00
p_condition[1]                                    0.50  3.2e-03  2.9e-01  0.045   0.50   0.95     8287       39     1.00
p_condition[2]                                    0.50  3.2e-03  2.9e-01  0.049   0.50   0.95     8287       39     1.00
p_recover_given_condition_and_treatment[1,1]      0.47  7.8e-05  7.1e-03   0.46   0.47   0.48     8426       39     1.00
p_recover_given_condition_and_treatment[1,2]      0.73  7.2e-05  6.2e-03   0.72   0.74   0.74     7247       34     1.00
p_recover_given_condition_and_treatment[2,1]      0.47  8.2e-05  7.1e-03   0.46   0.47   0.48     7444       35     1.00
p_recover_given_condition_and_treatment[2,2]      0.73  6.8e-05  6.3e-03   0.72   0.73   0.75     8660       40     1.00

If I’m reading your model/description correctly, there is an unobserved ‘condition’, and this influences the probability of recovery for each treatment?

So, given J treatments and K conditions, something like:

cond \sim \text{categorical}(\pi) \\ P_{recover} = \lambda_{treat_{[cond]}}\\ recover \sim Bern(P_{recover})

Is that right?

Or is it that the probability of which treatment is received depends on the condition? Like:

condition \sim \text{categorical}(\pi) \\ treat \sim \text{categorical}(\gamma_{cond}) \\ P_{recover} = \lambda_{treat_{j}}\\ recover \sim Bern(P_{recover})

It’s both of those:

  • Your treatment depends on your (unobserved) condition and the (observed) roll of a die
  • Your recovery depends on your (unobserved) condition and (observed) treatment

Alright, in that case I believe you have a model like so:

treat \sim \text{categorical}(\alpha_{[k]} + roll\cdot\beta_{roll}) \\ recovery \sim \text{bernoulli}(\gamma_{j[k]})

So the probability of being assigned to a treatment has an intercept which differs between the conditions, as well as by the dice roll (with the effect of the dice roll on treatment assignment as separate from condition).

Then, each treatment has a different probability of recovery, and within each treatment that probability also differs between conditions.

As some pseudo-Stan with your data, this kind of model would look like:


data {
  int<lower=0> N;
  int<lower=0> D; // Dice outcomes
  int<lower=0> J; // Treatments
  int<lower=0> K; // Conditions
  int rolls[N];
  int treatment[N];
  int recovery[N];
}

parameters {
  // Condition probability
  simplex[K] p_condition;

  // Condition-specific treatment probability means
  vector[K] p_treat_mu;
  // Treatment probabilities for each individual
  vector[J] p_treat[N];

  // Condition-specific recovery probability means
  vector[K] treat_beta_mu;

  // Treatment-specific recovery probability means
  vector[J] treat_beta;

  // Coefficient for effect of dice roll
  real roll_beta;
}

model {
  vector[K] log_p_condition = log(p_condition);

  p_treat_mu ~ std_normal();
  treat_beta_mu ~ std_normal();
  roll_beta ~ std_normal();

  for(n in 1:N) {
    // Intercept treatment probability for each individual comes from a N(treat_mu_k, 5) distribution,
    //   with condition probabilities defined by p_condition
    vector[K] lp;
    for(k in 1:K) {
      lp[k] = log_p_condition[k] + normal_lpdf(p_treat[n] | p_treat_mu[k], 5.0);
    }
    target += log_sum_exp(lp);

    // Estimate treatment log-prob given condition-specific mean as well dice-roll
    //   effect
    treatment[n] ~ categorical_logit_lpmf(p_treat[n] + rolls[n] * roll_beta);
  }

  for(j in 1:J) {
    // Intercept recovery probability for each treatment comes from a N(beta_mu_k, 5),
    //  again with probability p_condition
    vector[K] lp;
    for(k in 1:K) {
      lp[k] = log_p_condition[k] + normal_lpdf(treat_beta[j] | treat_beta_mu[k], 5.0);
    }
    target += log_sum_exp(lp);
  }

  // Probability of recovery for each treatment now differs by condition
  recovery ~ bernoulli_logit(treat_beta[treatment]);
}

Note that this example shows the conditions differing only in the means of their probabilities, not in the variances, and that I’ve treated the effect of the dice roll as continuous (for simplicity). It also more than likely has issues with identifiability, so this is more just a rough outline of how your model could look

3 Likes

Thanks for posting that. I’m still working on understanding and using your example.


In the meantime I’ve made some progress on my original model and updated the top post with what I’ve learned so far. The sampler is now converging but I think it’s giving the wrong answer with high precision. The main thing I’m uncertain about in my model is the line

lp[condition] = categorical_lpmf(condition | p_condition);

The other parts seem straightforward and hard to mess up.