lox.spool#

lox.spool(fun, argnames=None, tags=None, keep_logs=False, interval=None, reduce=None, prefix='')#

Spools a function to extract logs generated during its execution. Logs are generated by every call to lox.log within the function. Note that spooling also has to obey the underlying laws of JAX. Hence all shapes and lengths of the logs must be inferrable during compile time. As a direct consequence, loops with non-static lengths (e.g. while loops) are not supported. Another implication of this is that spooling an already jitted function that contains logging operations will trigger recompilation.

Parameters:
  • fun (Callable) – The function to be spooled.

  • argnames (Iterable[str] | None) – An optional list of argument names to be spooled.

  • tags (Iterable[str] | None) – An optional list of tags to filter the logs.

  • keep_logs (bool) – Whether to keep logs in the jaxpr.

  • interval (int | None) – An optional interval to subsample the logs.

  • reduce (str | None) – An optional reduction method to apply to the logs.

  • prefix (str) – An optional prefix to add to the log keys.

Returns:

A wrapped function that returns the spooled jaxpr and logs.

Return type:

Callable

Example

Here is an example of how to use spool to extract logs from a function.

>>> def f(x):
>>>     lox.log({"x": x})
>>>     return x + 1.0
>>> spool(f)(1.0)
(2.0, {'x': 1.0})

Spooling works with arbitrarily nested functions and higher order primitives such as jax.lax.scan or jax.lax.cond. The following example shows how to use spooling with a scan operation over the previously defined function f.

>>> def g(xs):
>>>     carry, ys = jax.lax.scan(lambda c, x: (f(c), x), xs)
>>>     lox.log({"carry": carry})
>>> return ys
>>> spool(g)(jnp.arange(3))
(Array([2., 3., 4.]), {'carry': 3., 'x': Array([0., 1., 2.])})

When logging inside of jax.lax.cond all branches must have the same log structure. In this case, both branches return a logdict containing only the key "x".

>>> def h(x):
>>>     return jax.lax.cond(x > 0, lambda x: f(x), lambda x: f(x + 1), x)
>>> x = 3.0
>>> spool(h)(x)
(4.0, {'x': 3.0})

The optional argument keep_logs can be set to True to keep the logs in the jaxpr. By default, spooling a function twice

>>> def j(x):
>>>     y, logs = spool(f)(x)
>>>     return y
>>> spool(j)(1.0)
2.0