lox.spool#
- lox.spool(fun, argnames=None, tags=None, keep_logs=False, interval=None, reduce=None, prefix='')#
Spools a function to extract logs generated during its execution. Logs are generated by every call to
lox.logwithin the function. Note that spooling also has to obey the underlying laws of JAX. Hence all shapes and lengths of the logs must be inferrable during compile time. As a direct consequence, loops with non-static lengths (e.g. while loops) are not supported. Another implication of this is that spooling an already jitted function that contains logging operations will trigger recompilation.- Parameters:
fun (Callable) β The function to be spooled.
argnames (Iterable[str] | None) β An optional list of argument names to be spooled.
tags (Iterable[str] | None) β An optional list of tags to filter the logs.
keep_logs (bool) β Whether to keep logs in the jaxpr.
interval (int | None) β An optional interval to subsample the logs.
reduce (str | None) β An optional reduction method to apply to the logs.
prefix (str) β An optional prefix to add to the log keys.
- Returns:
A wrapped function that returns the spooled jaxpr and logs.
- Return type:
Callable
Example
Here is an example of how to use
spoolto extract logs from a function.>>> def f(x): >>> lox.log({"x": x}) >>> return x + 1.0 >>> spool(f)(1.0) (2.0, {'x': 1.0})
Spooling works with arbitrarily nested functions and higher order primitives such as
jax.lax.scanorjax.lax.cond. The following example shows how to use spooling with a scan operation over the previously defined functionf.>>> def g(xs): >>> carry, ys = jax.lax.scan(lambda c, x: (f(c), x), xs) >>> lox.log({"carry": carry}) >>> return ys >>> spool(g)(jnp.arange(3)) (Array([2., 3., 4.]), {'carry': 3., 'x': Array([0., 1., 2.])})
When logging inside of
jax.lax.condall branches must have the same log structure. In this case, both branches return a logdict containing only the key"x".>>> def h(x): >>> return jax.lax.cond(x > 0, lambda x: f(x), lambda x: f(x + 1), x) >>> x = 3.0 >>> spool(h)(x) (4.0, {'x': 3.0})
The optional argument keep_logs can be set to True to keep the logs in the jaxpr. By default, spooling a function twice
>>> def j(x): >>> y, logs = spool(f)(x) >>> return y >>> spool(j)(1.0) 2.0