Feature proposal: Expose fast logp gradient function in httpstan

httpstan exposes the gradient function of the log-density as Model.log_prob_grad, but this interface comes with a very large call overhead.

If we want to evaluate lots of logp gradients from an external library (such as a different sampler), we could use a faster way to access the gradient function.

I would like to propose an interface that returns a C function pointer for this purpose. A POC implementation of this can be found here.

This provides several new methods for httpstan extension modules:

// Return a pointer to a context data structure that stores temporary data necessary
// for the logp evaluation:
void* new_logp_ctx(data)

// Destroy a context
free_logp_ctx(void *ctx)

// Return a function pointer that computes logp gradients
void* logp_gradient_function_ctx(void *ctx)

// The signature of the returned function pointer from the previous function is
int logp_gradient(
    size_t ndim,
    const double *unconstrained_parameters,
    double *gradient,  // output variable
    double *logp,  // output variable
    void *ctx
);
// return value of 0 indicates success, positive return value an invalid argument of some kind (ie divergence), negative value indicates an error where we should stop sampling.
// TODO what about the iparam argument in stan? Is it used?

// Not implemented yet...
// Return a pointer to an error retrieving function
void *retrieve_error_function_ctx(void *ctx)

// C-Function:
// This function should have signature
int retrieve_last_error_ctx(char *error, int maxlength, void *ctx);

py::tuple<py::array_t<double>, py::array_t<py::int>> py::array write_array_ctx(
    void *ctx,
    py::array_t<double> unconstrained_parameters,
    bool include_tparams,
    bool include_gqs,
    int seed
);

On the python side, the void * types can be represented as an integer.

If I understand the limitations due to the global stack allocator correctly, we should also require that a ctx is only used in the thread where it was created, so that we can initialize the global allocator when we create the ctx.

3 Likes

@WardBrian @andrjohns @rok_cesnovar

I know @Bob_Carpenter has been working on repl which lets you instantiate a model and get things like gradients, with the exact purpose being things like allowing other algorithms to be developed in languages like Python.

This is based off some work @syclik and @dmuck did:

This repl approach will likely still have some overhead compared to exposing a function pointer due to the IO involved, but it is also allows us to have a much nicer interface which is less tied to the C-ABI. Either of them should be faster or easier to use than one which needs the data each time, though

1 Like

That looks like a cool project. :-)

I’m not sure I would really consider this easier to integrate into other samplers than a C-API though.

The logp gradient really is a quite simple thing after all: a function that needs to know something about the model (hence the context), needs to return a gradient (ie one writable *double), a logp value (ie another writable *double), and needs to know the point in parameter space (ie one const *double). It also might error out, hence the int return value. Any sampler pretty much has to already deal with those things anyway, and I don’t see why those would change significantly in the future. Even if you involve GPUs I don’t see why you couldn’t still provide this interface, it just might just not be the most efficient one anymore (but that would definitely also apply to the other one).

If we want to use inter process communication to compute the logp gradient in a different process, we have to do a couple of things we don’t even need to think about in the C-API:

  • Start a server process, or possibly several if we want to sample in parallel
  • Think about on which cores those processes might be running. We want the process that samples to run on the same core as the process that computes the logp. Does the operating system figure that out on its own or do we need to specify this?
  • Initialize a pipe between the processes. Do we use loopback? UNIX sockets? inherited file descriptors?
  • Implement some protocol for talking to the server. How do we serialize our data? Do we use streaming or datagram sockets? If we use streaming sockets, we need to figure out where messages end. The protocol must also be able to handle errors. How much impact does serialization cost have on the runtime?

The fastest way to send messages between two processes without shared memory that I could figure out still takes about 1.5s to send 100_000 messages back and forth (I would actually have guessed that this would be quite a bit more…). This is without any serialization, parallelization or any other of the complications that will come on top of that. My WIP sampler using the C API can run 20 chains of a simple radon model in that time.


Process 1:

import socket

recv = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
send = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)

try:
    os.remove('./tmp1.socket')
except:
    pass

try:
    os.remove('./tmp2.socket')
except:
    pass

