JIT Backend Interface¶
Overview¶
JIT backends execute a recorded JITGraph. XAD exposes:
JITBackend<Scalar>: abstract execution interfaceJITGraphInterpreter<Scalar>: reference backend that interprets the graph
Compile-time feature flag
This API is only available when XAD is compiled with XAD_ENABLE_JIT.
JITBackend¶
template <class Scalar> class JITBackend
Abstract base class for JIT execution backends. The Scalar template parameter specifies the floating-point type used for computations and must match the type used by JITCompiler<Real, N>:
JITBackend<double>for use withJITCompiler<double, N>JITBackend<float>for use withJITCompiler<float, N>
Most applications use double precision. The float option is available for performance-sensitive applications where single precision is sufficient.
Backends are responsible for:
- compiling a
JITGraphinto an executable form (optional) - executing forward evaluation
- optionally executing forward+backward to produce input adjoints
Main virtual functions¶
compile¶
virtual void compile(const JITGraph& graph) = 0;
Prepare the backend for executing the given graph.
vectorWidth¶
virtual std::size_t vectorWidth() const = 0;
Returns the number of parallel evaluations per execution (1 for scalar backends, 4 for AVX2).
numInputs / numOutputs¶
virtual std::size_t numInputs() const = 0; virtual std::size_t numOutputs() const = 0;
Return the number of inputs/outputs in the compiled graph.
setInput¶
virtual void setInput(std::size_t inputIndex, const Scalar* values) = 0;
Set input values for an input variable. The values array must contain vectorWidth() elements.
forward¶
virtual void forward(Scalar* outputs) = 0;
Run a forward pass. The outputs array must have space for numOutputs() * vectorWidth() elements.
forwardAndBackward¶
virtual void forwardAndBackward(Scalar* outputs, Scalar* inputGradients) = 0;
Run forward and backward passes combined. The outputs array must have space for numOutputs() * vectorWidth() elements, and inputGradients must have space for numInputs() * vectorWidth() elements.
reset¶
virtual void reset() = 0;
Reset any cached backend state.
JITGraphInterpreter¶
template <class Scalar> class JITGraphInterpreter : public JITBackend<Scalar>
Reference backend that interprets the graph directly. It is mainly intended as:
- a correctness reference implementation
- a simple fallback backend
- a baseline for testing and debugging
Example Usage¶
For double backend:
auto backend = std::make_unique<xad::JITGraphInterpreter<double>>();
xad::JITCompiler<double, 1> jit(std::move(backend));
// ... record graph ...
jit.compile();
For float backend:
auto backend = std::make_unique<xad::JITGraphInterpreter<float>>();
xad::JITCompiler<float, 1> jit(std::move(backend));
// ... record graph ...
jit.compile();