Kernel averaged covariates

I’m looking for any recommendations you might give on speeding up the model below. The model fits well with good diagnostics, however it is quite slow, especially with the scale of data I’m hoping to use.

I realize that my model is slower than it might otherwise be because of all the calculations done in the modeling statement - I’ve tried to compensate for this by making them as efficient as possible and easily differentiable.

Be sure to let me know if you think it seems unproductive to use stan for this kind of model.

vector cwrgauss_kernel(matrix dist, matrix time_mat, real theta_one, real theta_two, vector ones){
    vector[rows(time_mat)] out;

    out = (exp( dist * inv_square(theta_one) ) * inv(theta_one) * ones) .* (1.25331 * erf(time_mat * inv(theta_two))) * ones; 
data {
    int<lower=0> N; // number of subjects
    int<lower=0> J; // Number of examinations
    #include "prior_data_lmm.stan" // prior_data 
    int kernel_ix[q];
    int ef_freq[q];
    matrix[J,qt] spatial_mat[N]; 
    matrix<lower=0>[J,qt] time_mat[N]; // n,q,j element is the time exposure for the nth subject with the qth feature at jth time point
    matrix[J,p] Z[N];
    vector[J] Y[N]; //array of outcomes
transformed data{
    vector[10] ones = rep_vector(1.0,10);
    cov_matrix[J] I = diag_matrix(rep_vector(1.0,J));
parameters {
    real beta_naught;
    vector[p] beta_one;
    vector[q] beta_two;
    real<lower=0> thetas[2]; 
    real subj_int[N];
    real<lower=0> subj_sig;
    real<lower=0> pop_sig;
transformed parameters{
    cov_matrix[J] R;
    R = I * pop_sig;
    #include "priors_stkap_lmm.stan"
    pop_sig ~ cauchy(0,25);
    subj_sig ~ cauchy(0,25); 
    subj_int ~ normal(0,subj_sig);
    thetas ~ normal(rep_vector(sqrt(8)/2,2),rep_vector(.5,2));
        vector[J] X_delta[N];
        vector[J] yhat[N];
        for(n in 1:N){
            X_delta[n] =  cwrgauss_kernel(spatial_mat[n], time_mat[n],thetas[1],thetas[2],ones);
            yhat[n] = beta_naught + Z[n] * beta_one  + X_delta[n] * beta_two[1] + subj_int[n];
        Y ~ multi_normal(yhat,R);

Don’t use multi_normal with a diagonal covariance matrix. Just do normal with standard deviation pop_sig but you may have to change things around so that Y and yhat are long vectors.