I am trying to model an experiment where the treatment depends on the roll of a die and an unknown “condition”. The outcome is whether the person recovers or not and depends on the treatment and condition.
In the data-generating Python script, I assume that recovery depends only on condition, not on treatment, and I want Stan to arrive at that conclusion too.
p_recover_given_condition_and_treatment
gives the wrong answer according to the generated data. How can I fix this model?
from random import randint, random
from operator import itemgetter
import json
p_treatment2_given_condition_and_dieroll = {
(1,1): .2,
(1,2): .6,
(2,1): .4,
(2,2): .8,
}
p_recover_given_condition_and_treatment = {
(1,1): .3,
(1,2): .5,
(2,1): .7,
(2,2): .9,
}
N = 10000
D = 2
C = 2
T = 2
def make_observation():
dieroll = randint(1,D)
condition = randint(1,C)
p_treatment2 = p_treatment2_given_condition_and_dieroll[condition,dieroll]
treatment = (random()<p_treatment2)+1
p_recover = p_recover_given_condition_and_treatment[condition,treatment]
recover = int(random()<p_recover)
return locals()
obs = [make_observation() for _ in range(N)]
d = dict(
D=D,C=C,T=T,N=N,
dieroll=list(map(itemgetter("dieroll"), obs)),
treatment=list(map(itemgetter("treatment"), obs)),
recover=list(map(itemgetter("recover"), obs)),
)
print(json.dumps(d, indent=2))
Marginalizing:
data {
int<lower=0> N;
int<lower=1> D;
int<lower=1> T;
int<lower=1> C;
int<lower=1,upper=D> dieroll[N];
int<lower=1,upper=T> treatment[N];
int<lower=0,upper=1> recover[N];
}
parameters {
simplex[D] p_dieroll;
simplex[T] p_treatment_given_condition_and_dieroll[C,D];
simplex[C] p_condition;
real<lower=0,upper=1> p_recover_given_condition_and_treatment[C,T];
}
model {
for (n in 1:N) {
dieroll[n] ~ categorical(p_dieroll);
vector[C] lp;
for (condition in 1:C) {
lp[condition] = categorical_lpmf(condition | p_condition);
treatment[n] ~ categorical(p_treatment_given_condition_and_dieroll[condition,dieroll[n]]);
recover[n] ~ bernoulli(p_recover_given_condition_and_treatment[condition,treatment[n]]);
}
target += log_sum_exp(lp);
}
}
% stansummary train*.csv | grep -v 'p_treatment_given.*1\]'
Inference for Stan model: my2_model
4 chains: each with iter=(1000,1000,1000,1000); warmup=(0,0,0,0); thin=(1,1,1,1); 4000 iterations saved.
Warmup took (64, 65, 67, 62) seconds, 4.3 minutes total
Sampling took (55, 50, 55, 55) seconds, 3.6 minutes total
Mean MCSE StdDev 5% 50% 95% N_Eff N_Eff/s R_hat
lp__ -31799 5.9e-02 2.3e+00 -31803 -31798 -31796 1522 7.1 1.0
accept_stat__ 0.89 1.0e-02 0.12 0.65 0.93 1.0 1.4e+02 6.3e-01 1.0e+00
stepsize__ 0.67 5.4e-02 0.076 0.61 0.65 0.79 2.0e+00 9.3e-03 2.7e+13
treedepth__ 2.9 5.5e-02 0.35 2.0 3.0 3.0 4.0e+01 1.9e-01 1.0e+00
n_leapfrog__ 6.6 1.8e-01 1.2 3.0 7.0 7.0 4.2e+01 1.9e-01 1.0e+00
divergent__ 0.00 nan 0.00 0.00 0.00 0.00 nan nan nan
energy__ 31804 8.3e-02 3.2 31799 31803 31809 1.5e+03 6.9e+00 1.0e+00
p_dieroll[1] 0.51 5.8e-05 5.1e-03 0.50 0.51 0.52 7708 36 1.00
p_dieroll[2] 0.49 5.8e-05 5.1e-03 0.48 0.49 0.50 7708 36 1.00
p_treatment_given_condition_and_dieroll[1,1,2] 0.29 7.0e-05 6.5e-03 0.28 0.29 0.30 8569 40 1.00
p_treatment_given_condition_and_dieroll[1,2,2] 0.70 7.1e-05 6.6e-03 0.69 0.70 0.71 8437 39 1.00
p_treatment_given_condition_and_dieroll[2,1,2] 0.29 6.8e-05 6.2e-03 0.28 0.29 0.30 8308 39 1.00
p_treatment_given_condition_and_dieroll[2,2,2] 0.70 7.8e-05 6.5e-03 0.69 0.70 0.71 6885 32 1.00
p_condition[1] 0.50 3.2e-03 2.9e-01 0.045 0.50 0.95 8287 39 1.00
p_condition[2] 0.50 3.2e-03 2.9e-01 0.049 0.50 0.95 8287 39 1.00
p_recover_given_condition_and_treatment[1,1] 0.47 7.8e-05 7.1e-03 0.46 0.47 0.48 8426 39 1.00
p_recover_given_condition_and_treatment[1,2] 0.73 7.2e-05 6.2e-03 0.72 0.74 0.74 7247 34 1.00
p_recover_given_condition_and_treatment[2,1] 0.47 8.2e-05 7.1e-03 0.46 0.47 0.48 7444 35 1.00
p_recover_given_condition_and_treatment[2,2] 0.73 6.8e-05 6.3e-03 0.72 0.73 0.75 8660 40 1.00