I adatpted the code with categorical_logit_glm_lpmf. Same result. Below is the code. I tried with the Device #1: Intel(R) UHD Graphics 630 and the Device #2: AMD Radeon Pro 5500M Compute Engine (btw, which one should i use?), and in both cases GPU usage was 0%. I certainly did something wrong, but cannot figure what. Any cue would be greatly appreciated!
functions {
real partial_sum(array[] int slice_n,
int start, int end,
array[] int Tsubj,
array[,] int sample,
array[,] int choice,
array[,,] int color,
array[,,] real proba,
matrix params,
int T_max,
int I_max) {
real log_prob = 0;
// Loop over the slice of data
for (n in start:end) {
int trials_for_subject = Tsubj[n]; // Extract number of trials for subject n
for (t in 1:trials_for_subject) { // Loop through trials for this subject
// Skip if sample data is missing/invalid
if (sample[n,t] < 0) continue;
vector[2] evidence = rep_vector(0.0, 2);
for (i in 1:sample[n,t]) { // Loop through samples
real p = proba[n, t, i]; // Probability for the current sample
real l = log(p / (1 - p)); // Log-likelihood computation
real log_odds = params[n,1] * l + params[n, 2];
evidence[color[n, t, i]] += exp(params[n,3]*(i-sample[n,t])) * log_odds; // Add to the evidence
}
// Skip if choice data is missing/invalid
if (choice[n, t] < 0) continue;
log_prob += categorical_logit_glm_lpmf(
{ choice[n, t] } | // Array of length 1
rep_matrix(1, 1, 1), // Predictor matrix (1x1)
rep_vector(0, 2), // Intercept vector of length 2
to_matrix(evidence') // Logits as a 1x2 matrix
);
}
}
return log_prob;
}
}
data {
int<lower=1> N; // Number of subjects
int<lower=1> T_max; // Maximum number of trials across all subjects
int<lower=1> I_max; // Maximum number of samples across all trials
array[N] int<lower=1> Tsubj; // Number of trials per subject
array[N, T_max] int<lower=-1> sample; // Number of samples per subj and trial, -1 indicates missing data
array[N, T_max, I_max] int<lower=-1, upper=2> color; // Info item colors, -1 indicates missing data
array[N, T_max, I_max] real<lower=-1, upper=1> proba; // Probabilities for info items, -1 indicates missing data
array[N, T_max] int<lower=-1, upper=2> choice; // Choices for each subject, -1 indicates missing data
}
parameters {
vector[3] mu_pr; // Group-level means
vector<lower=0>[3] sigma_pr; // Group-level SDs
matrix[N, 3] param_raw; // Subject-level raw parameters
}
transformed parameters {
matrix[N,5] params; // Subject-level parameters
// Transform subject-level parameters
for (n in 1:N) {
params[n, 1] = Phi_approx(mu_pr[1] + sigma_pr[1] * param_raw[n, 1]) * 10; // alpha (slope)
params[n, 2] = mu_pr[2] + sigma_pr[2] * param_raw[n, 2]; // beta (intercept)
params[n, 3] = Phi_approx(mu_pr[3] + sigma_pr[3] * param_raw[n, 3]) * 10; // lambda (discount)
}
}
model {
// Priors
int grainsize = 1; // grainsize for parallelization
mu_pr ~ std_normal();
sigma_pr ~ std_normal();
to_vector(param_raw) ~ std_normal();
// Create an array with indices that will be sliced
array[N] int indices;
for (n in 1:N) {
indices[n] = n;
}
// Likelihood using reduce_sum with the correct signature
target += reduce_sum(partial_sum, indices, grainsize,
Tsubj, sample, choice, color, proba, params,
T_max, I_max);
}
generated quantities {
array[N, T_max] real log_lik;
// Group-level transformed parameters
real mu_alpha = Phi_approx(mu_pr[1]) * 10;
real mu_beta = mu_pr[2];
real mu_lambda = Phi_approx(mu_pr[3])*10;
// Log-likelihood computation
for (n in 1:N) {
for (t in 1:Tsubj[n]) {
// Initialize to missing value indicator
log_lik[n, t] = -999;
// Skip if sample data is missing
if (sample[n,t] < 0) continue;
// Skip if choice data is missing
if (choice[n, t] < 0) continue;
vector[2] evidence = rep_vector(0.0, 2);
for (i in 1:sample[n,t]) {
// Skip if color or proba data is missing
if (color[n, t, i] < 0 || proba[n, t, i] < 0) continue;
real p = proba[n, t, i];
real l = log(p / (1 - p));
real log_odds = params[n,1] * l + params[n, 2];
evidence[color[n, t, i]] += exp(params[n,3]*(i-sample[n,t])) * log_odds;
}
vector[2] prob = softmax(evidence);
log_lik[n, t] = categorical_logit_lpmf(choice[n, t] | evidence);
}
}
}