lox.logdict#

class lox.logdict(data, **steps)#

Bases: dict[str, Any]

A dictionary that stores log values and the steps at which they were recorded. This class extends the standard dictionary to include step information, allowing for structured logging of data during computations. It is the underlying data structure used for all logging in lox. It behaves identically to a standard dictionary, but it additionally contains a steps attribute that stores timestamps at which the data was logged. Internally steps two level dictionary, where at the first level are the names of the timestamps (e.g. "step", "episode", etc.), and at the second level are the actual timestamps. However, it is not supposed to be accessed directly. Insted logdict provides a convenient interface to access the steps as attributes.

>>> _, logs = lox.spool(f)()
>>> logs["loss"]
[1.0, 0.8, 0.6, 0.5, 0.4, 0.3, 0.25, 0.2, 0.2]
>>> logs.step["loss"]
[0, 1, 2, 3, 4, 5, 6, 7, 8]
>>> loss.episode["loss"]
[0, 0, 0, 1, 1, 1, 2, 2, 2]
__init__(data, **steps)#

Methods

__add__(other)#

Adds two logdicts together, concatenating their data and steps. This is essentially identical to executing two functions in sequence and then collecting their logs. Assuming f and g are two pure functions containing arbitrary logs, then contents of the variable logs will be the same in the following two codeblocks.

_, logs_f = lox.spool(f)()
_, logs_g = lox.spool(g)()
logs = logs_f + logs_g
def h():
  f(); g()
_, logs = lox.spool(h)()

Note that this is not the same as updating the dict. If both logdicts contain values for the same key, the values will be concatenated, assuming they have the same structure. If they do not, an error will be raised.

Parameters:

other (logdict) – Another logdict to add.

Returns:

A new logdict containing the concatenated data and steps.

Return type:

logdict

__or__(other)#

Merges two logdicts, overwriting values from the self dict if they exist in both. The same happens for the steps.

Parameters:

other (logdict) – Another logdict to merge with.

Returns:

A new logdict containing the merged data and steps.

Return type:

logdict

Raises:

TypeError – If the other object is not a logdict.

property data: dict[str, Any]#

Returns: dict: The data stored in the logdict as a standard dictionary.

filter(predicate)#

Filters the logdict values and steps based on a predicate function.

Parameters:

predicate (Callable[[str, Any], bool]) – A function that takes a key and value and returns True if the item should be kept, False otherwise.

Returns:

A new logdict containing only the items that satisfy the predicate.

Return type:

logdict

prefix(prefix)#

Adds a prefix to all keys in the logdict.

Parameters:

prefix (str) – The prefix to add to each key.

Returns:

A new logdict with prefixed keys.

Return type:

logdict

reduce(mode='mean', keep_steps=True)#

Reduces the logdict values using the specified mode.

Parameters:
  • mode (str) – The reduction mode to apply. Can be one of “mean”, “first” or “last”.

  • keep_steps (bool) – Whether to keep the steps in the reduced logdict.

Returns:

A new logdict containing the reduced values and optionally the steps.

Return type:

logdict

Note that reduction over specific steps is not implemented yet.

Raises:

ValueError – If the reduction mode is not recognized.

property slice: _SliceProxy#

Provides a convenient interface to slice the logdict. This allows to slice the logdict using the standard slicing syntax.

Returns:

A proxy object that allows slicing the logdict.

Return type:

_SliceProxy

steps: dict[str, stepdict]#
tree_flatten()#

Function that flattens the logdict into a flat list of data and steps. This is used to register the logdict as a pytree in JAX.

Returns:

A tuple containing the flattened data and steps. The first element is a flat list of data values, and the second element is a tuple containing the structure of the data and steps.

Return type:

tuple

classmethod tree_unflatten(structure, logs_flat)#

Function that reconstructs the logdict from a flat list of data and steps. This is used to register the logdict as a pytree in JAX.

Parameters:
  • structure (tuple) – A tuple containing the structure of the data and steps.

  • logs_flat – A flat list of data values and steps.

Returns:

A new logdict instance containing the reconstructed data and steps.

Return type:

logdict