lox.tap#

lox.tap(fun, callback=None, argnames=None, tags=None, prefix='')#

A function transformation that taps into the execution of a JAX function and prints the values of specified arguments. One can only tap into values that are logged with lox.log(). This transformation modifies the function to exectute a callback with the tapped values. By default this callback will display the values in the console. It can be used to debug and inspect the values of arguments during the execution of a JAX function. It is possible to specify which arguments to tap by providing their names. If no names are provided, all arguments will be tapped. It is possible to provide a custom callback function to be called with the tapped values. The callback function should accept a single argument, which is a logdict containing the tapped values.

>>> def callback(logs: logdict):
>>>   print(logs)
>>> def f(x, y):
>>>   x = lox.log({"x": x, "y": y})
>>>   return x + y, x * y
>>> x, y = 1.0, 2.0
>>> y = lox.tap(f, argnames=["y"], callback=callback)(x, y)
{"y": 2.0}

Note that this transformation can introduce a significant overhead, especially if the function is called frequently or with large inputs. It is recommended to use this transformation only for debugging purposes and to remove it before any performance-critical execution. Use lox.spool() to log values that you want to tap into.

Parameters:
  • fun (Callable) – The function you want to tap into.

  • callback (Optional[Callable[[logdict], None]]) – A callback function to be called with the tapped values. If None, the default callback will be used to display the values.

  • argnames (Union[str, Iterable[str], None]) – A string or iterable of strings specifying the names of the arguments to be printed. If None, all arguments will be tapped.

  • prefix (str) – An optional prefix to add to the log keys.

Returns:

A wrapped function that executes the original function and prints the tapped values.

Return type:

Callable