Saving & reusing adaptation in cmdstanr

For mean number of divergences per chain (5000 draws, 300 * 4 chains) (and standard deviation)

group: (mean, sd)
warm sample: (3.40, 10.18)
warmup+sample: (3.57, 13.23)

import json
from gc import collect

import arviz as az
import cmdstanpy
import matplotlib.pyplot as plt
import pandas as pd


def get_inits(fit, checkpoint_name="init_checkpoint"):
    params_bool = [not item.endswith("__") for item in fit.column_names]
    params = [item for item in fit.column_names if not item.endswith("__")]
    names = []
    for i, init in enumerate([item[params_bool].tolist() for item in fit.draws(inc_warmup=True)[-1]], 1):
        names.append(checkpoint_name + f"_chain_{i}" + ".json")
        with open(names[-1], "w") as f:
            json.dump(dict(zip(params, init)), f)
    return names


def get_stepsize(fit):
    return fit.stepsize.tolist()


def get_metric(fit, checkpoint_name="metric_checkpoint"):
    names = []
    for i, metric in enumerate([item.tolist() for item in fit.metric], 1):
        names.append(checkpoint_name + f"_chain_{i}" + ".json")
        with open(names[-1], "w") as f:
            json.dump({"inv_metric": metric}, f)
    return names


schools_code = """
data {
  int<lower=0> J;         // number of schools
  real y[J];              // estimated treatment effects
  real<lower=0> sigma[J]; // standard error of effect estimates
}
parameters {
  real mu;                // population treatment effect
  real<lower=0> tau;      // standard deviation in treatment effects
  vector[J] eta;          // unscaled deviation from mu by school
}
transformed parameters {
  vector[J] theta = mu + tau * eta;        // school treatment effects
}
model {
  target += normal_lpdf(eta | 0, 1);       // prior log-density
  target += normal_lpdf(y | theta, sigma); // log-likelihood
}
"""


with open("schools_code.stan", "w") as f:
    print(schools_code, file=f)


model = cmdstanpy.CmdStanModel(stan_file="schools_code.stan", exe_file="schools_code.exe")


schools_data = {"J": 8,
                "y": [28,  8, -3,  7, -1,  1, 18, 12],
                "sigma": [15, 10, 16, 11,  9, 11, 10, 18]
               }


total_sum_divergences = []
summaries = []
for seed in range(300):
    fit_warmup = model.sample(
        iter_warmup=1000,
        seed=seed+1,
        iter_sampling=0,
        data=schools_data,
        save_warmup=True,
        show_progress="notebook",
    )

    inits = get_inits(fit_warmup, checkpoint_name=f"init_checkpoint_round_{seed}")
    stepsize = get_stepsize(fit_warmup)
    metric = get_metric(fit_warmup, checkpoint_name=f"metric_checkpoint_round_{seed}")

    fit_samples = model.sample(
        data=schools_data,
        seed=seed+2,
        inits=inits,
        step_size=stepsize,
        metric=metric,
        iter_warmup=0,
        adapt_engaged=False,
        iter_sampling=5000,
    )

    idata = az.from_cmdstanpy(fit_samples)
    sum_divergences = dict(enumerate(idata.sample_stats.diverging.sum(dim="draw").values.tolist()))

    total_sum_divergences.append(sum_divergences)
    summaries.append(az.summary(idata, var_names=["mu", "tau"]))
    del fit_samples
    del idata
    collect()


summary_min_max = pd.concat(summaries).reset_index().groupby("index").apply(lambda x: x.max(numeric_only=True) - x.min(numeric_only=True))

total_sum_divergences_all = []
summaries_all = []
for seed in range(300):
    fit = model.sample(
        data=schools_data,
        seed=seed+1,
        iter_warmup=1000,
        save_warmup=True,
        adapt_engaged=True,
        iter_sampling=5000,
    )

    idata = az.from_cmdstanpy(fit)
    sum_divergences = dict(enumerate(idata.sample_stats.diverging.sum(dim="draw").values.tolist()))

    total_sum_divergences_all.append(sum_divergences)
    summaries_all.append(az.summary(idata, var_names=["mu", "tau"]))
    del fit
    del idata
    collect()

print(
    pd.DataFrame(total_sum_divergences).values.ravel().mean(),
    pd.DataFrame(total_sum_divergences).values.ravel().std()
)


print(
    pd.DataFrame(total_sum_divergences_all).values.ravel().mean(),
    pd.DataFrame(total_sum_divergences_all).values.ravel().std()
)


plt.figure(figsize=(10,4), dpi=100)
plt.hist(pd.DataFrame(total_sum_divergences).values.ravel(), bins=100, label="fit: pre-warm sample");
plt.hist(pd.DataFrame(total_sum_divergences_all).values.ravel(), bins=100, histtype="step", label="fit: with warmup");
plt.yscale("log")
plt.xscale("log")
plt.legend()
plt.title("Number of divergences, 5000 draws per chain, 8-schools non-centered (300 x 4 chains)")
[spine.set_visible(False) for loc, spine in plt.gca().spines.items() if loc in ["top", "right"]]
plt.savefig("restart_divergences", dpi=300, bbox_inches="tight")


summary_min_max_all = pd.concat(summaries_all).reset_index().groupby("index").apply(lambda x: x.max(numeric_only=True) - x.min(numeric_only=True))

print(summary_min_max)

print(summary_min_max_all)
2 Likes