Nested AD with references to the top of the AD stack

Hi!

I am wondering if it is fine to do nested AD with vars which are part of the non-nested stack (as long as I make sure that adjoints are set to 0)?

So, the use pattern I have in mind is:

var theta =123;
start_nested();
theta.vi_->set_zero_adjoint();
var y = 456;
var dy_dt = f(theta, y);
dy_dt.grad();
...
recover_memory_nested();

I don’t see why this should not work, but I am getting weird behavior when doing this in the context of the ODE sub-system.

I have prepared a branch which uses the above pattern to avoid allocating the theta parameter vector of the ODE rhs functor every time the RHS is evaluated. This is not needed, since the parameters do not change. This gives me a ~15%-20% speedup on the SIR benchmark… so I would like to make the PR, but I am not sure if this is OK.

This PR here has all tests working ok, but I recall that it is fragile in the sense that changing slightly the order of things will break it.

So is this illegal?

Tagging a few people who hopefully know @syclik, @Bob_Carpenter or @seantalts ? Any suggestion?

To give everyone a bit more background: The branch feature/issue-1062-ode-speedup does work fine and gives me a ~20% speedup for ODE integration, which is great. The only trick I applied there is to let the parameter vector theta on the global AD stack instead of making it part of the nested AD stack which gets allocated with every call to the ODE RHS.

The tests work fine on that branch, e.g.

./runTests.py test/unit/math/rev/arr/functor/integrate_ode_rk45_grad_test.cpp

works just fine. However, things are fragile! If I use a slight variation of the code, which to my understanding should just work fine, then things fail. So the branch feature/issue-1062-ode-speedup-failing is a variation of that idea and the tests do fail - and it is not clear to me as to what is different.

I would like to file this PR for a nice speedup, but I think we should understand this weirdness.

The diff between the two is really minimal:

diff --git a/stan/math/rev/arr/functor/coupled_ode_system.hpp b/stan/math/rev/arr/functor/coupled_ode_system.hpp
index 955d376042..d25a520005 100644
--- a/stan/math/rev/arr/functor/coupled_ode_system.hpp
+++ b/stan/math/rev/arr/functor/coupled_ode_system.hpp
@@ -477,10 +477,10 @@ struct coupled_ode_system<F, var, var> {
       check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
                        "states", N_);

