import cmdstanpy
import os
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sn
import scipy.stats
import qute
sample_dir = os.path.join('~', 'samples')
gen_t_code = """
data {
int nsamp;
real<lower=0> nu;
}
transformed data {
matrix[2,2] S = [[1, 0], [0, 1]];
vector[2] mu = [0, 0]';
}
parameters {}
model {}
generated quantities {
array[nsamp] vector[2] mult_t;
array[nsamp] real uni_t;
mult_t = multi_student_t_rng(nu, rep_array(mu, nsamp), S);
uni_t = student_t_rng(nu, rep_array(0, nsamp), 1);
}
"""
stanfg = os.path.join('stan_models', 'gen_minimal_t.stan')
qute.write_check(stanfg, gen_t_code, False)
gmdl = cmdstanpy.CmdStanModel(stan_file=stanfg)
INFO:cmdstanpy:compiling stan program, exe file: /mnt/growler/barleyhome/bfiles/qutecollection1/analysis/Ben/stan_models/gen_minimal_t INFO:cmdstanpy:compiler options: stanc_options=None, cpp_options=None
No model found, writing
INFO:cmdstanpy:compiled model file: /mnt/growler/barleyhome/bfiles/qutecollection1/analysis/Ben/stan_models/gen_minimal_t
nsamp = 100000
nu = 5
prpr = gmdl.sample(data={'nsamp':nsamp, 'nu':nu}, chains=1, iter_sampling=1, seed=37,
fixed_param=True, output_dir=sample_dir)
INFO:cmdstanpy:start chain 1 INFO:cmdstanpy:finish chain 1
gaz = qute.cs_to_az(prpr, ['mult_t', 'uni_t'], p_coords={'samp':np.arange(nsamp), 'block':[0, 1]}, p_dims={'mult_t':['samp', 'block'], 'uni_t':['samp']})
gaz
WARNING:cmdstanpy:method "sampler_diagnostics" will be deprecated, use method "sampler_variables" instead.
<xarray.Dataset> Dimensions: (block: 2, samp: 100000) Coordinates: * samp (samp) int64 0 1 2 3 4 5 6 ... 99994 99995 99996 99997 99998 99999 * block (block) int64 0 1 Data variables: mult_t (samp, block) float64 2.183 -0.9085 1.489 ... -0.8471 1.197 -2.52 uni_t (samp) float64 -3.206 1.225 -0.8353 1.364 ... -0.6722 0.6454 -2.398 Attributes: created_at: 2021-08-04T15:59:30.257722 arviz_version: 0.11.2
array([ 0, 1, 2, ..., 99997, 99998, 99999])
array([0, 1])
array([[ 2.18265 , -0.908539], [ 1.48948 , -0.820367], [-1.57917 , -0.762235], ..., [-0.400004, 1.30235 ], [ 1.18092 , -0.84711 ], [ 1.19682 , -2.51958 ]])
array([-3.20594 , 1.22515 , -0.835271, ..., -0.672176, 0.645384, -2.39811 ])
<xarray.Dataset> Dimensions: (chain: 1, draw: 1) Coordinates: * chain (chain) int64 0 * draw (draw) int64 0 Data variables: lp__ (chain, draw) float64 0.0 accept_stat__ (chain, draw) float64 0.0 Attributes: created_at: 2021-08-04T15:59:30.260006 arviz_version: 0.11.2
array([0])
array([0])
array([[0.]])
array([[0.]])
x = np.linspace(-15, 15, 201)
y = scipy.stats.t.pdf(x, nu)
f, ax = plt.subplots()
plt.plot(x, y, label='scipy', linewidth=3)
sn.kdeplot(gaz.posterior.mult_t[{'block':0}].values, label='mult1')
sn.kdeplot(gaz.posterior.mult_t[{'block':1}].values, label='mult2', linestyle=':')
sn.kdeplot(gaz.posterior.uni_t.values, label='uni', linestyle='--')
plt.legend()
<matplotlib.legend.Legend at 0x7f11a0eb3f70>