PSA: Sufficient-Stats for Gaussian Likelihoods should speed up most Hierarchical models with IID observation subgroups

I previously posted on a speedup for Gaussian likelihoods that is akin to the sufficient-stats trick for binomial likelihoods, but I thought at that time that it was a trick that would rarely prove useful. Yesterday I had the realization that it might be worthwhile for common hierarchical scenarios, so I threw together a benchmark and indeed my hunch was correct:

Now, this speeds up only the likelihood part of a model, so don’t expect total run times of real-world models to achieve the above speedup factors, but it’s a pretty easy trick to apply & worth adding in most scenarios.

I’ll paste the Stan model used for the benchmark below but here are links to the Stan and R code for the benchmark:
main.r (2.4 KB)
plot.r (2.0 KB)
profiling_hierarchical_std_vs_suf.stan (3.4 KB)
sample_and_save.r (1.3 KB)

data{
	// k: number of nominal levels of observation-groupings
	int<lower=1> k ;
	// n: number of observations made per k
	int<lower=1> n ;
	// y: matrix of k-by-n observations
	array[n,k] real y ;
}
transformed data{
	// to make for a comparison that is biased in favor of the standard/traditional
	// approach, we'll flatten y to a vector for a single call to normal_lupdf.
	vector[k*n] y_vec = to_vector(to_array_1d(y)) ;
	// Vectorization of the standard approach requires use of a index variable
	// that keeps track of which observation-group with which each observation
	// is associated. This would be passed in by the user, but we'll create it
	// here by first creating an 2D array akin to y but containing entries
	// consisting of the column index repeated across rows, then flattening this
	// array in the same way we flattened y:
	array[n,k] int k_for_y ;
	for(i_k in 1:k){
		k_for_y[,i_k] = rep_array(i_k,n) ;
	}
	array[k*n] int k_for_y_vec = to_array_1d(k_for_y) ;
	// The new "sufficient" approach requires collapsing the data y to a mean &
	// variance per observation-group. To further bias the comparison in favor
	// of the standard approach, we'll pretend the user is relating these
	// summaries to the model parameters via an indexing array, which is useful
	// in the context of more complex models & missing data. Here the indexing
	// array will simply be of length k and consist of the sequence 1 through k.
	vector[k] obs_mean ;
	vector[k] obs_var ;
	array[k] int one_thru_k ;
	for(i_k in 1:k){
		one_thru_k[i_k] = i_k ;
		obs_mean[i_k] = mean(y[,i_k]) ;
		obs_var[i_k] = variance(y[,i_k]) ;
	}
	// The new sufficient approach uses pre-computable quantities involving the
	// number of observations contributing to each summary. While this is constant
	// in this demo at n, we'll account for potential slowdowns associated with
	// user data having a unique number of observations per summary by using a
	// vector of values for each of the following quantities, repeating the
	// quantity k times.
	vector[k] sqrt_n = rep_vector(sqrt(n*1.0),k) ;
	vector[k] sd_gamma_shape = rep_vector((n - 1.0) / 2.0,k) ;
}
parameters{
	// plain ol' mean-and-sd model for each observation-group with no pooling
	vector[k] pop_mean ;
	vector<lower=0>[k] pop_sd ;
}
model{
	// priors for the parameters:
	pop_mean ~ std_normal() ;
	pop_sd ~ weibull(2,1) ;
	// so that we don't increment target twice, collect the log-liklihoods in
	// two separate local variables inside profile statements to time each:
	real target1 = 0 ;
	real target2 = 0 ;
	profile("std"){
		target1 += normal_lupdf(y_vec |pop_mean[k_for_y_vec],pop_sd[k_for_y_vec]) ;
	}
	profile("suf"){
		target2 += normal_lupdf( obs_mean | pop_mean[one_thru_k] , pop_sd[one_thru_k]./sqrt_n ) ;
		target2 += gamma_lupdf(obs_var | sd_gamma_shape , sd_gamma_shape ./ pow(pop_sd[one_thru_k],2) ) ;
	}
	// only use the standard-computed log-likelihood to increment the target:
	target += target1 ;
}

// If we wanted to check that target1 & target2 are precisely proportional, we
// could have the following GQ section recomputing them:

// generated quantities{
// 	real target1 = 0 ;
// 	real target2 = 0 ;
// 	target1 += normal_lpdf(y_vec |pop_mean[k_for_y_vec],pop_sd[k_for_y_vec]) ;
// 	target2 += normal_lpdf( obs_mean | pop_mean[one_thru_k] , pop_sd[one_thru_k]./sqrt_n ) ;
// 	target2 += gamma_lpdf(obs_var | sd_gamma_shape , sd_gamma_shape ./ pow(pop_sd[one_thru_k],2) ) ;
// }

