Another long winded post, sorry :) i really just need to start coding this up and see what we are missing. Hopefully tomorrows twitch session will get me jump start this.
I did give a bad explanation there, sorry. In the case of A=add(B,C), A=1, B=2 and C=3. But that is irrelevant now, please ignore.
Ah, I should have asked before diving in the tree structure. I just wanted to get something out there and the idea sounded OK in my head.
I will try again with lists and the same example again. Since I am not really sure how you handle the case of temporary variables I am going to simplify this a bit by assigning names to everything.
matrix[N1,N1] K_1 = cov_exp_quad(x1, alpha, rho)
real sq = square(sigma)
vector[N1] r_v_1= rep_vector(sq, N1)
matrix [N1,N1] K_2 = diag_matrix(r_v_1);
matrix[N1, N1] K = K_1 + K_2;
matrix[N1, N1] L_K = cholesky_decompose(K);
vector[N1] L_K_div_y1 = mdivide_left_tri_low(L_K, y1);
matrix[N1, N1] L_K_div_y1_T = L_K_div_y1’;
vector[N1] K_div_y1 = mdivide_right_tri_low(L_K_div_y1_T, L_K);
row_vector[N1] K_div_y1_T = K_div_y1’;
All indexes use the same convention. 0 denotes the result and function arguments start at 1. I use the term arguments for both in the following text. So for A = add(B,C) 0 = A, 1 = B, 2 = C.
We then have a predefined map of all opencl supported functions with the following information:
-
cl_args
: lits of indexes of arguments that are matrix_cl. If zero is not present that is what I previously called a »selfish« function.
-
req_cl_args
: list of lists of indexes of arguments that must already be on the device in order to »approve« this function. A function matches if it matches any of the lists.
-
fast
: bool, denotes that the function is considered to be »fast« (really need a better name here). If a function is considered fast its arguments (and result) are put in the »maybe« list if the conditions to be moved to the OpenCL are not met.
We will also need to somehow denote if the function supports only doubles or also vars, this is omitted here for now as only var supported matrix_cl overload is the GLMs.
Here a few examples.
normal_id_glm: cl_arg = [1,2], req_cl_arg = [1,2], fast
multiply: cl_arg = [0,1,2], req_cl_arg = [[0], [1], [2]], fast
cov_exp_quad: cl_arg = [0,1] (sigma and length scale are non-matrix-cl), req_cl_arg=[[0], [1]], fast
add: cl_arg = [0,1,2], req_cl_arg = [[0,1], [0,2],[1,2]]
rep_matrix: cl_arg = [0,1], req_cl_arg=[0]
diag_matrix: cl_arg = [0,1], req_cl_arg = [0]
mdivide_left_tri_low: cl_arg = [0,1,2], req_cl_arg = [[0],[1],[2]], fast
mdivide_right_tri_low: same as mdivide_left_tri_low
cholesky_decompose: cl_arg = [0,1], req_cl_arg = [[0],[1]], fast
transpose: cl_arg = [0,1], req_cl_arg = [[0],[1]]
And then we have the 2 global lists of arguments that were moved to the device or are a good candidates to be moved. Both are empty at the start. Lets name them var_CL
and maybe_CL
We start by adding all data to var_CL
. In this example there is not data any so we continue.
Then we start with the passes over the functions:
-
check which arguments are in var_CL
and maybe_CL
(treat them both the same)
-
if the conditions are met (matches in req_cl_arg
):
-
if the conditions are not met and the function is “fast” put the variables not in var_CL
or maybe_CL
, otherwise just go to the next function
Repeat this until var_CL
stops growing. Clear maybe_CL
after each pass.
First pass:
-
We start with K_1 = cov_exp_quad(x1, alpha, rho)
. The list of arguments on the device is empty so no match. The function is »fast« so we add both arguments (0 and 1) to maybe_CL
.
var_CL = []; maybe_CL = [x1, K_1]
-
real sq = square(sigma)
no match in the map for square
so continue
-
vector[N1] r_v_1= rep_vector(sq, N1)
no matching arguments on the device, continue
-
matrix [N1,N1] K_2 = diag_matrix(r_v_1);
same
-
matrix[N1, N1] K = K_1 + K_2;
Only K1 is in »maybe« → [1], would need to match any list in [[0,1], [0,2],[1,2]] so moving on
-
matrix[N1, N1] L_K = cholesky_decompose(K);
no matches but »fast« so add to maybe_CL
New state: var_CL = ; maybe_CL = [x1, K_1, K, L_K]
-
vector[N1] L_K_div_y1 = mdivide_left_tri_low(L_K, y1);
matches because [1] matchs in [[0], [1]]
New state: var_CL = [L_K, y1, L_K_div_y1]; maybe_CL = [x1, K_1, K]
-
matrix[N1, N1] L_K_div_y1_T = L_K_div_y1';
matches because L_K_div_y1 is on the device
New state: var_CL = [L_K, y1, L_K_div_y1, L_K_div_y1_T]; maybe_CL = [x1, K_1, K]
-
vector[N1] K_div_y1 = mdivide_right_tri_low(L_K_div_y1_T, L_K);
matches because both input are on the device
New state: var_CL = [L_K, y1, L_K_div_y1, L_K_div_y1_T, K_div_y1]; maybe_CL = [x1, K_1, K]
-
row_vector[N1] K_div_y1_T = K_div_y1';
matches
New state: var_CL = [L_K, y1, L_K_div_y1, L_K_div_y1_T, K_div_y1, K_div_y1_T]; maybe_CL = [x1, K_1, K]
Second pass:
-
Clear maybe_CL
-
K_1 = cov_exp_quad(x1, alpha, rho)
again adds K_1 and x1 to the maybe_CL list.
New state: var_CL = [L_K, y1, L_K_div_y1, L_K_div_y1_T, K_div_y1, K_div_y1_T]; maybe_CL = [x1, K_1]
-
real sq = square(sigma)
same as first pass
-
vector[N1] r_v_1= rep_vector(sq, N1)
same as first pass
-
matrix [N1,N1] K_2 = diag_matrix(r_v_1);
same as first pass
-
matrix[N1, N1] K = K_1 + K_2;
Only K1 is in »maybe« so again moving on
-
matrix[N1, N1] L_K = cholesky_decompose(K);
L_K is in var_CL so adding K to it
New state: var_CL = [L_K, y1, L_K_div_y1, L_K_div_y1_T, K_div_y1, K_div_y1_T, K]; maybe_CL = [x1, K_1]
-
The rest is already on the device so just goes through.
In the next pass the K_1 and K_2 are added. Then in the next pass r_v_1 and in then sq. And it stops there.
After the last pass we go through it again adding to_matrix_cl on the firt occurence of all arguments in var_CL
. Once we find an occurence we delete them from the list.
Its not a perfect solution by any means, but its a start, maybe :)