RuntimeError: Initialization failed in categorical logit model

Hi, I am new to using Stan and am trying to perform a multivariate logistic regression.

Here is my stan code

data {
    int<lower=0> N;
    int<lower=0> D;
    int<lower=0> K;
    int<lower=0> y[N];
    matrix[N, D] x;
}

parameters {
    matrix[D, K] beta;
}

model {
    matrix[N, K] x_beta = x * beta;
    
    // prior
    to_vector(x_beta) ~ normal(0, 2);
    
    // likelihood
    for (n in 1:N) {
        y[n] ~ categorical_logit(x_beta[n]');
    }
}

I am defining my data as the following:

stan_data = {
    'N': len(song_test), # number of songs
    'D': 9, # number of features
    'K': 10, # number of classes
    'y': target, # the reponse
    'x': song_test, # model matrix
}

basically, target is the labels, and song_test is a pandas dataframe containing 9 feature for each spotify song (like danceability and energy).

I have tried changing the way I define my data and a few variations of the categorical logit like:

y[n] ~ categorical(softmax(x[n] * beta));

but I can’t seem to solve the problem. My whole error message is the following:

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-97-39ec286e6e86> in <module>()
----> 1 results = stan_model.sampling(data=stan_data)
      2 print(results)

/usr/local/anaconda3/envs/stanenv/lib/python3.7/site-packages/pystan/model.py in sampling(self, data, pars, chains, iter, warmup, thin, seed, init, sample_file, diagnostic_file, verbose, algorithm, control, n_jobs, **kwargs)
    811         call_sampler_args = izip(itertools.repeat(data), args_list, itertools.repeat(pars))
    812         call_sampler_star = self.module._call_sampler_star
--> 813         ret_and_samples = _map_parallel(call_sampler_star, call_sampler_args, n_jobs)
    814         samples = [smpl for _, smpl in ret_and_samples]
    815 

/usr/local/anaconda3/envs/stanenv/lib/python3.7/site-packages/pystan/model.py in _map_parallel(function, args, n_jobs)
     83         try:
     84             pool = multiprocessing.Pool(processes=n_jobs)
---> 85             map_result = pool.map(function, args)
     86         finally:
     87             pool.close()

/usr/local/anaconda3/envs/stanenv/lib/python3.7/multiprocessing/pool.py in map(self, func, iterable, chunksize)
    266         in a list that is returned.
    267         '''
--> 268         return self._map_async(func, iterable, mapstar, chunksize).get()
    269 
    270     def starmap(self, func, iterable, chunksize=None):

/usr/local/anaconda3/envs/stanenv/lib/python3.7/multiprocessing/pool.py in get(self, timeout)
    655             return self._value
    656         else:
--> 657             raise self._value
    658 
    659     def _set(self, i, obj):

RuntimeError: Initialization failed.

Thank you for your help.

Figured it out… it was my target labels that were from [0,9] instead of [1,10]!

1 Like

Welcome!

Ah, yeah unlike Python Stan indexes starting from 1. Glad you figured it out.

1 Like