3 Likes

Nice PSA! Sufficient statistics are a cool way to speed up likelihood computation. Indeed Gaussian linear regression also admits sufficient statistics. You can just precompute \hat{\beta} = (X^T X)^{-1} X^T y instead of doing the computation for each individual data point. Useful if you’re doing a hierarchical regression over different groups. See page 355 of BDA3 for the derivation.

3 Likes

Minor update: I worked out the appropriate jacobian correction necessary to use the chi-square rather than gamma distribution for the variance, which achieves a further 1.2x-1.3x speedup:

And here’s the code for the chi-square likelihood:

	target += (
		normal_lupdf(
			obs_mean
			| pop_mean[one_thru_k]
			, pop_sd[one_thru_k]./sqrt_n
		)
		+ (
			(
				chi_square_lupdf(
					obs_var ./ pow(pop_sd[one_thru_k],2)
					| 1
				)
				- 2*sum(log(pop_sd[one_thru_k]))
			)
			* k_minus_one
		)
	)  ;

1 Like

Quick update: the version using chi-square is not accurate, but the gamma remains accurate. Sorry for any misleading results in the interim. For posterity, here’s how I am now checking these things as I work to add sufficient-stats representations for a number of standard distributions:

First, I have a stan model where I express the standard likelihood in the model, plus candidate replacements in the GQ. Here’s one for the Gaussian with a bunch of attempts at different “sufficient” representations (gamma2_lp and chisq_lp being the ones above):

data {
	int<lower=1> n_obs ;
	vector[n_obs] obs ;
}

transformed data{
	// frequently-used quantities involving n_obs
	real sqrt_n_obs = sqrt(n_obs) ;
	real n_obs_div_2 = n_obs/2.0  ;
	real n_obs_minus_one = (n_obs - 1)  ;
	real n_obs_minus_one_div_two = n_obs_minus_one/2.0  ;

	real obs_mean = mean(obs) ;
	real obs_sum = sum(obs) ;
	real obs_sum_squared_deviations = sum(square(obs-obs_mean)) ;
	real obs_mean_squared_deviations = obs_sum_squared_deviations / n_obs ; // note denominator!
	real obs_variance = obs_sum_squared_deviations / n_obs_minus_one; // note denominator!
}

parameters {
	real mu ;
	real<lower=0> sigma ;
}

transformed parameters{
	real likelihood_lp = normal_lpdf(obs|mu,sigma) ;
}

model {
	mu ~ std_normal() ;
	sigma ~ weibull(2,1) ;
	target += likelihood_lp ;
}