-      for (size_t j = 0; j < M_; j++)
-        theta_[j].vi_->set_zero_adjoint();
-
       for (size_t i = 0; i < N_; i++) {
+        for (size_t j = 0; j < M_; j++)
+          theta_[j].vi_->set_zero_adjoint();
+
         dz_dt[i] = dy_dt_vars[i].val();
         dy_dt_vars[i].grad();

@@ -498,7 +498,9 @@ struct coupled_ode_system<F, var, var> {

         for (size_t j = 0; j < M_; j++) {
           double temp_deriv = theta_[j].adj();
-          theta_[j].vi_->set_zero_adjoint();
+          // not doing the zeroing here leads to test failures...not
+          // sure why!
+          // theta_[j].vi_->set_zero_adjoint();
           const size_t offset = N_ + N_ * N_ + N_ * j;
           for (size_t k = 0; k < N_; k++)
             temp_deriv += z[offset + k] * y_vars[k].adj();

I think it makes sense that you would need to zero, right? Let me try to state my mental model of what is happening here:

  1. You allocate a bunch of vars called theta
  2. You start a new ad stack of chain methods with start_nested (it also starts a new memory pool and maybe a few other book keeping thingies)
  3. You do a bunch of operations, all of which push chain methods onto the ad stack. some of these operations involve theta, and so theta will appear in chain methods on the stack, will have its adjoints read, and will have its adjoints messed with.
  4. You iterate through them all and call them with .grad().
  5. You clear the ad stack and go back to the old ad stack and free any var memory that you allocated during the nested call.

So from my mental model in step 3 you mess with the adjoints in a way that you need to clear in between each of them. recover_memory_nested would free memory used by stuff allocated during the nested call, but it 1) doesn’t call any destructors, and 2) doesn’t do cleanup outside of the nested arena and bookkeeping stuff. In particular, it doesn’t go and set old adjoints to 0 or anything it just says that all of that memory is now available again so any new stuff zeros it out. The stuff from outside the nested call isn’t touched.

Yes, it does make sense that I have to deal with set_zero_all_adjoints_nested() and zero out the adjoints of the operands which are in the outer AD stack.

It just seems to matter at what exact location in the source code I am zeroing the outer AD stack variables. So if I move the zeroing a little bit around, then things stop working - and this is what I do not get right now. From my understanding all what should matter is to do the zeroing of the adjoints just before I call grad. But that is not the case. It does smell like a bug to me or I am really operating wrongly here. So far I wasn’t able to make unit tests which are simple to tickle this behavior - only the two branches I point to above demonstrate a working and a non-working state (but I would really expect that both branches should work just fine).

I am hesitant to merge this PR until we figured this out. The PR gives a 15-20% speedup of the ODE integrators which is a good motivation, I think.

Oh I see, you’re zeroing twice to get it to work. Where is the 2nd time and what happens in between the two? Can you add it to that pattern you posted in your first post? (I don’t know the ODE code that well).

Let me paste the code in question.

So this works ok:

void operator()(const std::vector<double>& z, std::vector<double>& dz_dt,
                  double t) const {
    using std::vector;

    try {
      start_nested();

      vector<var> y_vars(z.begin(), z.begin() + N_);

      vector<var> dy_dt_vars = f_(t, y_vars, theta_, x_, x_int_, msgs_);

      check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
                       "states", N_);
  
      // zero the outer operands as we don't know in what state these are
      for (size_t j = 0; j < M_; j++)
        theta_[j].vi_->set_zero_adjoint();

      for (size_t i = 0; i < N_; i++) {
        dz_dt[i] = dy_dt_vars[i].val();
        dy_dt_vars[i].grad();

        // irrelevant stuff

        for (size_t j = 0; j < M_; j++) {
          double temp_deriv = theta_[j].adj();
          // NOTE: zero out the outer operand straight after using it => this works!
          theta_[j].vi_->set_zero_adjoint();
          const size_t offset = N_ + N_ * N_ + N_ * j;
          for (size_t k = 0; k < N_; k++)
            temp_deriv += z[offset + k] * y_vars[k].adj();

          dz_dt[offset + i] = temp_deriv;
        }

        // zero the nested AD tree
        set_zero_all_adjoints_nested();
      }
    } catch (const std::exception& e) {
      recover_memory_nested();
      throw;
    }
    recover_memory_nested();
  }

And this variation fails for reasons I really do not get as I have just moved zeroing of the adjoints just before the grad call:

void operator()(const std::vector<double>& z, std::vector<double>& dz_dt,
                  double t) const {
    using std::vector;

    try {
      start_nested();

      vector<var> y_vars(z.begin(), z.begin() + N_);

      vector<var> dy_dt_vars = f_(t, y_vars, theta_, x_, x_int_, msgs_);

      check_size_match("coupled_ode_system", "dz_dt", dy_dt_vars.size(),
                       "states", N_);

      for (size_t i = 0; i < N_; i++) {
        // zero out all outer operands (just before the grad call)
        for (size_t j = 0; j < M_; j++)
          theta_[j].vi_->set_zero_adjoint();

        dz_dt[i] = dy_dt_vars[i].val();
        dy_dt_vars[i].grad();

        // irrelevant stuff

        for (size_t j = 0; j < M_; j++) {
          double temp_deriv = theta_[j].adj();
          // NOTE: not doing the zeroing here leads to test failures...not
          // sure why!
          // theta_[j].vi_->set_zero_adjoint();
          const size_t offset = N_ + N_ * N_ + N_ * j;
          for (size_t k = 0; k < N_; k++)
            temp_deriv += z[offset + k] * y_vars[k].adj();

          dz_dt[offset + i] = temp_deriv;
        }

        // zero nested AD stack
        set_zero_all_adjoints_nested();
      }
    } catch (const std::exception& e) {
      recover_memory_nested();
      throw;
    }
    recover_memory_nested();
  }

Do you see why I am confused? To me it looks like weird stuff is going on.

Uhm… actually I think by now that I am committing a mistake here (in both versions above).

I not only need to zero out the adjoints from the outer AD stack… I also have to restore them to their original values! Otherwise I am messing with a possible grad call which is happening while the ODE integrator is being called. Right?

Let’s see, maybe this will make this work.

1 Like

