Chains not mixing in MRP problem

I’m currently working on an MRP problem using the framework from this paper by Yajuan Si. I know some of the people who work on MRP with Stan participate in this forum so I’m hoping they can take a look at my code.

The main idea is that I want to estimate an outcome y at a subnational level (like county) where the full joint distribution of poststratification variables is not available; the above paper shows how to estimate the N-sizes using only marginal distributions, which are known. In order to make sure I understand how this model works, I’m applying it in a situation where I know the joint distribution of the variables at a higher level (here, by estimating y at the state level).

Let \mathbf{N}_{..} be a D-dimensional vector containing marginal N-sizes for each level of the poststratification variables (like age, sex, etc.) and \mathbf{N} be a J-dimensional vector containing N-sizes for every poststratification cell. In the general problem, \mathbf{N}_{..} is known while \mathbf{N} needs to be estimated.

My full model is:

\mathbf{N}_{..} \sim \textrm{Poisson}(\mathbf{L}\mathbf{N}) \\ n_j \sim \textrm{Poisson}\left(\dfrac{nN_jp_j}{\sum N_jp_j}\right) \\ \textrm{logit}(p_j) = \mathbf{X}^{j}\boldsymbol\alpha \\ y_i \sim \textrm{Binomial}(\textrm{trials}_i, \theta_{j[i]}) \\ \textrm{logit}(\theta_j) = \beta_0 + \mathbf{X}^{j}\boldsymbol\beta \\

where n_j is the observed cell count in the survey, n, is the total number of respondents in the survey, p_j is the probability of a respondent from cell N_j being included in the survey, and \mathbf{L} is a loading matrix relating the poststratification cells to the margins.

I’m coding this model with the following Stan code. In this simple example, my poststratificaiton variables are age and state. Because I actually do know the true N_j when getting state estimates, I put priors on the N_j reflecting this knowledge (again, I want to make sure I’m coding the model correctly when I use it for the full county level model where this information will be truly missing and weaker priors will be put on the N_j).

data {
  int <lower=0> D; // number of marginals
  int<lower=0> J; // number of toal poststratification cells
  int<lower=0> J_obs; // number of observed cells in the survey
  
  matrix[D, J] L; // loading matrix
  array[D] int<lower=0> N_marginal;
  
  array[J_obs] int<lower=0> trials; // size of binomial trials
  array[J_obs] int<lower=0> y_obs; // number of successes
  array[J] int obs_cell_size; // total of responses in sample cells, 0 if not observed
  
  array[J_obs] int age_idx; //index of age levels in sample
  array[J] int age_idx_p; // index of age levels in full joint distribution of ps variables
  array[J_obs] int state_idx;
  array[J] int state_idx_p;
  
  int<lower=1> full_age; // total number of age levels
  int<lower=1> full_state;
  
  int<lower=1> ps; // number of ps outcomes, for state ps = 51 (50 states + DC)
  matrix[ps, J] L_ps;
  
  array[J] int true_n; // to put priors on Nhat

}

transformed data{
  int n = sum(trials);
}

parameters {
  vector<lower=0>[J] Nhat;
  
  vector[full_age] raw_age_p;
  real<lower=0> sigma_age_p;
  vector[full_age] raw_age;
  real<lower=0> sigma_age;
  
  vector[full_state] raw_state_p;
  real<lower=0> sigma_state_p;
  vector[full_state] raw_state;
  real<lower=0> sigma_state;
  
  real b0;
}

transformed parameters{
  
  vector[full_age] age_p = raw_age_p*sigma_age_p;
  vector[full_age] age = raw_age*sigma_age;
  vector[full_state] state_p = raw_state_p*sigma_state_p;
  vector[full_state] state = raw_state*sigma_state;
  
  vector[J] theta_p = inv_logit(b0 + age_p[age_idx_p] + state_p[state_idx_p]);
  vector[J_obs] theta_y = inv_logit(age[age_idx] + state[state_idx]);
}

model {
  
  b0 ~ normal(-10, 3); // figure it's unlikely to be in the survey so ensure theta_p is small
  
  raw_age_p ~ std_normal();
  raw_age ~ std_normal();
  raw_state_p ~ std_normal();
  raw_state ~ std_normal();
  
  for(i in 1:J){
    if(true_n[i] == 0){
      Nhat[i] ~ std_normal(); // none should be 0, but if they are missing from the Census, assumed close to 0
    } else{
      Nhat[i] ~ normal(true_n[i], 2);
    }
  }
  
  N_marginal ~ poisson(L*Nhat);
  vector[J] obs_cell_lambda = n*(theta_p .* Nhat) / sum(theta_p .* Nhat);
  obs_cell_size ~ poisson(obs_cell_lambda);
  y_obs ~ binomial(trials, theta_y);

}

generated quantities {
  vector[J] theta_y_pred = inv_logit(age[age_idx_p] + state[state_idx_p]);
  real y_mean = sum(theta_y_pred .* Nhat) / sum(Nhat);
  vector[ps] y_ps = L_ps*(theta_y_pred .* Nhat) ./ (L_ps*Nhat);
  vector[D] Nmarginal_rep = L*(theta_y_pred .* Nhat) ./ (L*Nhat);
}

The problem I’m having is that I’m getting absolutely no mixing in the chains.

Rplot

I didn’t want to post large Census files to make a reproducible example, but I’m hoping there’s something obviously wrong with my Stan code that’s causing this behavior. What am I missing?

Your model is definitely outside my domain knowledge but when I see chains not mixing I usally go back to the priors first. I see a lot of std_normal() and I would ask are those reasonable priors.