Hi, I am new to using Stan and am trying to perform a multivariate logistic regression.
Here is my stan code
data {
int<lower=0> N;
int<lower=0> D;
int<lower=0> K;
int<lower=0> y[N];
matrix[N, D] x;
}
parameters {
matrix[D, K] beta;
}
model {
matrix[N, K] x_beta = x * beta;
// prior
to_vector(x_beta) ~ normal(0, 2);
// likelihood
for (n in 1:N) {
y[n] ~ categorical_logit(x_beta[n]');
}
}
I am defining my data as the following:
stan_data = {
'N': len(song_test), # number of songs
'D': 9, # number of features
'K': 10, # number of classes
'y': target, # the reponse
'x': song_test, # model matrix
}
basically, target is the labels, and song_test is a pandas dataframe containing 9 feature for each spotify song (like danceability and energy).
I have tried changing the way I define my data and a few variations of the categorical logit like:
y[n] ~ categorical(softmax(x[n] * beta));
but I can’t seem to solve the problem. My whole error message is the following:
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-97-39ec286e6e86> in <module>()
----> 1 results = stan_model.sampling(data=stan_data)
2 print(results)
/usr/local/anaconda3/envs/stanenv/lib/python3.7/site-packages/pystan/model.py in sampling(self, data, pars, chains, iter, warmup, thin, seed, init, sample_file, diagnostic_file, verbose, algorithm, control, n_jobs, **kwargs)
811 call_sampler_args = izip(itertools.repeat(data), args_list, itertools.repeat(pars))
812 call_sampler_star = self.module._call_sampler_star
--> 813 ret_and_samples = _map_parallel(call_sampler_star, call_sampler_args, n_jobs)
814 samples = [smpl for _, smpl in ret_and_samples]
815
/usr/local/anaconda3/envs/stanenv/lib/python3.7/site-packages/pystan/model.py in _map_parallel(function, args, n_jobs)
83 try:
84 pool = multiprocessing.Pool(processes=n_jobs)
---> 85 map_result = pool.map(function, args)
86 finally:
87 pool.close()
/usr/local/anaconda3/envs/stanenv/lib/python3.7/multiprocessing/pool.py in map(self, func, iterable, chunksize)
266 in a list that is returned.
267 '''
--> 268 return self._map_async(func, iterable, mapstar, chunksize).get()
269
270 def starmap(self, func, iterable, chunksize=None):
/usr/local/anaconda3/envs/stanenv/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
655 return self._value
656 else:
--> 657 raise self._value
658
659 def _set(self, i, obj):
RuntimeError: Initialization failed.
Thank you for your help.