Too slow on cmdstan with gpu than pystan on cpu

The configuration of my computer is

  • Ubuntu 20.04
  • NVIDIA 3090
  • i7-9700KF
  • CUDATOOLKIT 11.5 (updated from 10.1 which is the same)
  • latest cmdstan and pystan3

I run the code on the cpu with pystan 3 which takes about 2 minutes. Since the data is only 1% of my full data, I want to improve the speed with GPU. But the code nearly get stucked with cmdstan on GPU (sometime it did run some iteration this morning, but now it totally get stucked, finished 100 iteration when I write this, cost 19.24 min). I tried the example code provided by cmdstan and the code here (https://github.com/bstatcomp/stan_gpu_install_docs.git), it works fine. Thus the working environment might be right. I tried my best to vectorize my code, so what’s my misunderstanding of stan? I commented out the prior for sigma, but it seems make no difference.


significance.stan (4.4 KB)
data.txt (175.7 KB)


Welcome to Discourse! My guess is that the vectorized normal_lpdf calls are very cheap even on the CPU, and that most of your computation is happening inside the for-loops in transformed parameters, which are not amenable to speedup on the GPU. You could use profiling with cmdstan to check.

There is some overhead for using the GPU, though I don’t have a feeling for whether it should be expected to approach the slowdown you are reporting here. Perhaps @rok_cesnovar or @stevebronder has more intuition or advice.

3 Likes

how about you profile your model first before embarking on GPU?

Also try to use the normal - glm functions, which are heavily optimised for the GPU.

Another friend is reduce_sum for CPU based within-chain parallelism.

5 Likes

I would also suggest starting with profiling. Lets first see which parts of the code run the slowest. If its the transformed parameters, then the current state of the GPU backend will not be of much help I am afraid.

If its the model block then we should be able to figure something out.

1 Like

Thanks! I’ll try profiling and @rok_cesnovar and @wds15 's suggestions. I have tried getting rid of all the for loops, it takes more than one hour to finish which is far slow than the 2 minutes than on cpu.

The updated model is

data {
    int<lower=0> N_CONDITION;
    int<lower=0> N_TREATMENT;
    int<lower=0> N; 
    
    //int<lower=0, upper=N_TREATMENT> disease_condition_seg[N_CONDITION+1];
    //int<lower=0, upper=N> treatment_seg[N_TREATMENT+1];
    matrix<lower=0, upper=1>[N_TREATMENT, N_CONDITION] treatment_condition_onehot;
    matrix<lower=0, upper=1>[N, N_TREATMENT] observation_treatment_onehot;

    vector<lower=0, upper=3>[N] concentration;
    vector[N] y;
} 

parameters {
    real mu_overarch_0;
    real mu_overarch_1;
    
    vector[N_CONDITION] mu_condition_0;
    vector[N_CONDITION] mu_condition_1;
    
    vector[N_TREATMENT] mu_observation_0;
    vector[N_TREATMENT] mu_observation_1;

    real<lower=0> sigma_overarch_0;
    real<lower=0> sigma_overarch_1;
    real<lower=0> sigma_treatment_0;
    real<lower=0> sigma_treatment_1;
    real<lower=0> sigma_obs;

    // vector<lower=0, upper=2>[N_CONDITION] sigma_overarch_0;
    // vector<lower=0, upper=2>[N_CONDITION] sigma_overarch_1;
    // // vector<lower=1, upper=50>[N_CONDITION] nu_overarch_0;
    // // vector<lower=1, upper=50>[N_CONDITION] nu_overarch_1;
    
    // vector<lower=0, upper=2>[N_TREATMENT] sigma_treatment_0;
    // vector<lower=0, upper=2>[N_TREATMENT] sigma_treatment_1;
    // // vector<lower=1, upper=50>[N_TREATMENT] nu_treatment_0;
    // // vector<lower=1, upper=50>[N_TREATMENT] nu_treatment_1;
    
} 

transformed parameters {

    
    // vector[N_TREATMENT] mu_treat_0;
    // vector[N_TREATMENT] mu_treat_1;
    
    // vector[N] mu_obs_0;
    // vector[N] mu_obs_1;
    // vector[N] mu_obs;
    
    // vector[N] nu_obs;
    // vector[N] sigma_obs;

    
    
    // treatment_condition_onehot with shape (N_TREATMENT, N_CONDITION)
    // mu_treat_0 = treatment_condition_onehot * mu_condition_0;
    // mu_treat_1 = treatment_condition_onehot * mu_condition_1;


    // observation_treatment_onehot with shape (N, N_TREATMENT)
    // mu_obs_0 = observation_treatment_onehot * mu_observation_0;
    // mu_obs_1 = observation_treatment_onehot * mu_observation_1;

    // sigma_obs = observation_treatment_onehot * sigma_treatment_0;
        
    
    // mu_obs = mu_obs_0 + mu_obs_1 .* concentration;
    // mu_obs = observation_treatment_onehot * mu_observation_0 + (observation_treatment_onehot * mu_observation_1) .* concentration;
}

model {

    // mu_overarch_0 ~ normal(0, 1);
    target += normal_lpdf(mu_overarch_0 | 0, 1);
    // mu_overarch_1 ~ normal(0, 1);
    target += normal_lpdf(mu_overarch_1 | 0, 1);

    
    //sigma_overarch_0 ~ inv_gamma(1, 1);
    target += inv_gamma_lpdf(sigma_overarch_0 | 1, 1);
    // sigma_overarch_1 ~ inv_gamma(1, 1);
    target += inv_gamma_lpdf(sigma_overarch_1 | 1, 1);
    // sigma_treatment_0 ~ inv_gamma(1, 1);
    target += inv_gamma_lpdf(sigma_treatment_0 | 1, 1);
    // sigma_treatment_1 ~ inv_gamma(1, 1);
    target += inv_gamma_lpdf(sigma_treatment_1 | 1, 1);

    // beta_0
    // mu_condition_0 ~ normal(mu_overarch_0, sigma_overarch_0);
    target += normal_lpdf(mu_condition_0 | mu_overarch_0, sigma_overarch_0);
    // mu_observation_0 ~ normal(mu_treat_0, sigma_treatment_0);
    // target += normal_lpdf(mu_observation_0 | mu_treat_0, sigma_treatment_0);
    target += normal_lpdf(mu_observation_0 | treatment_condition_onehot * mu_condition_0, sigma_treatment_0);
    
    // beta_1
    // mu_condition_1 ~ normal(mu_overarch_1, sigma_overarch_1);
    target += normal_lpdf(mu_condition_1 | mu_overarch_1, sigma_overarch_1);
    // mu_observation_1 ~ normal(mu_treat_1, sigma_treatment_1);
    // target += normal_lpdf(mu_observation_1 | mu_treat_1, sigma_treatment_1);
    target += normal_lpdf(mu_observation_1 | treatment_condition_onehot * mu_condition_1, sigma_treatment_1);
    
    // beta_0 + beta_1 * concentration
    // y ~ normal(mu_obs, sigma_obs);
    // target += normal_lpdf(y | mu_obs, sigma_obs);
    target += normal_lpdf(y | observation_treatment_onehot * mu_observation_0 + (observation_treatment_onehot * mu_observation_1) .* concentration, sigma_obs);

}

Thus, there isn’t transformed parameters anymore.