This is an attempt to express more clearly my toy model involving condition, treatment, and recovery.
- Treatment depends on (unobserved) condition and the (observed) roll of a die
- Recovery depends on (unobserved) condition and (observed) treatment
To simplify, first assume all values are observed. The model is:
data {
int<lower=0> N;
int<lower=1> D;
int<lower=1> T;
int<lower=1> C;
int<lower=1,upper=C> condition[N];
int<lower=1,upper=D> dieroll[N];
int<lower=1,upper=T> treatment[N];
int<lower=0,upper=1> recover[N];
}
parameters {
simplex[D] p_dieroll;
simplex[T] p_treatment_given_condition_and_dieroll[C,D];
simplex[C] p_condition;
real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
for (n in 1:N) {
dieroll[n] ~ categorical(p_dieroll);
condition[n] ~ categorical(p_condition);
treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[condition[n],dieroll[n]]);
recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[condition[n],treatment[n]]);
}
}
Generate example data
from random import randint, random
from operator import itemgetter
import json
p_treatment2_given_condition_and_dieroll = {
(1,1): .2,
(1,2): .4,
(2,1): .6,
(2,2): .8,
}
p_recover_given_condition_and_treatment = {
(1,1): .3,
(1,2): .5,
(2,1): .7,
(2,2): .9,
}
N = 100
D = 2
C = 2
T = 2
def make_observation():
dieroll = randint(1,D)
condition = randint(1,C)
p_treatment2 = p_treatment2_given_condition_and_dieroll[condition,dieroll]
treatment = (random()<p_treatment2)+1
p_recover = p_recover_given_condition_and_treatment[condition,treatment]
recover = int(random()<p_recover)
return locals()
obs = [make_observation() for _ in range(N)]
d = dict(
D=D,C=C,T=T,N=N,
condition=list(map(itemgetter("condition"), obs)),
dieroll=list(map(itemgetter("dieroll"), obs)),
treatment=list(map(itemgetter("treatment"), obs)),
recover=list(map(itemgetter("recover"), obs)),
)
print(json.dumps(d, indent=2))
Sample and summarize
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: observed_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (24, 23, 24, 24) seconds, 1.6 minutes total
Sampling took (25, 24, 24, 25) seconds, 1.6 minutes total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -25049 5.3e-02 2.2e+00 -25053 -25048 -25046 1760 18 1.00
accept_stat__ 0.89 1.7e-03 0.11 0.67 0.93 1.0 4.6e+03 4.7e+01 1.0e+00
stepsize__ 0.68 2.9e-02 0.041 0.62 0.72 0.72 2.0e+00 2.1e-02 1.4e+13
treedepth__ 2.9 6.5e-03 0.35 2.0 3.0 3.0 3.0e+03 3.0e+01 1.0e+00
n_leapfrog__ 6.6 2.3e-02 1.2 3.0 7.0 7.0 2.5e+03 2.6e+01 1.0e+00
divergent__ 0.00 nan 0.00 0.00 0.00 0.00 nan nan nan
energy__ 25054 7.8e-02 3.1 25049 25053 25059 1.6e+03 1.6e+01 1.0e+00
p_dieroll[1] 0.50 5.1e-05 5.0e-03 0.49 0.50 0.51 9518 98 1.00
p_dieroll[2] 0.50 5.1e-05 5.0e-03 0.49 0.50 0.51 9518 98 1.00
p_treatment_given_condition_and_dieroll[1,1,2] 0.20 8.9e-05 8.0e-03 0.19 0.20 0.22 8129 84 1.00
p_treatment_given_condition_and_dieroll[1,2,2] 0.40 1.0e-04 9.8e-03 0.38 0.40 0.42 8746 90 1.00
p_treatment_given_condition_and_dieroll[2,1,2] 0.61 1.0e-04 9.8e-03 0.59 0.61 0.63 8785 90 1.00
p_treatment_given_condition_and_dieroll[2,2,2] 0.80 9.3e-05 7.9e-03 0.78 0.80 0.81 7291 75 1.00
p_condition[1] 0.50 5.5e-05 5.0e-03 0.49 0.50 0.51 8335 86 1.00
p_condition[2] 0.50 5.5e-05 5.0e-03 0.49 0.50 0.51 8335 86 1.00
p_recover_given_condition_and_treatment[1,1] 0.30 8.4e-05 7.7e-03 0.28 0.30 0.31 8511 87 1.00
p_recover_given_condition_and_treatment[1,2] 0.51 1.4e-04 1.3e-02 0.49 0.51 0.53 8075 83 1.00
p_recover_given_condition_and_treatment[2,1] 0.70 1.4e-04 1.2e-02 0.68 0.70 0.72 7612 78 1.00
p_recover_given_condition_and_treatment[2,2] 0.89 5.8e-05 5.4e-03 0.88 0.89 0.90 8686 89 1.00
The results show all of the parameter values are recovered correctly.
Now moving to the main model, condition
is unobserved. So I want to marginalize over it. I’m not sure if I did this correctly. The line lp[c] = categorical_lpmf(c | p_condition) + log_p_condition[c];
seems especially suspicious.
Regardless, the results show that the parameter estimates are incorrect and exclude the true values while having 0 divergences and R_hat 1.0. For instance, the true value of p_recover_given_condition_and_treatment[2,2]
is 0.9 but is estimated at 0.72-0.75.
Is the model specified correctly? Why is it giving the wrong answer? How can I fix it?
data {
int<lower=0> N;
int<lower=1> D;
int<lower=1> T;
int<lower=1> C;
int<lower=1,upper=D> dieroll[N];
int<lower=1,upper=T> treatment[N];
int<lower=0,upper=1> recover[N];
}
parameters {
simplex[D] p_dieroll;
simplex[T] p_treatment_given_condition_and_dieroll[C,D];
simplex[C] p_condition;
real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
vector[C] log_p_condition = log(p_condition);
for (n in 1:N) {
dieroll[n] ~ categorical(p_dieroll);
vector[C] lp;
for (c in 1:C) {
lp[c] = categorical_lpmf(c | p_condition) + log_p_condition[c]; // XXX Unsure about this.
treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[c,dieroll[n]]);
recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[c,treatment[n]]);
}
target += log_sum_exp(lp);
}
}
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: my2_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (66, 71, 67, 74) seconds, 4.6 minutes total
Sampling took (58, 59, 57, 58) seconds, 3.9 minutes total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -3.2e+04 5.4e-02 2.2e+00 -3.2e+04 -3.2e+04 -3.2e+04 1678 7.2 1.0
accept_stat__ 0.88 7.3e-03 0.14 0.62 0.93 1.0 3.9e+02 1.7e+00 1.0e+00
stepsize__ 0.60 2.3e-02 0.032 0.55 0.61 0.63 2.0e+00 8.6e-03 1.2e+13
treedepth__ 2.9 5.2e-03 0.27 2.0 3.0 3.0 2.8e+03 1.2e+01 1.0e+00
n_leapfrog__ 6.8 1.5e-02 0.77 7.0 7.0 7.0 2.7e+03 1.2e+01 1.0e+00
divergent__ 0.00 nan 0.00 0.00 0.00 0.00 nan nan nan
energy__ 31813 7.9e-02 3.1 31808 31813 31819 1.6e+03 6.8e+00 1.0e+00
p_dieroll[1] 5.1e-01 6.2e-05 5.0e-03 5.0e-01 5.1e-01 5.2e-01 6570 28 1.00
p_dieroll[2] 4.9e-01 6.2e-05 5.0e-03 4.8e-01 4.9e-01 5.0e-01 6570 28 1.00
p_treatment_given_condition_and_dieroll[1,1,2] 2.9e-01 6.9e-05 6.4e-03 2.8e-01 2.9e-01 3.0e-01 8505 37 1.00
p_treatment_given_condition_and_dieroll[1,2,2] 7.0e-01 7.7e-05 6.4e-03 6.9e-01 7.0e-01 7.1e-01 6906 30 1.00
p_treatment_given_condition_and_dieroll[2,1,2] 2.9e-01 7.1e-05 6.4e-03 2.8e-01 2.9e-01 3.0e-01 8171 35 1.00
p_treatment_given_condition_and_dieroll[2,2,2] 7.0e-01 7.7e-05 6.6e-03 6.9e-01 7.0e-01 7.1e-01 7434 32 1.00
p_condition[1] 4.9e-05 6.5e-07 4.9e-05 2.6e-06 3.4e-05 1.5e-04 5725 25 1.00
p_condition[2] 1.0e+00 6.5e-07 4.9e-05 1.0e+00 1.0e+00 1.0e+00 5726 25 1.00
p_recover_given_condition_and_treatment[1,1] 4.7e-01 8.1e-05 7.0e-03 4.6e-01 4.7e-01 4.8e-01 7339 32 1.00
p_recover_given_condition_and_treatment[1,2] 7.3e-01 7.4e-05 6.4e-03 7.2e-01 7.3e-01 7.5e-01 7400 32 1.00
p_recover_given_condition_and_treatment[2,1] 4.7e-01 7.8e-05 6.9e-03 4.6e-01 4.7e-01 4.8e-01 7843 34 1.00
p_recover_given_condition_and_treatment[2,2] 7.3e-01 7.3e-05 6.3e-03 7.2e-01 7.3e-01 7.5e-01 7538 32 1.00