Design Decisions#
lox is designed to provide a seamless logging experience in JAX, adhering to its functional paradigm while minimizing boilerplate. This document outlines the key architectural choices and design philosophies behind the library.
Motivation#
Logging in JAX is notoriously difficult due to its pure functional nature and the way transformations like jit, vmap, and scan work. Common approaches often involve:
Explicit Return: Manually plumbing metrics through every function return. This clutters function signatures and logic.
Side-Effects (
host_callbackordebug.callback): Using callbacks directly. While useful for monitoring, it can be difficult to collect and process these logs in a structured way (e.g., for reduction or saving).
lox aims to provide the best of both worlds: the simplicity of “add a log statement anywhere” with the power of structured collection and real-time monitoring.
Core Philosophy: Jaxpr Manipulation#
The central design decision of lox is to treat logging as a program transformation rather than a side-effecting operation.
Instead of relying on global state or immediate execution of callbacks, lox.log inserts a custom JAX primitive (lox_p) into the function’s intermediate representation (jaxpr).
By deferring the interpretation of these log points, lox can decide how to handle them based on the desired transformation:
lox.tap: Re-writes the jaxpr to insertjax.debug.callbackat eachlox_psite.lox.spool: Re-writes the jaxpr to collect the logged values and return them as additional function outputs.lox.strip: Removes alllox_pprimitives from the jaxpr, effectively neutralizing logging with zero runtime overhead.
The lox_p Primitive#
The lox primitive is designed to be as transparent as possible:
Functional Purity: By default, it returns its input values, maintaining the functional purity of the user’s code.
Effectful Abstract Evaluation: It is registered with
DebugEffect, which signals to JAX that this primitive might have side effects when transformed (important fortap).Transformation Support: It implements batching (
vmap) and JVP rules, ensuring it works seamlessly with JAX’s core transformations.
Dual API: Tap vs. Spool#
lox provides two primary ways to consume logs, catering to different use cases.
lox.tap (Real-time Monitoring)#
tap is designed for live inspection. It uses jax.debug.callback to send data back to the Python host as it is encountered during execution. This is ideal for:
Console logging of progress.
Live plotting/monitoring.
Debugging values inside
jitorscan.
lox.spool (Structured Collection)#
spool is designed for efficiency and batch processing. It transforms the function so that it returns a logdict containing all logged values.
Efficiency: Within
scan,spoolautomatically handles the stacking of logs across iterations, returning them as a single array.Post-processing: Collected logs can be sliced, reduced (e.g.,
jnp.mean), or saved after the function execution.Static Shape Requirement: Since
spoolmodifies the output signature of a JAX function, it requires that the number of log events is statically known. This is whyspooldoes not supportwhile_loopwith dynamic termination.
Data Structure: logdict#
Logs in lox are not just raw values; they are encapsulated in a logdict.
Pytree Compatibility:
logdictand its helperstepdictare registered JAX pytrees, allowing them to be passed in and out of JAX-transformed functions.Metadata Association: Every log value can be associated with multiple “steps” (e.g., training step, epoch, episode). These steps are tracked alongside the data and automatically handled during transformations like
vmaporscan.Rich Interface:
logdictprovides a dictionary-like interface but adds methods forslice,reduce, and merging (+), facilitating easy post-processing.
Logger Abstraction#
lox separates the instrumentation (lox.log) from the output backend.
The
Loggerbase class defines a standard interface for backend integrations (e.g.,ConsoleLogger,WandBLogger,SaveLogger).Loggers can be easily swapped or combined using
MultiLoggerwithout changing the underlying function logic.State Management: Loggers use a
LoggerStatepattern, consistent with JAX’s state-handling conventions (e.g., forinitandlogcalls).
Support for JAX Transformations#
A non-negotiable goal for lox is full support for JAX’s core transformations.
vmap:loxhandles batching by automatically expanding the dimensions of logged data and steps.scan/fori_loop:loxhandles the sequential nature of these loops, either by streaming (intap) or stacking (inspool).jit:loxtransformations work at the jaxpr level before compilation, ensuring that logging logic is baked into the compiled artifact.