Need help debugging sparse matmul error with non-trivial `map_rect` involved

I am having trouble debugging the following model:

functions {
    vector shard_ll(vector phi,
                        vector theta,
                        array[] real x_r,
                        array[] int x_i) {
        // Unpack Metadata ----
        // Shard Dimensions ----
        int n_obs_shard = x_i[1];
        int n_int = x_i[2];
        int n_fe = x_i[3];
        int n_re = x_i[4];
        int n_nz_int_shard = x_i[5];
        int n_nz_fe_shard = x_i[6];
        int n_nz_re_shard = x_i[7];
        int n_batches = x_i[8];
        int n_feats = x_i[9];

        // Unpack Parameters ----
        int pos = 1;
        vector[n_int] int_coefs = phi[pos:(pos + n_int - 1)];
        pos += n_int;
        vector[n_fe] fe_coefs = phi[pos:(pos + n_fe - 1)];
        pos += n_fe;
        vector[n_re] re_coefs = phi[pos:(pos + n_re - 1)];

        pos += n_re;
        vector[n_batches] sf = phi[pos:(pos + n_batches - 1)];
        pos += n_batches;
        vector[n_feats] iodisp = phi[pos:(pos + n_feats - 1)];

        // Unpack Shard Data ----
        // Integers ----
        int i_pos = 10; // Start after metadata (7) + n_batches/n_feats (2)
        array[n_obs_shard] int batch_id_shard = x_i[i_pos:(i_pos + n_obs_shard - 1)];
        i_pos += n_obs_shard;
        array[n_obs_shard] int feat_id_shard = x_i[i_pos:(i_pos + n_obs_shard - 1)];
        i_pos += n_obs_shard;
        array[n_obs_shard] int counts_shard = x_i[i_pos:(i_pos + n_obs_shard - 1)];
        i_pos += n_obs_shard;

        array[n_nz_int_shard] int int_design_j = x_i[i_pos:(i_pos + n_nz_int_shard - 1)];
        i_pos += n_nz_int_shard;
        array[n_obs_shard + 1] int int_design_p = x_i[i_pos:(i_pos + n_obs_shard)];
        i_pos += n_obs_shard + 1;

        array[n_nz_fe_shard] int fe_design_j;
        array[n_obs_shard + 1] int fe_design_p;
        if (n_fe > 0) {
            fe_design_j = x_i[i_pos:(i_pos + n_nz_fe_shard - 1)];
            i_pos += n_nz_fe_shard;
            fe_design_p = x_i[i_pos:(i_pos + n_obs_shard)];
            i_pos += n_obs_shard + 1;
        }

        array[n_nz_re_shard] int re_design_j;
        array[n_obs_shard + 1] int re_design_p;
        if (n_re > 0) {
            re_design_j = x_i[i_pos:(i_pos + n_nz_re_shard - 1)];
            i_pos += n_nz_re_shard;
            re_design_p = x_i[i_pos:(i_pos + n_obs_shard)];
        }

        // Reals ----
        int r_pos = 1;
        vector[n_nz_int_shard] int_design_x = to_vector(x_r[r_pos:(r_pos + n_nz_int_shard - 1)]);
        r_pos += n_nz_int_shard;

        vector[n_nz_fe_shard] fe_design_x;
        if (n_fe > 0) {
            fe_design_x = to_vector(x_r[r_pos:(r_pos + n_nz_fe_shard - 1)]);
            r_pos += n_nz_fe_shard;
        }

        vector[n_nz_re_shard] re_design_x;
        if (n_re > 0) {
            re_design_x = to_vector(x_r[r_pos:(r_pos + n_nz_re_shard - 1)]);
        }

        // Compute negative binomial likelihood
        vector[n_obs_shard] log_mu;
        real log_lik = 0;

        // Compute log_mu from design matrices
        log_mu = csr_matrix_times_vector(n_obs_shard, n_int, int_design_x, int_design_j, int_design_p, int_coefs);
        if (n_fe > 0) {
            log_mu += csr_matrix_times_vector(n_obs_shard, n_fe, fe_design_x, fe_design_j, fe_design_p, fe_coefs);
        }
        if (n_re > 0) {
            log_mu += csr_matrix_times_vector(n_obs_shard, n_re, re_design_x, re_design_j, re_design_p, re_coefs);
        }

        // Likelihood loop
        for (i in 1:n_obs_shard) {
            // Adjust for batch-effect size factor
            log_mu[i] += sf[batch_id_shard[i]];

            log_lik += neg_binomial_2_log_lpmf(counts_shard[i] | log_mu[i], iodisp[feat_id_shard[i]]);
        }

        return [log_lik]';
    }
}
data {
    // Dimensions ----
    int<lower=0> n_int;       // # of distinct intercept parameters
    int<lower=0> n_fe;        // # of distinct fixed-effects parameters
    int<lower=0> n_re;        // # of distinct random-effects parameters
    int<lower=0> n_re_terms;  // # of random-effects terms in design formula
    int<lower=1> n_batches;   // # of batches
    int<lower=1> n_feats;     // # of original features
    int<lower=1> n_threads;   // # of threads to use

    // Index Variables ----
    array[n_re] int<lower=1, upper=n_re_terms> re_id;

    // Thread-specific Data ----
    int<lower=1> reals_per_thread;
    int<lower=1> ints_per_thread;
    array[n_threads, reals_per_thread] real x_r;
    array[n_threads, ints_per_thread] int x_i;
}
transformed data {
    array[n_threads] vector[0] theta;
}
parameters {
    // Shrinkage ----
    real<lower=0> tau;
    vector<lower=0>[n_fe] lambda;

    // Feature Expression ----
    vector[n_int] int_coefs;
    vector[n_fe] fe_coefs;
    vector[n_re] z_re; // Raw random-effects coefficients
    vector<lower=0>[n_re_terms] re_sigma;

    // Size Factors ----
    simplex[n_batches] raw_sf;

    // Feature-level Dispersion ----
    vector<lower=0>[n_feats] iodisp;
    real iodisp_mu;
    real<lower=0> iodisp_sigma;
    }

