Reduce sum including slicing on transformed parameters

Hi Staniacs, The following Stan model implements a somewhat more complicated generalization of a logistic regression. I’m trying to use reduce sum (partial sum) to slice on both the bernoulli likelihood and it’s mean (probability) vector, named ptilde. So, I have included the generation of the mean vector, ptilde, inside the partial sum function. If I compile with multithreading turned off the model estimates well, though there are some early nan warnings about the mean vector ptilde that resolve. Once I compile with multithreading turned on, however, it won’t sample because it estimates nan for the mean vector ptilde; in particular, I receive the following error message:

Chain 1 Exception: Exception: bernoulli_lpmf: Probability parameter[1] is nan, but must be in the interval [0, 1] (in 'C:/Users/SAVITS~1/AppData/Local/Temp/RtmpG4hURT/model-4f6c533d56d3.stan', line 83, column 4 to line 84, column 65) (in 'C:/Users/SAVITS~1/AppData/Local/Temp/RtmpG4hURT/model-4f6c533d56d3.stan', line 197, column 1 to line 199, column 96)

Here is the Stan model. I know it’s rather long and difficult to parse (pun intended), so I;d appreciate your expert comments from focusing on my partial sum user-defined function. I’m using cmdstanr and the current production version of cmdstan.

functions{
  
  vector build_b_spline(vector t, vector ext_knots, int ind, int order);
  matrix build_mu(int start, int end, int n, int K_sp, 
                  matrix X, matrix[] G, matrix beta_x, matrix[] beta_w);
  vector build_b_spline(vector t, vector ext_knots, int ind, int order) {
    // INPUTS:
    //    t:          the points at which the b_spline is calculated
    //    ext_knots:  the set of extended knots
    //    ind:        the index of the b_spline
    //    order:      the order of the b-spline
    vector[num_elements(t)] b_spline;
    vector[num_elements(t)] w1 = rep_vector(0, num_elements(t));
    vector[num_elements(t)] w2 = rep_vector(0, num_elements(t));
    if (order==1)
      for (i in 1:num_elements(t)) // B-splines of order 1 are piece-wise constant
        b_spline[i] = (ext_knots[ind] <= t[i]) && (t[i] < ext_knots[ind+1]);
    else {
      if (ext_knots[ind] != ext_knots[ind+order-1])
        w1 = (to_vector(t) - rep_vector(ext_knots[ind], num_elements(t))) /
             (ext_knots[ind+order-1] - ext_knots[ind]);
      if (ext_knots[ind+1] != ext_knots[ind+order])
        w2 = 1 - (to_vector(t) - rep_vector(ext_knots[ind+1], num_elements(t))) /
                 (ext_knots[ind+order] - ext_knots[ind+1]);
      // Calculating the B-spline recursively as linear interpolation of two lower-order splines
      b_spline = w1 .* build_b_spline(t, ext_knots, ind, order-1) +
                 w2 .* build_b_spline(t, ext_knots, ind+1, order-1);
    }
    return b_spline;
  }
  
  matrix build_mu(int start, int end, int n, int K_sp, 
     matrix X, matrix[] G, matrix beta_x, matrix[] beta_w){
       
     matrix[n,2] mu_x;
    
     for( arm in 1:2 )
     {
       mu_x[start:end,arm]     = X[start:end,] * to_vector(beta_x[,arm]); /* n x l for each arm */
       // spline term
       for( k in 1:K_sp )
       {
         mu_x[start:end,arm]  += to_vector(beta_w[arm][,k]' * G[k][,start:end]); /* n x 1 */
       } /* end loop k over K predictors */
      
     }/* end loop arm over convenience and reference sample arms */
    
    // enforce sum to n constraint for p[,2]
    {
      vector[n] pi_r;
      real max_pir;
      pi_r[start:end]                 = inv_logit(mu_x[start:end,2]);
      pi_r[start:end]                 = (pi_r[start:end] / sum(pi_r)) * n;
      max_pir                         = max( pi_r );
      if( max_pir > 1 )
        pi_r  = pi_r / (max_pir + 1e-5);
      mu_x[start:end,2]    = logit( pi_r[start:end] );
    }
    
    return mu_x;
    
  } /* end function build_mu */
  
  real partial_sum(int[] s, 
                   int start, int end, real[] logit_pw, int K_sp, int n_c, int n,
                   matrix X, matrix[] G, matrix beta_x, matrix[] beta_w, 
                   real logit_h, real logit_q, real logit_t, real phi_w) {
     matrix[n,2] p;
     vector[n] p_tilde; // this pseudoprobability must be in [0,1]
     real ratio_h           = expm1(-logit_h) + 1; // (1-h)/h
     real qdt               = inv_logit(logit_q) / inv_logit(logit_t);
     real fred;
     matrix[n,2] mu_x       = build_mu(start, end, n, K_sp, X, G, beta_x, beta_w);
     
    for( arm in 1:2 )
    {
      p[start:end,arm]        = inv_logit( mu_x[start:end,arm] );
    }
  
    p_tilde[start:end]           = p[start:end,1] ./ ( p[start:end,1] + (qdt * ratio_h)*p[start:end,2] );
         
    fred =  bernoulli_lpmf(s | p_tilde[start:end]) 
            + normal_lpdf( logit_pw | mu_x[(n_c+1):n,2], phi_w );     
    return fred; 
  }
  

} /* end function block */

data{
    int<lower=1> n_c; // observed convenience (non-probability) sample size
	  int<lower=1> n_r; // observed reference (probability) sample size
	  int<lower=1> N; // estimate of population size underlying reference and convenience samples
	  int<lower=1> n; // total sample size, n = n_c + n_r
	  int<lower=1> K; // number of fixed effects
	  int<lower=0> K_sp; // number of predictors to model under a spline basis
	  int<lower=1> num_knots;
    int<lower=1> spline_degree;
    matrix[num_knots,K_sp] knots;
    real<lower=0> weights[n_r]; // sampling weights for n_r observed reference sample units
    matrix[n_c, K] X_c; //  *All* predictors - continuous and categorical -  for the convenience units
    matrix[n_r, K] X_r; // *All* predictors - continuous and categorical -  for the reference units
    matrix[n_c, K_sp] Xsp_c; //  *Continuous* predictors under a spline basis for convenience units
    matrix[n_r, K_sp] Xsp_r; // *Continuous* predictors under a spline basis for convenience units
    int<lower=1> n_df;
} /* end data block */

transformed data{
  // create indicator variable of membership in convenience or reference samples
  // indicator of observation membership in the convenience sample
  int grainsize                     = 1;
  real logit_pw[n_r]                = logit(inv(weights));
  int<lower=0, upper = 1> s[n]      = to_array_1d( append_array(rep_array(1,n_c),rep_array(0,n_r)) ); 
  matrix[n,K] X                     = append_row( X_c,X_r );
  matrix[n,K_sp] X_sp               = append_row( Xsp_c,Xsp_r );
  /* formulate spline basis matrix, B */
  int num_basis = num_knots + spline_degree - 1; // total number of B-splines
  matrix[spline_degree + num_knots,K_sp] ext_knots_temp;
  matrix[2*spline_degree + num_knots,K_sp] ext_knots;
  matrix[num_basis,n] G[K_sp]; /* basis for model on p_c */
  for(k in 1:K_sp)
  {
     ext_knots_temp[,k] = append_row(rep_vector(knots[1,k], spline_degree), knots[,k]);
    // set of extended knots
     ext_knots[,k] = append_row(ext_knots_temp[,k], rep_vector(knots[num_knots,k], spline_degree));
     for (ind in 1:num_basis)
     {
        G[k][ind,] = to_row_vector(build_b_spline(X_sp[,k], ext_knots[,k], ind, (spline_degree + 1)));
     }
     G[k][num_knots + spline_degree - 1, n] = 1;
  }
     
} /* end transformed data block */

parameters {
  real logit_h; // h = Pr(s_i = 1)
  real logit_t;
  real logit_q;
  real<lower=0> sigma_h; // scale parameter for prior on logit_h
  real<lower=0> sigma_t;
  real<lower=0> sigma_q;
  matrix<lower=0>[K,2] sigma_betax; /* standard deviations of K x 2, beta_x */
                                    /* first column is convenience sample, "c", and second column is "r" */

  matrix[K,2] betaraw_x; /* fixed effects coefficients - first colum for p_c; second column for p_r  */
  // spline coefficients
  vector<lower=0>[2] sigma_global; /* set this equal to 1 if having estimation difficulties */
  matrix<lower=0>[2,K_sp] sigma_w;
  matrix[num_basis,K_sp] betaraw_w[2];  // vector of B-spline regression coefficients for each predictor, k
                                        // and 2 sample arms
  real<lower=0> phi_w; /* scale parameter in model for -1og(weights) */                                     
} /* end parameters block */

transformed parameters {
  matrix[K,2] beta_x;
  matrix[num_basis,K_sp] beta_w[2];
  
  // for scale parameters for interaction effects from those for main effects to which they link
  for( arm in 1:2 )
  {
     beta_x[,arm]   = betaraw_x[,arm] .* sigma_betax[,arm]; /* Non-central parameterization */
  }/* end loop arm over convenience and reference sample arms */
  
  // spline regression coefficients
  for(arm in 1:2)
  {
    for( k in 1:K_sp ) 
    {
      beta_w[arm][,k]   = cumulative_sum(betaraw_w[arm][,k]);
      beta_w[arm][,k]  *= sigma_w[arm,k] * sigma_global[arm];
    } /* end loop k over K predictors */
  }
  
} /* end transformed parameters block */

 model {
  // L_beta          ~ lkj_corr_cholesky(6);
  to_vector(sigma_betax) ~ student_t(n_df,0,3);
  to_vector(sigma_w)     ~ student_t(n_df,0,1);
  sigma_global           ~ student_t(1,0,1);
  sigma_h                ~ student_t(n_df,0,1);
  sigma_q                ~ student_t(n_df,0,1);
  sigma_t                ~ student_t(n_df,0,1);
  phi_w                  ~ student_t(n_df,0,1);

  to_vector(betaraw_x)   ~ std_normal();
  for(arm in 1:2)
    to_vector(betaraw_w[arm])   ~ std_normal();

  logit_h   ~ student_t( n_df, logit(n_c/(n +0.0)), sigma_h );
  logit_t   ~ student_t( n_df, logit(n_r/(N +0.0)), sigma_t );
  logit_q   ~ student_t( n_df, logit(n_c/(N +0.0)), sigma_q );
  
  /* Model likelihood for y, logit_pw */
 // Sum terms 1 to n in the likelihood 
 target            += reduce_sum(partial_sum, s, grainsize, 
                                              logit_pw, K_sp, n_c, n, X, G, 
                                              beta_x, beta_w, logit_h, logit_q, logit_t, phi_w); 

} /* end model block */  

generated quantities{
  matrix[n,2] p;
  matrix[n,2] mu_x;
  
  for( arm in 1:2 )
     {
       mu_x[,arm]     = X[,] * to_vector(beta_x[,arm]); /* n x l for each arm */
       // spline term
       for( k in 1:K_sp )
       {
         mu_x[,arm]  += to_vector(beta_w[arm][,k]' * G[k][,1:n]); /* n x 1 */
       } /* end loop k over K predictors */
      
     }/* end loop arm over convenience and reference sample arms */
  
    // enforce sum to n constraint for p[,2]
    {
      vector[n] pi_r       = inv_logit(mu_x[,2]);
      real max_pir;
      pi_r                 = (pi_r / sum(pi_r)) * n;
      max_pir              = max( pi_r );
      if( max_pir > 1 )
        pi_r  = pi_r / (max_pir + 1e-5);
      mu_x[,2]             = logit( pi_r );
    }
  
    for( arm in 1:2 )
    {
      p[,arm]        = inv_logit( mu_x[,arm] );
    }
  
  // smoothed sampling weights for convenience and reference units
  vector[n] weights_smooth_c = inv(p[,1]);
  vector[n] weights_smooth_r = inv(p[,2]);
  // inclusion probabilities in convenience and reference units for convenience units
  // use for soft thresholding
  vector[n_c] pi_c                  = p[1:n_c,1];
  vector[n_c] pi_r_c                = p[1:n_c,2];
  // normalized weights
  weights_smooth_c                  *= ((n_c+0.0)/(n+0.0)) * (sum(weights_smooth_r)/sum(weights_smooth_c));
  weights_smooth_r                  *= ((n_r+0.0)/(n+0.0));
} /* end generated quantities block */


I have a sum-to-a-constant constraint on p[1:n,2] that, in turn, is used to compute the Bernoulli mean, p_tilde. Perhaps this standardization induces a dependence in the unit means that prevents within chain parallelization. Conditioned on the sum, the unit means are still independent.

One suggestion, both for efficiency and for ease of troubleshooting is that within partial_sum, right now you are building objects like mu_x, p, and p_tilde to be n elements long, and then you are only writing a subset of the elements from start:end. It would be preferable to just give each of these objects 1 + end - start elements (or rows) and then drop the ubiquitous start:end indexing except where needed to interact with data or parameters defined outside of partial_sum (also rewrite build_mu so that its output is of length 1 + end - start instead of length n).

In addition to some efficiency gains, this will ensure that you aren’t accidentally grabbing the wrong, unwritten elements of these objects anywhere in your code. For example, while I’m not 100% sure this is the problem, I wonder if when you do max( pi_r ), given that pi_r is mostly empty, whether max might be returning something that isn’t a number. Even if this isn’t exactly your problem, my guess is that your problem is something similar; somewhere you are accidentally using some of these superfluous empty matrix/array elements that you are carrying around.

1 Like

Thank you for the helpful suggestions @jsocolar. They worked and I attach the script the final script that successfully ran, which includes the sum-to-constant constraint for pi_r. I carefully implemented the sum-to-constant so that pi_r would be full before computing max(pi_r) or sum(pi_r).

This final version slices on all of parameter vectors mu_x[,arm], p[,arm], and p_tilde, in addition to data vectors s, logit_pw, which should produce a maximally efficient implentation; however, I am finding that this parallelized model always takes longer than the non-threaded/non-parallelized version, no matter how many threads I use. joaneelleouet also reports this same issue. This is particularly disappointing because I have implemented slicing on transformed parameters that form the mean of the data likelihoods such that it should be maximally scalable under multithreading.

I’d appreciate any further insights on why the parallelized model is always slower than the non-parallelized. Please see my final parallelized script, below:

functions{
  
  vector build_b_spline(vector t, vector ext_knots, int ind, int order);
  row_vector build_muxi(int i, int K_sp, int num_basis, row_vector x_i, vector[] g_i, 
                        matrix beta_x, matrix[] beta_w);
  vector build_b_spline(vector t, vector ext_knots, int ind, int order) {
    // INPUTS:
    //    t:          the points at which the b_spline is calculated
    //    ext_knots:  the set of extended knots
    //    ind:        the index of the b_spline
    //    order:      the order of the b-spline
    vector[num_elements(t)] b_spline;
    vector[num_elements(t)] w1 = rep_vector(0, num_elements(t));
    vector[num_elements(t)] w2 = rep_vector(0, num_elements(t));
    if (order==1)
      for (i in 1:num_elements(t)) // B-splines of order 1 are piece-wise constant
        b_spline[i] = (ext_knots[ind] <= t[i]) && (t[i] < ext_knots[ind+1]);
    else {
      if (ext_knots[ind] != ext_knots[ind+order-1])
        w1 = (to_vector(t) - rep_vector(ext_knots[ind], num_elements(t))) /
             (ext_knots[ind+order-1] - ext_knots[ind]);
      if (ext_knots[ind+1] != ext_knots[ind+order])
        w2 = 1 - (to_vector(t) - rep_vector(ext_knots[ind+1], num_elements(t))) /
                 (ext_knots[ind+order] - ext_knots[ind+1]);
      // Calculating the B-spline recursively as linear interpolation of two lower-order splines
      b_spline = w1 .* build_b_spline(t, ext_knots, ind, order-1) +
                 w2 .* build_b_spline(t, ext_knots, ind+1, order-1);
    }
    return b_spline;
  }
  
  row_vector build_muxi(int i, int K_sp, int num_basis,
     row_vector x_i, vector[] g_i, matrix beta_x, matrix[] beta_w){
       
     row_vector[2] mu_xi;
    
     for( arm in 1:2 )
     {
       mu_xi[arm]     = dot_product(x_i,beta_x[,arm]); /* scalar */
       // spline term
       for( k in 1:K_sp )
       {
         mu_xi[arm]   += dot_product(beta_w[arm][1:num_basis,k], g_i[k][1:num_basis]); /* scalar */
       } /* end loop k over K predictors */
      
     }/* end loop arm over convenience and reference sample arms */
    
    return mu_xi;
    
  } /* end function build_mu */
  
  real partial_sum(int[] s, 
                   int start, int end, real[] logit_pw, int K_sp, int n_c, int n,
                   int num_basis, matrix X, matrix[] G, matrix beta_x, matrix[] beta_w, 
                   real logit_h, real logit_q, real logit_t, real phi_w) {
     int N = end - start + 1;
     matrix[N,2] mu_x;
     matrix[N,2] p;
     vector[N] p_tilde; // this pseudoprobability must be in [0,1]
     real ratio_h           = expm1(-logit_h) + 1; // (1-h)/h
     real qdt               = inv_logit(logit_q) / inv_logit(logit_t);
     real fred              = 0;
     
     // memo: slicing on all of mu_x[li,arm], p[li,arm], p_tilde[li] for li in 1:(end-start+1)
     //       where p_tilde is the mean vector for binary data vector, s, and mu_x[,2]
     //       is the mean vector for data vector logit_pw.
     //       Also slicing data vectors s and logit_pw in their respective
     //       log-likelihood contributions.
     
     for( li in 1:N)
     {
       vector[num_basis] g_i[K_sp]       = G[1:K_sp,1:num_basis,li];
       mu_x[li,]                         = build_muxi(li, K_sp, num_basis, X[li,], g_i, beta_x, beta_w);
     }
               
    // enforce sum to n constraint for p[,2]
    { /* begin local block to normalize n x 1 p[,2] probability parameters to sum to n */
      vector[N] pi_r;
      real max_pir;
      real sum_pir;
      for( li in 1:N )
      {
        pi_r[li]            = inv_logit( mu_x[li,2] );
      } /* endloop i over all units */
      sum_pir               = sum(pi_r);
      
      for( li in 1:N )
      {
        pi_r[li]            = (pi_r[li] / sum_pir) * n;
      }/* endloop i over all units */
      max_pir                         = max( pi_r );
      
      for( li in 1:N )
      {
        if( max_pir > 1)
          pi_r[li]         = pi_r[li] / (max_pir + 1e-5);
      }/* endloop i over all units */
        
    } /* end local block to normalize n x 1 p[,2] probability parameters to sum to n */ 
    
    for( arm in 1:2 )
    {
      for( li in 1:N )
      {
        p[li,arm]       = inv_logit( mu_x[li,arm] );
      }
      
    } /* end loop i over all units */
    
    for( li in 1:N ) // N = end - start + 1
    {
      p_tilde[li]            = p[li,1] ./ ( p[li,1] + (qdt * ratio_h)*p[li,2] );
      fred                  += bernoulli_lpmf(s[li] | p_tilde[li]);
    } /* end loop i over all units */
    
    for( li in n_c+1:N )
    {
      // int i                 = li + start - 1;
      fred                  += normal_lpdf( logit_pw[li-n_c] | mu_x[li,2], phi_w );
    } /* end loop i ove n_r units */
  
  return fred;
  }/* end function partial_sum() */

} /* end function block */

data{
    int<lower=1> n_c; // observed convenience (non-probability) sample size
	  int<lower=1> n_r; // observed reference (probability) sample size
	  int<lower=1> N; // estimate of population size underlying reference and convenience samples
	  int<lower=1> n; // total sample size, n = n_c + n_r
	  int<lower=1> K; // number of fixed effects
	  int<lower=0> K_sp; // number of predictors to model under a spline basis
	  int<lower=1> num_knots;
    int<lower=1> spline_degree;
    matrix[num_knots,K_sp] knots;
    real<lower=0> weights[n_r]; // sampling weights for n_r observed reference sample units
    matrix[n_c, K] X_c; //  *All* predictors - continuous and categorical -  for the convenience units
    matrix[n_r, K] X_r; // *All* predictors - continuous and categorical -  for the reference units
    matrix[n_c, K_sp] Xsp_c; //  *Continuous* predictors under a spline basis for convenience units
    matrix[n_r, K_sp] Xsp_r; // *Continuous* predictors under a spline basis for convenience units
    int<lower=1> n_df;
} /* end data block */

transformed data{
  // create indicator variable of membership in convenience or reference samples
  // indicator of observation membership in the convenience sample
  int grainsize                     = 1;
  real logit_pw[n_r]                = logit(inv(weights));
  int<lower=0, upper = 1> s[n]      = to_array_1d( append_array(rep_array(1,n_c),rep_array(0,n_r)) ); 
  matrix[n,K] X                     = append_row( X_c,X_r );
  matrix[n,K_sp] X_sp               = append_row( Xsp_c,Xsp_r );
  /* formulate spline basis matrix, B */
  int num_basis = num_knots + spline_degree - 1; // total number of B-splines
  matrix[spline_degree + num_knots,K_sp] ext_knots_temp;
  matrix[2*spline_degree + num_knots,K_sp] ext_knots;
  matrix[num_basis,n] G[K_sp]; /* basis for model on p_c */
  for(k in 1:K_sp)
  {
     ext_knots_temp[,k] = append_row(rep_vector(knots[1,k], spline_degree), knots[,k]);
    // set of extended knots
     ext_knots[,k] = append_row(ext_knots_temp[,k], rep_vector(knots[num_knots,k], spline_degree));
     for (ind in 1:num_basis)
     {
        G[k][ind,] = to_row_vector(build_b_spline(X_sp[,k], ext_knots[,k], ind, (spline_degree + 1)));
     }
     G[k][num_knots + spline_degree - 1, n] = 1;
  }
     
} /* end transformed data block */

parameters {
  real logit_h; // h = Pr(s_i = 1)
  real logit_t;
  real logit_q;
  real<lower=0> sigma_h; // scale parameter for prior on logit_h
  real<lower=0> sigma_t;
  real<lower=0> sigma_q;
  matrix<lower=0>[K,2] sigma_betax; /* standard deviations of K x 2, beta_x */
                                    /* first column is convenience sample, "c", and second column is "r" */

  matrix[K,2] betaraw_x; /* fixed effects coefficients - first colum for p_c; second column for p_r  */
  // spline coefficients
  vector<lower=0>[2] sigma_global; /* set this equal to 1 if having estimation difficulties */
  matrix<lower=0>[2,K_sp] sigma_w;
  matrix[num_basis,K_sp] betaraw_w[2];  // vector of B-spline regression coefficients for each predictor, k
                                        // and 2 sample arms
  real<lower=0> phi_w; /* scale parameter in model for -1og(weights) */                                     
} /* end parameters block */

transformed parameters {
  matrix[K,2] beta_x;
  matrix[num_basis,K_sp] beta_w[2];
  
  // for scale parameters for interaction effects from those for main effects to which they link
  for( arm in 1:2 )
  {
     beta_x[,arm]   = betaraw_x[,arm] .* sigma_betax[,arm]; /* Non-central parameterization */
  }/* end loop arm over convenience and reference sample arms */
  
  // spline regression coefficients
  for(arm in 1:2)
  {
    for( k in 1:K_sp ) 
    {
      beta_w[arm][,k]   = cumulative_sum(betaraw_w[arm][,k]);
      beta_w[arm][,k]  *= sigma_w[arm,k] * sigma_global[arm];
    } /* end loop k over K predictors */
  }
  
} /* end transformed parameters block */

 model {
  to_vector(sigma_betax) ~ student_t(n_df,0,3);
  to_vector(sigma_w)     ~ student_t(n_df,0,1);
  sigma_global           ~ student_t(1,0,1);
  sigma_h                ~ student_t(n_df,0,1);
  sigma_q                ~ student_t(n_df,0,1);
  sigma_t                ~ student_t(n_df,0,1);
  phi_w                  ~ student_t(n_df,0,1);

  to_vector(betaraw_x)   ~ std_normal();
  for(arm in 1:2)
    to_vector(betaraw_w[arm])   ~ std_normal();

  logit_h   ~ student_t( n_df, logit(n_c/(n +0.0)), sigma_h );
  logit_t   ~ student_t( n_df, logit(n_r/(N +0.0)), sigma_t );
  logit_q   ~ student_t( n_df, logit(n_c/(N +0.0)), sigma_q );
  
  /* Model likelihood for y, logit_pw */
 // Sum terms 1 to n in the likelihood 
 target            += reduce_sum(partial_sum, s, grainsize, 
                                              logit_pw, K_sp, n_c, n, num_basis, X, G, 
                                              beta_x, beta_w, logit_h, logit_q, logit_t, phi_w); 

} /* end model block */  

generated quantities{
  matrix[n,2] p;
  matrix[n,2] mu_x;
  
  for( arm in 1:2 )
     {
       mu_x[,arm]     = X[,] * to_vector(beta_x[,arm]); /* n x l for each arm */
       // spline term
       for( k in 1:K_sp )
       {
         mu_x[,arm]  += to_vector(beta_w[arm][,k]' * G[k][,1:n]); /* n x 1 */
       } /* end loop k over K predictors */
      
     }/* end loop arm over convenience and reference sample arms */
  
    // enforce sum to n constraint for p[,2]
    {
      vector[n] pi_r       = inv_logit(mu_x[,2]);
      real max_pir;
      pi_r                 = (pi_r / sum(pi_r)) * n;
      max_pir              = max( pi_r );
      if( max_pir > 1 )
        pi_r  = pi_r / (max_pir + 1e-5);
      mu_x[,2]             = logit( pi_r );
    }
  
    for( arm in 1:2 )
    {
      p[,arm]        = inv_logit( mu_x[,arm] );
    }
  
  // smoothed sampling weights for convenience and reference units
  vector[n] weights_smooth_c = inv(p[,1]);
  vector[n] weights_smooth_r = inv(p[,2]);
  // inclusion probabilities in convenience and reference units for convenience units
  // use for soft thresholding
  vector[n_c] pi_c                  = p[1:n_c,1];
  vector[n_c] pi_r_c                = p[1:n_c,2];
  // normalized weights
  weights_smooth_c                  *= ((n_c+0.0)/(n+0.0)) * (sum(weights_smooth_r)/sum(weights_smooth_c));
  weights_smooth_r                  *= ((n_r+0.0)/(n+0.0));
} /* end generated quantities block */