Marginalizing over unobserved category

This is an attempt to express more clearly my toy model involving condition, treatment, and recovery.

  • Treatment depends on (unobserved) condition and the (observed) roll of a die
  • Recovery depends on (unobserved) condition and (observed) treatment

To simplify, first assume all values are observed. The model is:

data {
  int<lower=0> N;
  int<lower=1> D;
  int<lower=1> T;
  int<lower=1> C;
  int<lower=1,upper=C> condition[N];
  int<lower=1,upper=D> dieroll[N];
  int<lower=1,upper=T> treatment[N];
  int<lower=0,upper=1> recover[N];
}
parameters {
  simplex[D] p_dieroll;
  simplex[T] p_treatment_given_condition_and_dieroll[C,D];
  simplex[C] p_condition;
  real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
  for (n in 1:N) {
    dieroll[n] ~ categorical(p_dieroll);
    condition[n] ~ categorical(p_condition);
    treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[condition[n],dieroll[n]]);
    recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[condition[n],treatment[n]]);
  }
}
Generate example data
from random import randint, random
from operator import itemgetter
import json



p_treatment2_given_condition_and_dieroll = {
    (1,1): .2,
    (1,2): .4,
    (2,1): .6,
    (2,2): .8,
}

p_recover_given_condition_and_treatment = {
    (1,1): .3,
    (1,2): .5,
    (2,1): .7,
    (2,2): .9,
}

N = 100
D = 2
C = 2
T = 2


def make_observation():
    dieroll = randint(1,D)
    condition = randint(1,C)
    p_treatment2 = p_treatment2_given_condition_and_dieroll[condition,dieroll]
    treatment = (random()<p_treatment2)+1
    p_recover = p_recover_given_condition_and_treatment[condition,treatment]
    recover = int(random()<p_recover)
    return locals()



obs = [make_observation() for _ in range(N)]
d = dict(
    D=D,C=C,T=T,N=N,
    condition=list(map(itemgetter("condition"), obs)),
    dieroll=list(map(itemgetter("dieroll"), obs)),
    treatment=list(map(itemgetter("treatment"), obs)),
    recover=list(map(itemgetter("recover"), obs)),
)
print(json.dumps(d, indent=2))
Sample and summarize
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: observed_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (24, 23, 24, 24) seconds, 1.6 minutes total
Sampling took (25, 24, 24, 25) seconds, 1.6 minutes total

                                                  Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N_Eff/s    R_hat

lp__                                            -25049  5.3e-02  2.2e+00 -25053 -25048 -25046     1760       18     1.00
accept_stat__                                     0.89  1.7e-03     0.11   0.67   0.93    1.0  4.6e+03  4.7e+01  1.0e+00
stepsize__                                        0.68  2.9e-02    0.041   0.62   0.72   0.72  2.0e+00  2.1e-02  1.4e+13
treedepth__                                        2.9  6.5e-03     0.35    2.0    3.0    3.0  3.0e+03  3.0e+01  1.0e+00
n_leapfrog__                                       6.6  2.3e-02      1.2    3.0    7.0    7.0  2.5e+03  2.6e+01  1.0e+00
divergent__                                       0.00      nan     0.00   0.00   0.00   0.00      nan      nan      nan
energy__                                         25054  7.8e-02      3.1  25049  25053  25059  1.6e+03  1.6e+01  1.0e+00

p_dieroll[1]                                      0.50  5.1e-05  5.0e-03   0.49   0.50   0.51     9518       98     1.00
p_dieroll[2]                                      0.50  5.1e-05  5.0e-03   0.49   0.50   0.51     9518       98     1.00
p_treatment_given_condition_and_dieroll[1,1,2]    0.20  8.9e-05  8.0e-03   0.19   0.20   0.22     8129       84     1.00
p_treatment_given_condition_and_dieroll[1,2,2]    0.40  1.0e-04  9.8e-03   0.38   0.40   0.42     8746       90     1.00
p_treatment_given_condition_and_dieroll[2,1,2]    0.61  1.0e-04  9.8e-03   0.59   0.61   0.63     8785       90     1.00
p_treatment_given_condition_and_dieroll[2,2,2]    0.80  9.3e-05  7.9e-03   0.78   0.80   0.81     7291       75     1.00
p_condition[1]                                    0.50  5.5e-05  5.0e-03   0.49   0.50   0.51     8335       86     1.00
p_condition[2]                                    0.50  5.5e-05  5.0e-03   0.49   0.50   0.51     8335       86     1.00
p_recover_given_condition_and_treatment[1,1]      0.30  8.4e-05  7.7e-03   0.28   0.30   0.31     8511       87     1.00
p_recover_given_condition_and_treatment[1,2]      0.51  1.4e-04  1.3e-02   0.49   0.51   0.53     8075       83     1.00
p_recover_given_condition_and_treatment[2,1]      0.70  1.4e-04  1.2e-02   0.68   0.70   0.72     7612       78     1.00
p_recover_given_condition_and_treatment[2,2]      0.89  5.8e-05  5.4e-03   0.88   0.89   0.90     8686       89     1.00