transformed parameters {
    // Size Factors ----
    vector[n_batches] sf = log(raw_sf) + log(n_batches);

    // Feature Expression ----
    vector[n_re] re_coefs;
    for (i in 1:n_re) {
        re_coefs[i] = z_re[i] * (re_sigma[re_id[i]] * tau);
    }

    vector[n_int + n_fe + n_re + n_batches + n_feats] phi =
        append_row(int_coefs,
            append_row(fe_coefs,
                append_row(re_coefs,
                    append_row(sf, iodisp))));
}
model {
    // Priors ----
    z_re ~ std_normal();

    // Horseshoe prior for fixed effects
    lambda ~ cauchy(0, 1);
    fe_coefs ~ normal(0, lambda * tau);

    // Inverse overdispersion regularization
    iodisp ~ lognormal(iodisp_mu, iodisp_sigma);

    // Likelihood ----
    target += sum(map_rect(shard_ll, phi, theta, x_r, x_i));
}

This is a lot so I’ll highlight the important component:

        // Compute log_mu from design matrices
        print("n_obs: ", n_obs_shard);
        print("n_int: ", n_int);
        print("x: ", size(int_design_x));
        print("j: ", size(int_design_j));
        print("p: ", size(int_design_p));
        print("n_int: ", size(int_coefs));
        log_mu = csr_matrix_times_vector(n_obs_shard, n_int, int_design_x, int_design_j, int_design_p, int_coefs);
        if (n_fe > 0) {
            log_mu += csr_matrix_times_vector(n_obs_shard, n_fe, fe_design_x, fe_design_j, fe_design_p, fe_coefs);
        }
        if (n_re > 0) {
            log_mu += csr_matrix_times_vector(n_obs_shard, n_re, re_design_x, re_design_j, re_design_p, re_coefs);
        }

When I run this model with my test data I get this error:

n_obs: 142
n_int: 500
x: 142
j: 142
p: 143
n_int: 500

Chain 1 Unrecoverable error evaluating the log probability at the initial value.
Chain 1 Exception: Exception: csr_matrix_times_vector: u/z (141) and v (142) must match in size (in '/var/folders/21/t9hvtpkd72n1pjw4nvd3g0g80000gn/T/RtmpLysClz/model-53764869b2b8.stan', line 103, column 8 to column 114) (in '/var/folders/21/t9hvtpkd72n1pjw4nvd3g0g80000gn/T/RtmpLysClz/model-53764869b2b8.stan', line 208, column 4 to column 60)
Warning: Fitting finished unexpectedly! Use the $output() method for more information.

