Skip to content

JIT Mode

Overview

XAD provides an optional JIT (Just-In-Time) mode that records computations into a graph which can be compiled and executed multiple times with different inputs. This is particularly useful when:

  • The same computation needs to be evaluated repeatedly with varying inputs
  • You want to avoid the overhead of re-recording a tape for each evaluation
  • Performance is critical for repeated derivative calculations

Compile-time feature flag

JIT mode is only available when XAD is compiled with XAD_ENABLE_JIT. This is a compile-time configuration option.

Tape vs. JIT: Key Differences

Understanding the difference between tape-based adjoint mode and JIT mode is essential:

Aspect Tape JIT
Recording Re-records every run Records once, replays many times
Control flow Evaluated per run Baked in at record time (unless using ABool::If)
Memory Tape grows with computation Fixed graph size after recording
Use case Single or varying computations Repeated evaluations with same structure

Basic JIT Usage

Here is a simple example showing how to use JIT mode:

#include <XAD/XAD.hpp>

using AD = xad::AReal<double, 1>;

xad::JITCompiler<double, 1> jit;

// Create and register inputs
AD x = 2.0;
AD y = 3.0;
jit.registerInput(x);
jit.registerInput(y);

// Record the computation
AD z = x * y + x;
jit.registerOutput(z);

// Compile the recorded graph
jit.compile();

// Execute forward pass
double out = 0.0;
jit.forward(&out, 1);
// out now contains the result

// Compute adjoints
jit.setDerivative(z.getSlot(), 1.0);
jit.computeAdjoints();
double dz_dx = jit.getDerivative(x.getSlot());
double dz_dy = jit.getDerivative(y.getSlot());

Handling Branching: The Control Flow Problem

A critical difference between tape and JIT modes is how they handle control flow.

Consider a piecewise function:

template <class AD>
AD piecewise(const AD& x)
{
    if (xad::value(x) < 2.0)
        return 1.0 * x;
    return 7.0 * x;
}

With Tape Mode

Tape mode re-records the computation for each run. The if statement is evaluated fresh each time based on the current input value:

using mode = xad::adj<double>;
using tape_type = mode::tape_type;
using AD = mode::active_type;

auto run = [](double x0) {
    tape_type tape;
    AD x = x0;
    tape.registerInput(x);
    tape.newRecording();

    AD y = piecewise(x);  // if is evaluated NOW with current x

    tape.registerOutput(y);
    xad::derivative(y) = 1.0;
    tape.computeAdjoints();

    std::cout << "x=" << x0 << " y=" << xad::value(y)
              << " dy/dx=" << xad::derivative(x) << "\n";
};

run(1.0);  // x < 2, so y = 1*x = 1,  dy/dx = 1
run(3.0);  // x >= 2, so y = 7*x = 21, dy/dx = 7

Output:

x=1 y=1 dy/dx=1
x=3 y=21 dy/dx=7

With JIT Mode (Plain if - Incorrect)

JIT mode records the computation once. A plain C++ if is evaluated at record time and the chosen branch is baked into the graph:

using AD = xad::AReal<double, 1>;

xad::JITCompiler<double, 1> jit;
AD x = 1.0;  // Record with x=1
jit.registerInput(x);

AD y = piecewise(x);  // if evaluated NOW: x < 2 is true, records "return 1*x"
jit.registerOutput(y);
jit.compile();

// First run: x=1
double out = 0.0;
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=1 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";

// Second run: x=3 (but graph still has the x < 2 branch!)
x = 3.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=3 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";

Output (wrong for x=3):

x=1 y=1 dy/dx=1
x=3 y=3 dy/dx=1

The graph was recorded with the x < 2 branch, so even when x=3 it still computes 1*x = 3 instead of 7*x = 21.

The Solution: ABool::If

To handle branching correctly in JIT mode, use ABool::If which records a conditional node that evaluates the condition at runtime:

template <class AD>
AD piecewise_abool(const AD& x)
{
    // Create a trackable boolean from the comparison
    auto cond = xad::less(x, 2.0);

    // Compute BOTH branches
    AD true_branch = 1.0 * x;
    AD false_branch = 7.0 * x;

    // Record a conditional selection node
    return cond.If(true_branch, false_branch);
}

Now JIT mode works correctly:

using AD = xad::AReal<double, 1>;

xad::JITCompiler<double, 1> jit;
AD x = 1.0;
jit.registerInput(x);

AD y = piecewise_abool(x);  // Records conditional node with both branches
jit.registerOutput(y);
jit.compile();

// Run with x=1
double out = 0.0;
x = 1.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=1 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";

// Run with x=3
x = 3.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=3 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";

Output (correct):

x=1 y=1 dy/dx=1
x=3 y=21 dy/dx=7

Comparison Summary

The following table summarizes the behavior across all combinations:

Scenario x y dy/dx Note
Tape with plain if 1 1 1 Correct
Tape with plain if 3 21 7 Correct
JIT with plain if (record) 1 1 1 Correct
JIT with plain if (replay) 3 3 1 Wrong - branch baked in
Tape with ABool::If 1 1 1 Correct
Tape with ABool::If 3 21 7 Correct - ABool works with tape too
JIT with ABool::If (record) 1 1 1 Correct
JIT with ABool::If (replay) 3 21 7 Correct - conditional node

Best Practice

When writing code that may be used with JIT mode, use ABool::If for all conditional logic that depends on active variables. This ensures correct behavior in both tape and JIT modes.

Available Comparison Functions

XAD provides comparison functions that return ABool for use with JIT:

  • xad::less(a, b) - returns ABool for a < b
  • xad::less_equal(a, b) - returns ABool for a <= b
  • xad::greater(a, b) - returns ABool for a > b
  • xad::greater_equal(a, b) - returns ABool for a >= b
  • xad::equal(a, b) - returns ABool for a == b
  • xad::not_equal(a, b) - returns ABool for a != b

Reference

For detailed API documentation, see:

See also

This tutorial is based on the included XAD sample (jit_tutorial).