Design Decisions#

lox is designed to provide a seamless logging experience in JAX, adhering to its functional paradigm while minimizing boilerplate. This document outlines the key architectural choices and design philosophies behind the library.

Motivation#

Logging in JAX is notoriously difficult due to its pure functional nature and the way transformations like jit, vmap, and scan work. Common approaches often involve:

  1. Explicit Return: Manually plumbing metrics through every function return. This clutters function signatures and logic.

  2. Side-Effects (host_callback or debug.callback): Using callbacks directly. While useful for monitoring, it can be difficult to collect and process these logs in a structured way (e.g., for reduction or saving).

lox aims to provide the best of both worlds: the simplicity of “add a log statement anywhere” with the power of structured collection and real-time monitoring.

Core Philosophy: Jaxpr Manipulation#

The central design decision of lox is to treat logging as a program transformation rather than a side-effecting operation. Instead of relying on global state or immediate execution of callbacks, lox.log inserts a custom JAX primitive (lox_p) into the function’s intermediate representation (jaxpr).

By deferring the interpretation of these log points, lox can decide how to handle them based on the desired transformation:

  • lox.tap: Re-writes the jaxpr to insert jax.debug.callback at each lox_p site.

  • lox.spool: Re-writes the jaxpr to collect the logged values and return them as additional function outputs.

  • lox.strip: Removes all lox_p primitives from the jaxpr, effectively neutralizing logging with zero runtime overhead.

The lox_p Primitive#

The lox primitive is designed to be as transparent as possible:

  • Functional Purity: By default, it returns its input values, maintaining the functional purity of the user’s code.

  • Effectful Abstract Evaluation: It is registered with DebugEffect, which signals to JAX that this primitive might have side effects when transformed (important for tap).

  • Transformation Support: It implements batching (vmap) and JVP rules, ensuring it works seamlessly with JAX’s core transformations.

Dual API: Tap vs. Spool#

lox provides two primary ways to consume logs, catering to different use cases.

lox.tap (Real-time Monitoring)#

tap is designed for live inspection. It uses jax.debug.callback to send data back to the Python host as it is encountered during execution. This is ideal for:

  • Console logging of progress.

  • Live plotting/monitoring.

  • Debugging values inside jit or scan.

lox.spool (Structured Collection)#

spool is designed for efficiency and batch processing. It transforms the function so that it returns a logdict containing all logged values.

  • Efficiency: Within scan, spool automatically handles the stacking of logs across iterations, returning them as a single array.

  • Post-processing: Collected logs can be sliced, reduced (e.g., jnp.mean), or saved after the function execution.

  • Static Shape Requirement: Since spool modifies the output signature of a JAX function, it requires that the number of log events is statically known. This is why spool does not support while_loop with dynamic termination.

Data Structure: logdict#

Logs in lox are not just raw values; they are encapsulated in a logdict.

  • Pytree Compatibility: logdict and its helper stepdict are registered JAX pytrees, allowing them to be passed in and out of JAX-transformed functions.

  • Metadata Association: Every log value can be associated with multiple “steps” (e.g., training step, epoch, episode). These steps are tracked alongside the data and automatically handled during transformations like vmap or scan.

  • Rich Interface: logdict provides a dictionary-like interface but adds methods for slice, reduce, and merging (+), facilitating easy post-processing.

Logger Abstraction#

lox separates the instrumentation (lox.log) from the output backend.

  • The Logger base class defines a standard interface for backend integrations (e.g., ConsoleLogger, WandBLogger, SaveLogger).

  • Loggers can be easily swapped or combined using MultiLogger without changing the underlying function logic.

  • State Management: Loggers use a LoggerState pattern, consistent with JAX’s state-handling conventions (e.g., for init and log calls).

Support for JAX Transformations#

A non-negotiable goal for lox is full support for JAX’s core transformations.

  • vmap: lox handles batching by automatically expanding the dimensions of logged data and steps.

  • scan / fori_loop: lox handles the sequential nature of these loops, either by streaming (in tap) or stacking (in spool).

  • jit: lox transformations work at the jaxpr level before compilation, ensuring that logging logic is baked into the compiled artifact.