JIT Mode
Overview¶
XAD provides an optional JIT (Just-In-Time) mode that records computations into a graph which can be compiled and executed multiple times with different inputs. This is particularly useful when:
- The same computation needs to be evaluated repeatedly with varying inputs
- You want to avoid the overhead of re-recording a tape for each evaluation
- Performance is critical for repeated derivative calculations
Compile-time feature flag
JIT mode is only available when XAD is compiled with XAD_ENABLE_JIT. This is a compile-time configuration option.
Tape vs. JIT: Key Differences¶
Understanding the difference between tape-based adjoint mode and JIT mode is essential:
| Aspect | Tape | JIT |
|---|---|---|
| Recording | Re-records every run | Records once, replays many times |
| Control flow | Evaluated per run | Baked in at record time (unless using ABool::If) |
| Memory | Tape grows with computation | Fixed graph size after recording |
| Use case | Single or varying computations | Repeated evaluations with same structure |
Basic JIT Usage¶
Here is a simple example showing how to use JIT mode:
#include <XAD/XAD.hpp>
using AD = xad::AReal<double, 1>;
xad::JITCompiler<double, 1> jit;
// Create and register inputs
AD x = 2.0;
AD y = 3.0;
jit.registerInput(x);
jit.registerInput(y);
// Record the computation
AD z = x * y + x;
jit.registerOutput(z);
// Compile the recorded graph
jit.compile();
// Execute forward pass
double out = 0.0;
jit.forward(&out, 1);
// out now contains the result
// Compute adjoints
jit.setDerivative(z.getSlot(), 1.0);
jit.computeAdjoints();
double dz_dx = jit.getDerivative(x.getSlot());
double dz_dy = jit.getDerivative(y.getSlot());
Handling Branching: The Control Flow Problem¶
A critical difference between tape and JIT modes is how they handle control flow.
Consider a piecewise function:
template <class AD>
AD piecewise(const AD& x)
{
if (xad::value(x) < 2.0)
return 1.0 * x;
return 7.0 * x;
}
With Tape Mode¶
Tape mode re-records the computation for each run. The if statement is evaluated fresh each time based on the current input value:
using mode = xad::adj<double>;
using tape_type = mode::tape_type;
using AD = mode::active_type;
auto run = [](double x0) {
tape_type tape;
AD x = x0;
tape.registerInput(x);
tape.newRecording();
AD y = piecewise(x); // if is evaluated NOW with current x
tape.registerOutput(y);
xad::derivative(y) = 1.0;
tape.computeAdjoints();
std::cout << "x=" << x0 << " y=" << xad::value(y)
<< " dy/dx=" << xad::derivative(x) << "\n";
};
run(1.0); // x < 2, so y = 1*x = 1, dy/dx = 1
run(3.0); // x >= 2, so y = 7*x = 21, dy/dx = 7
Output:
With JIT Mode (Plain if - Incorrect)¶
JIT mode records the computation once. A plain C++ if is evaluated at record time and the chosen branch is baked into the graph:
using AD = xad::AReal<double, 1>;
xad::JITCompiler<double, 1> jit;
AD x = 1.0; // Record with x=1
jit.registerInput(x);
AD y = piecewise(x); // if evaluated NOW: x < 2 is true, records "return 1*x"
jit.registerOutput(y);
jit.compile();
// First run: x=1
double out = 0.0;
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=1 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";
// Second run: x=3 (but graph still has the x < 2 branch!)
x = 3.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=3 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";
Output (wrong for x=3):
The graph was recorded with the x < 2 branch, so even when x=3 it still computes 1*x = 3 instead of 7*x = 21.
The Solution: ABool::If¶
To handle branching correctly in JIT mode, use ABool::If which records a conditional node that evaluates the condition at runtime:
template <class AD>
AD piecewise_abool(const AD& x)
{
// Create a trackable boolean from the comparison
auto cond = xad::less(x, 2.0);
// Compute BOTH branches
AD true_branch = 1.0 * x;
AD false_branch = 7.0 * x;
// Record a conditional selection node
return cond.If(true_branch, false_branch);
}
Now JIT mode works correctly:
using AD = xad::AReal<double, 1>;
xad::JITCompiler<double, 1> jit;
AD x = 1.0;
jit.registerInput(x);
AD y = piecewise_abool(x); // Records conditional node with both branches
jit.registerOutput(y);
jit.compile();
// Run with x=1
double out = 0.0;
x = 1.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=1 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";
// Run with x=3
x = 3.0;
jit.clearDerivatives();
jit.forward(&out, 1);
jit.setDerivative(y.getSlot(), 1.0);
jit.computeAdjoints();
std::cout << "x=3 y=" << out << " dy/dx=" << jit.getDerivative(x.getSlot()) << "\n";
Output (correct):
Comparison Summary¶
The following table summarizes the behavior across all combinations:
| Scenario | x | y | dy/dx | Note |
|---|---|---|---|---|
Tape with plain if | 1 | 1 | 1 | Correct |
Tape with plain if | 3 | 21 | 7 | Correct |
JIT with plain if (record) | 1 | 1 | 1 | Correct |
JIT with plain if (replay) | 3 | 3 | 1 | Wrong - branch baked in |
Tape with ABool::If | 1 | 1 | 1 | Correct |
Tape with ABool::If | 3 | 21 | 7 | Correct - ABool works with tape too |
JIT with ABool::If (record) | 1 | 1 | 1 | Correct |
JIT with ABool::If (replay) | 3 | 21 | 7 | Correct - conditional node |
Best Practice
When writing code that may be used with JIT mode, use ABool::If for all conditional logic that depends on active variables. This ensures correct behavior in both tape and JIT modes.
Available Comparison Functions¶
XAD provides comparison functions that return ABool for use with JIT:
xad::less(a, b)- returnsABoolfora < bxad::less_equal(a, b)- returnsABoolfora <= bxad::greater(a, b)- returnsABoolfora > bxad::greater_equal(a, b)- returnsABoolfora >= bxad::equal(a, b)- returnsABoolfora == bxad::not_equal(a, b)- returnsABoolfora != b
Reference¶
For detailed API documentation, see:
See also
This tutorial is based on the included XAD sample (jit_tutorial).