PSA: where possible, use columns_dot_product rather than rows_dot_product

Hey folks, just thought I’d share a neat performance optimization that I haven’t seen mentioned directly anywhere. As the manual notes, matrices are stored in column-major fashion, which leads to faster performance for retrieving whole columns as opposed to retrieving whole rows. I realized this weekend that the standard use of rows_dot_product() for regression models, especially in hierarchical contexts, might achieve free performance improvements by switching to columns_dot_product(), and indeed this seems to be the case!

To prove this to myself, I used the awesome new(ish) profiling features and a model where a dot-product is being computed between a data variable and a parameter variable (as is the most common use scenario I’d guess):

data{
	int n ; // number of dot-products to compute
	int m ; // input size for each dot-product
	int<lower=0,upper=1> order ; // for ensuring order of operations doesn't matter (ex. cacheing) 
}
transformed data{
	matrix[n,m] X ; // to be filled with random "data"
	for(i_m in 1:m){
		for(i_n in 1:n){
			X[i_n,i_m] = std_normal_rng() ;
		}
	}
	matrix[m,n] tX = transpose(X) ;
}
parameters{
	matrix[n,m] Y ;
	matrix[m,n] tY ;
}
model{
	// "priors" necessary so sampler doesn't balk
	to_vector(Y) ~ std_normal() ; 
	to_vector(tY) ~ std_normal() ;
	//pre-allocate output variables
	vector[n] Z ;
	row_vector[n] tZ ;
	if(order==0){ // do rows first
		profile("rows"){
			Z = rows_dot_product(X,Y);
		}
		profile("cols"){
			tZ = columns_dot_product(tX,tY);
		}
	}else{ // do columns first
		profile("cols"){
			tZ = columns_dot_product(tX,tY);
		}
		profile("rows"){
			Z = rows_dot_product(X,Y);
		}
	}
}

And some R code to run single-iteration-with-no-adaptation “sampling” runs so we’re in a reasonably-close-to-real-world-use scenario:

library(tidyverse)
initialize_purrr_progress = function(x){
	.pp <<- progress::progress_bar$new(
		format = "[:bar] :percent eta: :eta",
		total = length(x), clear = FALSE, width= 60
	)
	return(x)
}
finalize_purrr_progress = function(x){
	y = nrow(x) #without this, lazy eval of pipes causes error
	dt = difftime(Sys.time(),.pp$.__enclos_env__$private$start)
	cat('Operations took ',round(dt),' ',attr(dt,'units'))
	return(x)
}
# compile the model
mod = cmdstanr::cmdstan_model('stan/profiling_dots.stan')
# simulate and return timings
(
	list(
		n = 1:10
		, m = 1:10
		, iter = 1:1e2
	)
	%>% cross()
	%>% initialize_purrr_progress()
	%>% map_dfr(
		.f = function(x){
			this_times = NULL
			capture.output({
				try({
					fit = mod$sample(
						data=list(
							n = 2^x$n
							, m = 2^x$m
							, order = x$iter%%2
						)
						, chains = 1
						, iter_warmup = 0
						, iter_sampling = 1
						, adapt_engaged = FALSE
						, refresh = 0
						, show_messages = FALSE
					)
					(
						fit$profiles()[[1]]
						%>% as_tibble()
						%>% select(name,total_time)
						%>% rename(value=total_time)
						%>% mutate(n=2^x$n,m=2^x$m,iter=x$iter)
					)-> this_times
				})
			})
			.pp$tick()
			return(this_times)
		}
	)
	%>% finalize_purrr_progress()
) -> all_times

Yielding this plot of the timing (facets mapped to the size of each dot-product m, lines joining medians, boxes spanning 50% quantiles, lines spanning 80% quantiles):

and comparing the times via ratio (cols/rows):

Now, obviously there’s some noise in these results (ran it overnight but possibly my system had some automated background tasks intermittently causing delays), and for small data scenarios, there’s not a substantial speedup, but for folks struggling with big models where the dot-product operation forms a substantial portion of the compute time, switching to columns_dot_product() should yield decent speedup.

Note however that the above assumes a model structure that can be set up to accommodate columns_dot_product() at the outset, meaning if you make the switch you will want to structure the variables appropriately. That is, ensure your model is structured to sample tY so you can do simply columns_dot_product(tY,tX) and certainly don’t sample Y and do columns_dot_product(transpose(Y),tX) as it’s likely that the transpose operation will eat up some if not all the performance gain.

Code for plots:

(
	all_times
	%>% group_by(n,m,name)
	%>% summarise(
		med = median(value)
		, lo80 = quantile(value,.1)
		, hi80 = quantile(value,.9)
		, lo50 = quantile(value,.3)
		, hi50 = quantile(value,.7)
	)
	%>% ggplot()
	+ facet_wrap(.~m)
	+ geom_line(
		aes(
			x = n
			, y = med
			, color = name
		)
	)
	+ geom_linerange(
		aes(
			x = n
			, ymin = lo50
			, ymax = hi50
			, color = name
		)
		, size = 3
		, alpha = .5
	)
	+ geom_linerange(
		aes(
			x = n
			, ymin = lo80
			, ymax = hi80
			, color = name
		)
		, size = 1
		, alpha = .5
	)
	+ scale_x_log10()
	+ scale_y_log10()
	+ labs(
		y = 'Seconds'
		, x = 'n (number of dot-products)'
		, color = 'method'
	)
	+ theme(
		aspect.ratio = 1
	)
)
(
	all_times
	%>% pivot_wider()
	%>% mutate(
		value = cols/rows
	)
	%>% group_by(n,m)
	%>% summarise(
		med = median(value)
		, lo80 = quantile(value,.1)
		, hi80 = quantile(value,.9)
		, lo50 = quantile(value,.3)
		, hi50 = quantile(value,.7)
	)
	%>% ggplot()
	+ facet_wrap(~m)
	+ geom_line(
		aes(
			x = n
			, y = med
		)
	)
	+ geom_linerange(
		aes(
			x = n
			, ymin = lo50
			, ymax = hi50
		)
		, size = 3
		, alpha = .5
	)
	+ geom_linerange(
		aes(
			x = n
			, ymin = lo80
			, ymax = hi80
		)
		, size = 1
		, alpha = .5
	)
	+ scale_x_log10()
	+ geom_hline(yintercept=1,linetype=3)
	+ labs(
		y = 'Cols time / Rows time'
		, x = 'n (number of dot-products)'
	)
	+ theme(
		aspect.ratio = 1
	)
)

9 Likes

Very nice PSA. This post somehow escaped me. Thanks Mike!

1 Like