For mean number of divergences per chain (5000 draws, 300 * 4 chains) (and standard deviation)
group: (mean, sd)
warm sample: (3.40, 10.18)
warmup+sample: (3.57, 13.23)
import json
from gc import collect
import arviz as az
import cmdstanpy
import matplotlib.pyplot as plt
import pandas as pd
def get_inits(fit, checkpoint_name="init_checkpoint"):
params_bool = [not item.endswith("__") for item in fit.column_names]
params = [item for item in fit.column_names if not item.endswith("__")]
names = []
for i, init in enumerate([item[params_bool].tolist() for item in fit.draws(inc_warmup=True)[-1]], 1):
names.append(checkpoint_name + f"_chain_{i}" + ".json")
with open(names[-1], "w") as f:
json.dump(dict(zip(params, init)), f)
return names
def get_stepsize(fit):
return fit.stepsize.tolist()
def get_metric(fit, checkpoint_name="metric_checkpoint"):
names = []
for i, metric in enumerate([item.tolist() for item in fit.metric], 1):
names.append(checkpoint_name + f"_chain_{i}" + ".json")
with open(names[-1], "w") as f:
json.dump({"inv_metric": metric}, f)
return names
schools_code = """
data {
int<lower=0> J; // number of schools
real y[J]; // estimated treatment effects
real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
real mu; // population treatment effect
real<lower=0> tau; // standard deviation in treatment effects
vector[J] eta; // unscaled deviation from mu by school
}
transformed parameters {
vector[J] theta = mu + tau * eta; // school treatment effects
}
model {
target += normal_lpdf(eta | 0, 1); // prior log-density
target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""
with open("schools_code.stan", "w") as f:
print(schools_code, file=f)
model = cmdstanpy.CmdStanModel(stan_file="schools_code.stan", exe_file="schools_code.exe")
schools_data = {"J": 8,
"y": [28, 8, -3, 7, -1, 1, 18, 12],
"sigma": [15, 10, 16, 11, 9, 11, 10, 18]
}
total_sum_divergences = []
summaries = []
for seed in range(300):
fit_warmup = model.sample(
iter_warmup=1000,
seed=seed+1,
iter_sampling=0,
data=schools_data,
save_warmup=True,
show_progress="notebook",
)
inits = get_inits(fit_warmup, checkpoint_name=f"init_checkpoint_round_{seed}")
stepsize = get_stepsize(fit_warmup)
metric = get_metric(fit_warmup, checkpoint_name=f"metric_checkpoint_round_{seed}")
fit_samples = model.sample(
data=schools_data,
seed=seed+2,
inits=inits,
step_size=stepsize,
metric=metric,
iter_warmup=0,
adapt_engaged=False,
iter_sampling=5000,
)
idata = az.from_cmdstanpy(fit_samples)
sum_divergences = dict(enumerate(idata.sample_stats.diverging.sum(dim="draw").values.tolist()))
total_sum_divergences.append(sum_divergences)
summaries.append(az.summary(idata, var_names=["mu", "tau"]))
del fit_samples
del idata
collect()
summary_min_max = pd.concat(summaries).reset_index().groupby("index").apply(lambda x: x.max(numeric_only=True) - x.min(numeric_only=True))
total_sum_divergences_all = []
summaries_all = []
for seed in range(300):
fit = model.sample(
data=schools_data,
seed=seed+1,
iter_warmup=1000,
save_warmup=True,
adapt_engaged=True,
iter_sampling=5000,
)
idata = az.from_cmdstanpy(fit)
sum_divergences = dict(enumerate(idata.sample_stats.diverging.sum(dim="draw").values.tolist()))
total_sum_divergences_all.append(sum_divergences)
summaries_all.append(az.summary(idata, var_names=["mu", "tau"]))
del fit
del idata
collect()
print(
pd.DataFrame(total_sum_divergences).values.ravel().mean(),
pd.DataFrame(total_sum_divergences).values.ravel().std()
)
print(
pd.DataFrame(total_sum_divergences_all).values.ravel().mean(),
pd.DataFrame(total_sum_divergences_all).values.ravel().std()
)
plt.figure(figsize=(10,4), dpi=100)
plt.hist(pd.DataFrame(total_sum_divergences).values.ravel(), bins=100, label="fit: pre-warm sample");
plt.hist(pd.DataFrame(total_sum_divergences_all).values.ravel(), bins=100, histtype="step", label="fit: with warmup");
plt.yscale("log")
plt.xscale("log")
plt.legend()
plt.title("Number of divergences, 5000 draws per chain, 8-schools non-centered (300 x 4 chains)")
[spine.set_visible(False) for loc, spine in plt.gca().spines.items() if loc in ["top", "right"]]
plt.savefig("restart_divergences", dpi=300, bbox_inches="tight")
summary_min_max_all = pd.concat(summaries_all).reset_index().groupby("index").apply(lambda x: x.max(numeric_only=True) - x.min(numeric_only=True))
print(summary_min_max)
print(summary_min_max_all)