So I presume u/z is the vector of nonzero values of the sparse matrix and v are the column indices, as these both should have the same length, while the row start indices should have length n_rows + 1.

As you can see the print statements confirm the sizes all match, so I’m confused as to how I’m getting this error.

One possibility I’ve thought of it is that since each thread works on only a single block of data here, there could be a thread with a smaller amount of data that’s causing an issue (perhaps all threads are required to have all objects with the same shape), so here is the metadata for the matrix I pass to Stan:

> ints[,1:9]
     [,1] [,2] [,3] [,4] [,5] [,6] [,7] [,8] [,9]
[1,]  142  500    0 1000  142    0  142    2  500
[2,]  143  500    0 1000  143    0  143    2  500
[3,]  143  500    0 1000  143    0  143    2  500
[4,]  143  500    0 1000  143    0  143    2  500
[5,]  143  500    0 1000  143    0  143    2  500
[6,]  143  500    0 1000  143    0  143    2  500
[7,]  143  500    0 1000  143    0  143    2  500

The first block has n_obs_shard = n_nz_int_shard = n_nz_re_shard = 142, while all others have 143.

But this doesn’t explain why the exception states u/z has length 141. Additionally, when I change n_threads to perfectly divide into the number of observations(such that all objects are the same size in all threads) I still get the same type of error where the length of u/z is 1 less than what it should be.

Here is a snippet of the R code I use to generate the data I pass to Stan:

    # Concatenate relevant data for Stan
    # reals: [ |- int_design@x -| , |- fe_design@x -|, |- re_design@x -|]
    # ints:  [
    #   n_obs, n_ints, n_fe, n_re, n_nz_int, n_nz_fe, n_nz_re, n_batches, n_feats,
    #   |- batch_id -|, |- feat_id -|,
    #   |- counts -|,
    #   |- int_design@j -|, |- int_design@p -|,
    #   |- fe_design@j -|, |- re_design@p -|,
    #   |- fe_design@j -|, |- re_design@p -|
    # ]
    data_partition <- lapply(1:n_threads, function(i) {
        chunk_idxs <- chunk_idxs_list[[i]]

        reals <- c(
            int_design_partition[[i]][["reals"]],
            fe_design_partition[[i]][["reals"]],
            re_design_partition[[i]][["reals"]]
        )

        ints <- c(
            length(chunk_idxs),
            stan_data[["n_int"]],
            stan_data[["n_fe"]],
            stan_data[["n_re"]],
            int_design_partition[[i]][["n_nz_int"]],
            fe_design_partition[[i]][["n_nz_fe"]],
            re_design_partition[[i]][["n_nz_re"]],
            stan_data[["n_batches"]],
            stan_data[["n_feats"]],
            stan_data[["batch_id"]][chunk_idxs],
            stan_data[["feat_id"]][chunk_idxs],
            counts_partition[[i]][["ints"]],
            int_design_partition[[i]][["ints"]],
            fe_design_partition[[i]][["ints"]],
            re_design_partition[[i]][["ints"]]
        )

        list(reals = reals, ints = ints)
    })

    # Compute required padding
    lengths <- sapply(data_partition, function(data_chunk) {
        c(
            length(data_chunk[["reals"]]),
            length(data_chunk[["ints"]])
        )
    })

    # Construct empty matrices for thread-specific real and integer data
    reals <- matrix(0, n_threads, max(lengths[1, ]))
    ints <- matrix(0, n_threads, max(lengths[2, ]))

    # Fill in matrix
    for (i in 1:n_threads) {
        reals[i, 1:lengths[1, i]] <- data_partition[[i]][["reals"]]
        ints[i, 1:lengths[2, i]] <- data_partition[[i]][["ints"]]
    }

    stan_data[["reals_per_thread"]] <- max(lengths[1, ])
    stan_data[["ints_per_thread"]] <- max(lengths[2, ])
    stan_data[["x_r"]] <- reals
    stan_data[["x_i"]] <- ints

I’m not sure why this error is showing up.

Forgot mat@j and mat@p in R are 0-indexed, adding 1L fixed all the issues never mind!