Hello to all,
I am new to stan and currently trying to implement the standard example on LDA from the user guide in R which can be found here:
https://mc-stan.org/docs/2_18/stan-users-guide/latent-dirichlet-allocation.html
Although the data is only 50x643 sampling takes a lot of time (>1 hour for 2000 iterations) compared with other implementations in R (e.g. the topicmodels package takes <1min for 10000 Gibbs samples). Also, I always get a lot of divergent transitions. I tried setting adapt_delta = 0.99 but i did not help.
So I guess that I am doing some basic mistakes here. I would be very thankful if anybody would be able the help me.
Here’s some R-code that has the described problem:
library(tm)
library(rstan)
library(slam)
library(shinystan)
library(topicmodels)
options(mc.cores = parallel::detectCores())
#get data
data("AssociatedPress", package = "topicmodels")
#cut data
dtm = AssociatedPress[1:50,]
dtm = removeSparseTerms(dtm, 0.95)
dim(dtm)
#parameter
N_TOPICS = 2
#model
model.code =
"
data {
int<lower=2> K; // num topics
int<lower=2> V; // num words
int<lower=1> M; // num docs
int<lower=1> N; // total word instances
int<lower=1,upper=V> w[N]; // word n
int<lower=1,upper=M> doc[N]; // doc ID for word n
vector<lower=0>[K] alpha; // topic prior
vector<lower=0>[V] beta; // word prior
}
parameters {
simplex[K] theta[M]; // topic dist for doc m
simplex[V] phi[K]; // word dist for topic k
}
model {
for (m in 1:M)
theta[m] ~ dirichlet(alpha); // prior
for (k in 1:K)
phi[k] ~ dirichlet(beta); // prior
for (n in 1:N) {
real gamma[K];
for (k in 1:K)
gamma[k] = log(theta[doc[n], k]) + log(phi[k, w[n]]);
target += log_sum_exp(gamma); // likelihood;
}
}
"
data = list(K = N_TOPICS,
V = dim(dtm)[2],
M = dim(dtm)[1],
N = sum(dtm$v),
w = rep(dtm$j,dtm$v),
doc = rep(dtm$i,dtm$v),
alpha = rep(50/N_TOPICS,N_TOPICS), #according to Griffiths and Steyvers(2004)
beta = rep(0.1,dim(dtm)[2]) #according to Griffiths and Steyvers(2004)
)
stan.model <- stan_model(model_code = model.code)
stan.model.fit <- sampling(stan.model,
data = data,
#control = list(adapt_delta = 0.99), #this does not help
iter = 1000,
chains = 5,
seed = 2019)