recv.bind('./tmp2.socket')
recv.listen(1)
conn, addr = recv.accept()
send.connect('./tmp1.socket')

%%time
while True:
    val = conn.recv(1024)
    send.send(val)
    if not val:
        break

Process 2:

import socket

data = b"a" * (1024)

send = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
recv = socket.socket(socket.AF_UNIX, socket.SOCK_SEQPACKET)
recv.bind('./tmp1.socket')
recv.listen(1)
send.connect('./tmp2.socket')

conn, _ = recv.accept()

for i in range(100_000):
    send.send(data)
    result = conn.recv(1024)

send.close()
1 Like

My two cents on including this in httpstan. I’d strongly prefer that things like this get first included in Stan C++ and then exposed in httpstan.

1 Like

Putting most of that into the stan library sounds fine to me as well. Where in the stan lib would you want this to live?

1 Like

@mjcarter may be able to comment since he ran into this issue and fixed it locally.

It seems like there is nice progress in the area of getting the logprob, gradients and transformed parameters out of a stan model / data, having trouble understanding the exact state of things though – still pretty tough to do for an R package it seems? If there’s help needed on the R side I’m happy to, only the c++ compiler stuff drives me bonkers…

2 Likes

Not a package, but I wrote an R client for @Bob_Carpenter’s Stan Model Server (mentioned by @WardBrian above). It is fully-featured and more developed than what we have (and will have) in ReddingStan. The R client complements the Python one.

Take a look; and please feel free to leave feedback here or in the Stan Model Server repo itself.

2 Likes

See here: GitHub - mjcarter95/PyBindStan: A Python interface to Stan, based on HTTPStan.. @mjcarter has made a little Python interface that allows users to access logprob and gradient information.

@s.maskell it seems like pybindstan still creates a new model object each call, looking at the stan_services.cpp file

Recently @roualdes has been working on Edward Roualdes / BridgeStan · GitLab which is very cmdstan-like but allows in-memory calls to log_prob on a model which is only instantiated once. So far in testing it has been very fast

3 Likes

Amazing! Has anyone tested with Blackjax?

1 Like

@WardBrian You beat me to it :) Thanks for the mention. I’ll write up a new post about BridgeStan so as to not hijack this thread.

1 Like

BlackJAX’s doc says:

it integrates really well with PPLs as long as they can provide a (potentially unnormalized) log-probability density function compatible with JAX.

What does “compatible with JAX” mean here?

I don’t know! It’s a project I’ve been following but haven’t tested. It’s worth bringing up in their github issues and ask how we can make BridgeStan or the Stan Model Server work with it (if it’s possible).

I just tried, it doesn’t work out of the box in the way I was hoping. Both the server and bridge raised errors of some sort when I tried them.

The signatures for their algorithms include a logprob_grad_fn argument to specify your own gradient, but the documentation never mentions these and it seems like for a few algorithms at least they do nothing, and they still try to use Jax’s autodiff. This obviously fails

Edit: After disabling Jax’s jit and manually editing part of the blackjax source to actually use the supplied logprob_grad_fn, I was able to get something based on BridgeStan to run, but it isn’t recovering the parameters, so something else is still wrong

1 Like

Sounds like it might be better to wrap the logp in a jax op if you want to use blackjax. A nice intro about doing that is here: GitHub - dfm/extending-jax: Extending JAX with custom C++ and CUDA code

1 Like

Ah, yes, we do create a new object inside stan_services.

PyBindStan is a bit clunky at the moment and is a bit buggy. BridgeStan is a lot cleaner and addresses a number of the outstanding issues in PyBindStan. Very neat!

@s.maskell

Blackjax author here. There’s now a tutorial to use custom gradients with Blackjax.

If you encounter any difficulties we’re happy to take a look if you open an issue on the repository.

I never got around to re-sharing it back on this page, but I did eventually get blackjax (or more generally, just JAX) to play nicely: bridgestan/python/blackjax_example.py at examples/blackjax · WardBrian/bridgestan · GitHub

Unfortunately if you don’t run primarily on the CPU the communication overhead between your accelerator and the host can really kill you with this.