The results show all of the parameter values are recovered correctly.

Now moving to the main model, condition is unobserved. So I want to marginalize over it. I’m not sure if I did this correctly. The line lp[c] = categorical_lpmf(c | p_condition) + log_p_condition[c]; seems especially suspicious.

Regardless, the results show that the parameter estimates are incorrect and exclude the true values while having 0 divergences and R_hat 1.0. For instance, the true value of p_recover_given_condition_and_treatment[2,2] is 0.9 but is estimated at 0.72-0.75.

Is the model specified correctly? Why is it giving the wrong answer? How can I fix it?

data {
  int<lower=0> N;
  int<lower=1> D;
  int<lower=1> T;
  int<lower=1> C;
  int<lower=1,upper=D> dieroll[N];
  int<lower=1,upper=T> treatment[N];
  int<lower=0,upper=1> recover[N];
}
parameters {
  simplex[D] p_dieroll;
  simplex[T] p_treatment_given_condition_and_dieroll[C,D];
  simplex[C] p_condition;
  real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
  vector[C] log_p_condition = log(p_condition);
  for (n in 1:N) {
    dieroll[n] ~ categorical(p_dieroll);
    vector[C] lp;
    for (c in 1:C) {
      lp[c] = categorical_lpmf(c | p_condition) + log_p_condition[c]; // XXX Unsure about this.
      treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[c,dieroll[n]]);
      recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[c,treatment[n]]);
    }
    target += log_sum_exp(lp);
  }
}
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: my2_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (66, 71, 67, 74) seconds, 4.6 minutes total
Sampling took (58, 59, 57, 58) seconds, 3.9 minutes total

                                                    Mean     MCSE   StdDev       5%      50%      95%    N_Eff  N_Eff/s    R_hat

lp__                                            -3.2e+04  5.4e-02  2.2e+00 -3.2e+04 -3.2e+04 -3.2e+04     1678      7.2      1.0
accept_stat__                                       0.88  7.3e-03     0.14     0.62     0.93      1.0  3.9e+02  1.7e+00  1.0e+00
stepsize__                                          0.60  2.3e-02    0.032     0.55     0.61     0.63  2.0e+00  8.6e-03  1.2e+13
treedepth__                                          2.9  5.2e-03     0.27      2.0      3.0      3.0  2.8e+03  1.2e+01  1.0e+00
n_leapfrog__                                         6.8  1.5e-02     0.77      7.0      7.0      7.0  2.7e+03  1.2e+01  1.0e+00
divergent__                                         0.00      nan     0.00     0.00     0.00     0.00      nan      nan      nan
energy__                                           31813  7.9e-02      3.1    31808    31813    31819  1.6e+03  6.8e+00  1.0e+00

p_dieroll[1]                                     5.1e-01  6.2e-05  5.0e-03  5.0e-01  5.1e-01  5.2e-01     6570       28     1.00
p_dieroll[2]                                     4.9e-01  6.2e-05  5.0e-03  4.8e-01  4.9e-01  5.0e-01     6570       28     1.00
p_treatment_given_condition_and_dieroll[1,1,2]   2.9e-01  6.9e-05  6.4e-03  2.8e-01  2.9e-01  3.0e-01     8505       37     1.00
p_treatment_given_condition_and_dieroll[1,2,2]   7.0e-01  7.7e-05  6.4e-03  6.9e-01  7.0e-01  7.1e-01     6906       30     1.00
p_treatment_given_condition_and_dieroll[2,1,2]   2.9e-01  7.1e-05  6.4e-03  2.8e-01  2.9e-01  3.0e-01     8171       35     1.00
p_treatment_given_condition_and_dieroll[2,2,2]   7.0e-01  7.7e-05  6.6e-03  6.9e-01  7.0e-01  7.1e-01     7434       32     1.00
p_condition[1]                                   4.9e-05  6.5e-07  4.9e-05  2.6e-06  3.4e-05  1.5e-04     5725       25     1.00
p_condition[2]                                   1.0e+00  6.5e-07  4.9e-05  1.0e+00  1.0e+00  1.0e+00     5726       25     1.00
p_recover_given_condition_and_treatment[1,1]     4.7e-01  8.1e-05  7.0e-03  4.6e-01  4.7e-01  4.8e-01     7339       32     1.00
p_recover_given_condition_and_treatment[1,2]     7.3e-01  7.4e-05  6.4e-03  7.2e-01  7.3e-01  7.5e-01     7400       32     1.00
p_recover_given_condition_and_treatment[2,1]     4.7e-01  7.8e-05  6.9e-03  4.6e-01  4.7e-01  4.8e-01     7843       34     1.00
p_recover_given_condition_and_treatment[2,2]     7.3e-01  7.3e-05  6.3e-03  7.2e-01  7.3e-01  7.5e-01     7538       32     1.00