Woof, that sounds right to me if theta are used elsewhere.

It works now! Recovering the adjoints is what is needed in addition to manually zeroing them.

I just filed the PR… this gives me 19% speedup for the SIR example. However, merging this PR probably needs some discussion due to the subtle things which are going on here. Not sure if what I did is OK, but given 19% speedup we should probably discuss.

2 Likes

This is also my understanding of the problem.

You have to do it after all the operations that modify vari and before calculating adjoints.

I don’t think any of this is safe. The idea is that the nested stack lets you use a new memory space that’s completely partitioned from the original stack, i.e., a bunch of vari* on a stack with nesting:

outer
outer
...
outer
-------------------
nested 1
nested 1
...
nested 1
-------------------

When starting the nested autodiff, it shouldn’t cross over into the outer autodiff at all. The point is that it’s purely nested. When computing gradients, all of the nested cases need to be zeroed, then the nested grad call works only within the nesting.

If you start having things point from the nested space outward, all hell’s going to break loose. What I don’t understand is why you’d want to do that.

It is possible to have the nested thing reduce to a new vari that goes back on the outer stack.

Before calculating gradients, everything in the current nesting level can be zeroed if it’s not already zero by construction (as it is when built the first time). Then the top level result is set to 1 and the derivatives are propagated wholly within a nesting level.

Hmm. I wonder how much of a performance increase you’d get by keeping a copy of var theta = 123; outside of the nested part and just copying into a new theta once you start_nested? Or is that what the current version on develop is doing?

The bad pattern to me is the argument const std::vector<double>& z, which brings non-nested autodiff variables into scope with vector<var> y_vars(z.begin(), z.begin() + N_); I’m not sure what the intention of that is.

Nested autodiff was designed to just start again from scratch. So you could bring in values of other autodiff variables, but not the vari* themselves.

Not sure what the problem with that one is. Recall that we are interfacing here with ODE solvers. These only work with double only vectors.

What my benchmarks have shown so far is that not starting from scratch for things which are immutable speeds up the ODE integration. To gain a bit more motivation for doing this I could look into one more example if that helps. I do understand that what I did is potentially outside of the intended use case of nested AD… but maybe this is a general thing we want to explore? I can easily imagine that a more tightly integrated nested AD approach would be very beneficial for the threaded map_rect, for example (the per-job parameters do not need to be re-created on a local AD stack - we could just use them where they are if properly zeroed and restored).

Currently the PR proposes to just use the parameters on the outer AD stack. Allocating a new copy of them on the outer AD stack (this is what @seantalts suggested, right?) is not ideal. This would increase with every call to coupled_ode_system the global AD stack without any means to get rid of it. Thus, repeated integrate calls lead to an increased AD stack => not good.

What could be done instead is to double nest things. So the coupled_ode_system constructor starts a nested AD arena, then places a copy of theta varis into that and the () operator would itself run nested AD. This way we can clean the theta varis from the AD stack in the destructor of the coupled_ode_system. Maybe this is a good compromise here?

Thanks—I didn’t realize that. Then how do you get pointers to autodiff variables that aren’t on the nested stack?

I’m clearly very confused about what’s going on here and what the intention is w.r.t. autodiff. I find all the ODE solving stuff really confusing. Is there a simpler example outside of an ODE context that illustrates the point?

I wrote some tests which illustrate the nesting of AD.

nested_test.cpp (4.7 KB)

Maybe they make it easier to follow along.

this is small enough to paste here. My Mac always wants ot open the wrong app for files, which is a pain.

P.S. The general recommendation is to use line comments to comment out blocks of code. The block comments are only for long doc chunks. You can do that in editors like emacs in one command.

#include <stan/math.hpp>
#include <gtest/gtest.h>
#include <stan/math/prim/mat/fun/Eigen.hpp>  // only used for stack tests
#include <stan/math/rev/mat/fun/quad_form.hpp>
#include <test/unit/math/rev/mat/fun/util.hpp>
#include <string>
#include <vector>

struct AgradRev : public testing::Test {
  void SetUp() {
    // make sure memory's clean before starting each test
    stan::math::recover_memory();
  }
};

struct binary_functor {
  template <typename T1, typename T2>
  std::vector<typename stan::return_type<T1, T2>::type> operator()(T1& a,
                                                                   T2& b) {
    std::vector<typename stan::return_type<T1, T2>::type> res(2);
    res[0] = a * a + b * b;
    res[1] = a * a * a + b * b * b;
    return res;
  }
};

