Calculate bayesian mean using stan with output matching the dynamic input tensor shape

I am working on pystan==3.6.0, Python 3.8.11
I have a list of numpy.ndarray (tensors) called tensor_values (which can be of variable shape, However, all the tensors in the list will be of same shape - I will show the range of shapes at the end) and a list of float values of weights called weights_norm_1
This is my input data

    if isinstance(tensor_values[0], np.ndarray):
        max_value = len(tensor_values[0]) #Shape_of_one_of_input_tensor to give stan
    else:
	max_value = 1
        
    stan_data = {
        'N': len(tensor_values),
        'n': len(weights_norm_1),
        'P': max_value, #Shape_of_one_of_input_tensor
        'p': [t.flatten() if t.ndim > 0 else np.array([t]) for t in tensor_values],  
        'lambda': 1,
        'mu_c': 1,
        'weights': np.array(weights_norm_1)
    }

My Stan code is as follows

data {
  int<lower=1> N;          // Number of input values
  int<lower=1> n;          // length of weight_norm_1
  int<lower=1> P;          // Number of parameters in each tensor
  vector[P] p[N];          // Parameters for each input tensor
  real<lower=0> lambda;    // Variance of each input tensor's parameters
  real mu_c;               // Mean of the bias for each input tensor
  vector[n] weights;	   // Weights for each input tensor, computed externally weight_norm_1
}

parameters {
  vector[P] p_m;           // Global mean parameter
  vector[N] b_c;           // Bias for each input tensor
}

model {
  vector[n] sigma_c;	   // Standard deviations for input tensor biases

  // Priors
  p_m ~ normal(0, 1);	   // Prior on the global mean parameters

  // Compute sigma_c based on weights
  for (i in 1:n) {
    sigma_c[i] = 1 - weights[i];
  }

  b_c ~ normal(mu_c, sigma_c);  // Prior on biases

  // Likelihood
  for (i in 1:N) {
    p[i] ~ normal(p_m * b_c[i], lambda);  // Model for each input tensors parameters
  }
}

Python code that handles the output of stan

    posterior = stan.build(model_code, data=stan_data)
    fit = posterior.sample(num_chains=4, num_samples=1000)
    print(f"  fit.summary: {fit}")
    

I want to calculate bayesian mean, the output should have the same dimensions as of the list of tensor_values. How can I fix the code to be able to handle the input and output as required?

Following is the list of all the potential input tensor shapes

Shape of element: ()
Shape of element: (1,)
Shape of element: (128,)
Shape of element: (128, 128, 3, 3, 3)
Shape of element: (128, 256, 1, 1, 1)
Shape of element: (128, 256, 3, 3, 3)
Shape of element: (128, 64, 3, 3, 3)
Shape of element: (256,)
Shape of element: (256, 128, 3, 3, 3)
Shape of element: (256, 256, 3, 3, 3)
Shape of element: (256, 512, 1, 1, 1)
Shape of element: (256, 512, 3, 3, 3)
Shape of element: (32,)
Shape of element: (32, 32, 3, 3, 3)
Shape of element: (32, 4, 3, 3, 3)
Shape of element: (32, 64, 1, 1, 1)
Shape of element: (4,)
Shape of element: (4, 64, 3, 3, 3)
Shape of element: (512,)
Shape of element: (512, 256, 3, 3, 3)
Shape of element: (512, 512, 3, 3, 3)
Shape of element: (64,)
Shape of element: (64, 128, 1, 1, 1)
Shape of element: (64, 128, 3, 3, 3)
Shape of element: (64, 32, 3, 3, 3)
Shape of element: (64, 64, 3, 3, 3)

Hi, @Irfan_khan, and thanks for moving the discussion here from the cmdstanpy issue: dynamic and variable tensor shape handling in stan · stan-dev/cmdstanpy · Discussion #748 · GitHub

This post says “pystan” but the issue was opened for cmdstanpy. I’d suggest using cmdstanpy for applied problems if you only need the draws and a summary.

Stan requires the same shape and size of parameters each iteration. You supplied some relatively simple Stan code and then said there’s a list of potential input tensor shapes. How do those shapes relate to either data or parameter values in your Stan code?

I’m afraid the answer is probably going to be that you can’t do it, at least not without a heroic amount of low-level array packing and unpacking. Stan just isn’t designed to let you write code like numpy that’s polymorphic over different array shapes.