generated quantities{
	// computing lp's for sufficient-statistics alternatives to the standard approach used in the model block
	//		If an alternative expresses the equivalent information as the standard approach, then the value
	//		for it's lp should be precisely proportional to the standard lp, the latter being accessed as the
	//		`lp__` quantity in the posterior. A check for precise proportionality is achieved by computing
	//		the correlation across draws between a candidate alternative's lp and lp__; when precisely
	//		proportional this correlation should be nearly 1, with any deviation from one being minimal and
	//		due to numeric representation error.
	//		Note: we'll compute quantities with a `_` suffix first, then combine to derive the lp values that
	//		are to be used to test precise proportionality to lp__


	// computing the term for the mean as it is used by all sufficient-stats alternatives
	real obs_mean_lp_ = normal_lpdf( obs_mean | mu, sigma / sqrt_n_obs ) ;

	// computing the variance term for the gamma representation
	real gamma_lp_nojacobian = (
		gamma_lpdf( obs_sum_squared_deviations / square(sigma) | n_obs_minus_one_div_two , 2 )
	) ;

	// computing the variance term for the inverse-gamma representation
	real invgamma_lp_nojacobian = (
		inv_gamma_lpdf( square(sigma) / obs_sum_squared_deviations | n_obs_minus_one_div_two, .5)
	) ;

	// computing the variance term for the chi-square representation
	real chisq_lp_nojacobian = (
		// chi_square_lpdf( n_obs * obs_sum_squared_deviations / square(sigma) | n_obs )
		chi_square_lpdf( obs_variance / square(sigma) | 1 )
	) ;

	// computing the variance term for the inverse-chi-square representation
	real invchisq_lp_nojacobian = (
		inv_chi_square_lpdf( square(sigma) | n_obs)
	) ;

	real invgamma2_lp_	= inv_gamma_lpdf( obs_sum_squared_deviations | n_obs_div_2, n_obs_div_2 * square(sigma)) ;
	real invgamma3_lp_	= inv_gamma_lpdf( obs_variance | n_obs_minus_one_div_two , n_obs_minus_one_div_two * square(sigma) ) ;
	real gamma2_lp_	= gamma_lpdf( obs_variance | n_obs_minus_one_div_two , n_obs_minus_one_div_two / square(sigma) ) ;

	// When a given lp above includes as the first argument to the lpdf a quantity that implies a
	// non-reducing-and-non-linear transforms of a parameter, the jacobian determinant of that transform
	// must be expressed here so it can be added and thereby accounted for to achieve precise proportionality
	// to the standard lp__
	// real gamma_lp_jacobian = log(2 * obs_sum_squared_deviations / pow(sigma,3)) ;
	real gamma_lp_jacobian = -log(pow(sigma,3)) ; // excludes constants: log(2 * obs_sum_squared_deviations / pow(sigma,3)) = log( 2 * obs_sum_squared_deviations) -log(pow(sigma,3)) ~~ -log(pow(sigma,3))

	// real invgamma_lp_jacobian = log(2 * sigma / obs_sum_squared_deviations) ;
	real invgamma_lp_jacobian = log(sigma) ; // excludes constants; log(2 * sigma / obs_sum_squared_deviations) = log(sigma) + log(2 / obs_sum_squared_deviations) ~~ log(sigma)

	real chisq_lp_jacobian = -log(pow(sigma,3)) ; // excludes constants; log(2 * n_obs * obs_sum_squared_deviations / pow(sigma,3)) = log(2 * n_obs * obs_sum_squared_deviations) - log(pow(sigma,3))

	real invchisq_lp_jacobian = log(sigma); // excludes constants; log(2*sigma)  = log(2) + log(sigma) ~~ log(sigma)


	// adding the mean term to the variance terms for each alternative, resulting in the lp quantities
	// to be tested for precise proportionality to lp__
	real gamma_lp = obs_mean_lp_ + gamma_lp_nojacobian + gamma_lp_jacobian ;
	real invgamma_lp = obs_mean_lp_ + invgamma_lp_nojacobian + invgamma_lp_jacobian ;
	real chisq_lp = obs_mean_lp_ + chisq_lp_nojacobian + chisq_lp_jacobian ;
	real invchisq_lp = obs_mean_lp_ + invchisq_lp_nojacobian + invchisq_lp_jacobian ;
	real gamma2_lp = obs_mean_lp_ + gamma2_lp_;
	real invgamma2_lp = obs_mean_lp_ + invgamma2_lp_;
	real invgamma3_lp = obs_mean_lp_ + invgamma3_lp_;
}

Then I sample at a minimal n_obs and a large n_obs and check correlations with liklihood_lp:


library(tidyverse)
chains = parallel::detectCores()/2

mod = cmdstanr::cmdstan_model(
	'stan/tmp2.stan'
	, cpp_options = cpp_options
	, stanc_options = stanc_options
	# , force = T
)

fit2e0 = mod$sample(
	data = lst(
		n_obs = 2e0
		, obs = rnorm(n_obs)
	)
	# , seed = 1
	, chains = chains*10
	, parallel_chains = chains
	, refresh = 0
	, diagnostics = NULL
	, show_messages = FALSE
	, show_exceptions = FALSE
	, sig_figs = 18
)

fit2e2 = mod$sample(
	data = lst(
		n_obs = 2e2
		, obs = rnorm(n_obs)
	)
	# , seed = 1
	, chains = chains*10
	, parallel_chains = chains
	, refresh = 0
	, diagnostics = NULL
	, show_messages = FALSE
	, show_exceptions = FALSE
	, sig_figs = 18
)

get_correlation_col = function(fit,fit_name){
	(
		fit$draws(format = 'draws_df')
		%>% as_tibble()
		%>% select('likelihood_lp',ends_with('_lp'))
		%>% as.matrix()
		%>% cor()
		%>% as.tibble()
		%>% slice(1)
		%>% select(-likelihood_lp)
		%>% pivot_longer(everything())
		%>% mutate(fit=fit_name)
	)
}

(
	bind_rows(
		get_correlation_col(fit2e0,'fit2e0')
		, get_correlation_col(fit2e2,'fit2e2')
	)
	%>% pivot_wider(names_from=fit)
	%>% arrange(desc(fit2e0),desc(fit2e2))
) %>% print()

In this case, yielding:

# A tibble: 7 × 3
  name         fit2e0  fit2e2
  <chr>         <dbl>   <dbl>
1 gamma2_lp     1      1     
2 invgamma_lp   0.948  0.999 
3 chisq_lp      0.948  0.717 
4 invchisq_lp   0.901  0.157 
5 invgamma3_lp  0.880  0.992 
6 invgamma2_lp  0.771 -0.0570
7 gamma_lp      0.545 -0.0237
1 Like