TEST_F(AgradRev, non_nested_jacobian) {
  using stan::math::var;
  using stan::math::vari;
  using std::vector;

  var a = 2.0;
  var b = 3.0;

  binary_functor f;
  vector<var> res = f(a, b);

  res[0].grad();

  EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
  EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

  stan::math::set_zero_all_adjoints();

  res[1].grad();

  EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
  EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());
}

TEST_F(AgradRev, nested_jacobian) {
  using stan::math::var;
  using stan::math::vari;
  using std::vector;

  var a = 2.0;
  var b = 3.0;

  binary_functor f;

  stan::math::start_nested();

  vector<var> res = f(a, b);

  res[0].grad();

  EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
  EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

  stan::math::set_zero_all_adjoints_nested();

  a.vi_->set_zero_adjoint();
  b.vi_->set_zero_adjoint();

  res[1].grad();

  EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
  EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());

  stan::math::recover_memory_nested();
}

TEST_F(AgradRev, double_nested_jacobian) {
  using stan::math::var;
  using stan::math::vari;
  using std::vector;

  stan::math::start_nested();

  var a = 2.0;

  binary_functor f;

  stan::math::start_nested();

  var b = 3.0;

  vector<var> res = f(a, b);

  res[0].grad();

  EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
  EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

  stan::math::set_zero_all_adjoints_nested();

  a.vi_->set_zero_adjoint();
  // b.vi_->set_zero_adjoint();

  res[1].grad();

  EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
  EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());

  stan::math::recover_memory_nested();

  EXPECT_TRUE(!stan::math::empty_nested());

  stan::math::recover_memory_nested();

  EXPECT_TRUE(stan::math::empty_nested());
}

TEST_F(AgradRev, repeated_double_nested_jacobian) {
  using stan::math::var;
  using stan::math::vari;
  using std::vector;

  stan::math::start_nested();

  var a = 2.0;

  binary_functor f;

  {
    stan::math::start_nested();

    var b = 3.0;

    vector<var> res = f(a, b);

    a.vi_->set_zero_adjoint();

    res[0].grad();

    EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
    EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

    stan::math::set_zero_all_adjoints_nested();

    a.vi_->set_zero_adjoint();
    // b.vi_->set_zero_adjoint();

    res[1].grad();

    EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
    EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());

    stan::math::recover_memory_nested();

    EXPECT_TRUE(!stan::math::empty_nested());
  }

  // repeat start

  {
    stan::math::start_nested();

    var b = 4.0;

    vector<var> res = f(a, b);

    a.vi_->set_zero_adjoint();

    res[0].grad();

    EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
    EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

    stan::math::set_zero_all_adjoints_nested();

    a.vi_->set_zero_adjoint();
    // b.vi_->set_zero_adjoint();

    res[1].grad();

    EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
    EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());

    stan::math::recover_memory_nested();

    EXPECT_TRUE(!stan::math::empty_nested());
  }
  // repeat end

  stan::math::recover_memory_nested();

  EXPECT_TRUE(stan::math::empty_nested());
}

/*
TEST_F(AgradRev, repeated_double_nested_jacobian) {
  using stan::math::var;
  using stan::math::vari;
  using std::vector;

  var a = 2.0;

  stan::math::start_nested();

  var b = 3.0;

  binary_functor f;

  stan::math::start_nested();

  vector<var> res = f(a, b);

  res[0].grad();

  EXPECT_FLOAT_EQ(2.0 * a.val(), a.adj());
  EXPECT_FLOAT_EQ(2.0 * b.val(), b.adj());

  stan::math::set_zero_all_adjoints_nested();

  a.vi_->set_zero_adjoint();
  b.vi_->set_zero_adjoint();

  res[1].grad();

  EXPECT_FLOAT_EQ(3.0 * a.val() * a.val(), a.adj());
  EXPECT_FLOAT_EQ(3.0 * b.val() * b.val(), b.adj());

  stan::math::recover_memory_nested();

  EXPECT_TRUE(!stan::math::empty_nested());

  stan::math::recover_memory_nested();

  EXPECT_TRUE(stan::math::empty_nested());

}
*/