This doesn’t look like a correct marginalization. To marginalize over the unobserved condition, we need to sum up the contributions to the likelihood associated with each possible value of condition. (Note that the possible values of condition are mutually exclusive, and mutually exclusive probabilities add.) Thus:

for (n in 1:N) {
 log_lik = vector[C];
 for (i in 1:C) {
  log_lik[i] = categorical_lpmf(i | p_condition) +
   categorical_lpmf(treatment[n] | p_treatment_given_condition_and_dieroll[i,  dieroll[n])  +
   bernoulli_lpmf(recover[n] | p_recover_given_condition_and_treatment[i, treatment[n]);
 }
 target += log_sum_exp(log_lik);
}
1 Like

That seems more understandable, but it too produces estimates that exclude the true values. I expected there to be more uncertainty than it’s giving, as well.

data {
  int<lower=0> N;
  int<lower=1> D;
  int<lower=1> T;
  int<lower=1> C;
  int<lower=1,upper=D> dieroll[N];
  int<lower=1,upper=T> treatment[N];
  int<lower=0,upper=1> recover[N];
}
parameters {
  simplex[D] p_dieroll;
  simplex[T] p_treatment_given_condition_and_dieroll[C,D];
  simplex[C] p_condition;
  real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
  for (n in 1:N) {
    dieroll[n] ~ categorical(p_dieroll);
    vector[C] log_lik;
    for (c in 1:C) {
      treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[c,dieroll[n]]);
      recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[c,treatment[n]]);
      log_lik[c] =
        categorical_lpmf(c | p_condition)
        + categorical_lpmf(treatment[n] | p_treatment_given_condition_and_dieroll[c,  dieroll[n]])
        + bernoulli_lpmf(recover[n] | p_recover_given_condition_and_treatment[c, treatment[n]]);
    }
    target += log_sum_exp(log_lik);
  }
}
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: my2_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.

Warmup took (102, 115, 125, 101) seconds, 7.4 minutes total
Sampling took (89, 90, 87, 90) seconds, 5.9 minutes total

                                                  Mean     MCSE   StdDev     5%    50%    95%    N_Eff  N_Eff/s    R_hat

lp__                                            -45357  5.1e-02  2.2e+00 -45362 -45357 -45354     1874      5.3      1.0
accept_stat__                                     0.88  2.4e-02     0.13   0.62   0.92    1.0  2.9e+01  8.2e-02  1.0e+00
stepsize__                                        0.59  3.5e-02    0.049   0.54   0.58   0.67  2.0e+00  5.6e-03  2.5e+13
treedepth__                                        2.9  3.8e-02     0.28    2.0    3.0    3.0  5.2e+01  1.5e-01  1.0e+00
n_leapfrog__                                       6.8  6.1e-02     0.81    7.0    7.0    7.0  1.7e+02  4.9e-01  1.0e+00
divergent__                                       0.00      nan     0.00   0.00   0.00   0.00      nan      nan      nan
energy__                                         45363  7.8e-02      3.1  45358  45362  45368  1.6e+03  4.6e+00  1.0e+00

p_dieroll[1]                                      0.50  6.0e-05  4.9e-03   0.49   0.50   0.51     6589       19     1.00
p_dieroll[2]                                      0.50  6.0e-05  4.9e-03   0.49   0.50   0.51     6589       19     1.00
p_treatment_given_condition_and_dieroll[1,1,2]    0.41  8.4e-05  6.2e-03   0.40   0.41   0.42     5403       15     1.00
p_treatment_given_condition_and_dieroll[1,2,2]    0.60  7.7e-05  6.2e-03   0.59   0.60   0.61     6526       18     1.00
p_treatment_given_condition_and_dieroll[2,1,2]    0.41  7.4e-05  6.3e-03   0.40   0.41   0.42     7059       20     1.00
p_treatment_given_condition_and_dieroll[2,2,2]    0.60  8.1e-05  6.2e-03   0.59   0.60   0.61     5871       16      1.0
p_condition[1]                                    0.50  3.8e-03  2.7e-01  0.072   0.50   0.93     4817       14     1.00
p_condition[2]                                    0.50  3.8e-03  2.7e-01  0.067   0.50   0.93     4817       14     1.00
p_recover_given_condition_and_treatment[1,1]      0.42  7.5e-05  6.4e-03   0.41   0.42   0.43     7087       20     1.00
p_recover_given_condition_and_treatment[1,2]      0.78  6.6e-05  5.4e-03   0.77   0.78   0.78     6573       18     1.00
p_recover_given_condition_and_treatment[2,1]      0.42  7.4e-05  6.5e-03   0.41   0.42   0.43     7699       22     1.00
p_recover_given_condition_and_treatment[2,2]      0.78  6.0e-05  5.2e-03   0.77   0.78   0.78     7477       21     1.00

Your two sampling statements inside the inner loop beginning treatment[n] ~ and recover[n] ~ should be deleted.

1 Like