lox.logdict#
- class lox.logdict(data, **steps)#
Bases:
dict[str,Any]A dictionary that stores log values and the steps at which they were recorded. This class extends the standard dictionary to include step information, allowing for structured logging of data during computations. It is the underlying data structure used for all logging in lox. It behaves identically to a standard dictionary, but it additionally contains a
stepsattribute that stores timestamps at which the data was logged. Internallystepstwo level dictionary, where at the first level are the names of the timestamps (e.g."step","episode", etc.), and at the second level are the actual timestamps. However, it is not supposed to be accessed directly. Instedlogdictprovides a convenient interface to access the steps as attributes.>>> _, logs = lox.spool(f)() >>> logs["loss"] [1.0, 0.8, 0.6, 0.5, 0.4, 0.3, 0.25, 0.2, 0.2] >>> logs.step["loss"] [0, 1, 2, 3, 4, 5, 6, 7, 8] >>> loss.episode["loss"] [0, 0, 0, 1, 1, 1, 2, 2, 2]
- __init__(data, **steps)#
Methods
- __add__(other)#
Adds two logdicts together, concatenating their data and steps. This is essentially identical to executing two functions in sequence and then collecting their logs. Assuming
fandgare two pure functions containing arbitrary logs, then contents of the variablelogswill be the same in the following two codeblocks._, logs_f = lox.spool(f)() _, logs_g = lox.spool(g)() logs = logs_f + logs_g
def h(): f(); g() _, logs = lox.spool(h)()
Note that this is not the same as updating the dict. If both logdicts contain values for the same key, the values will be concatenated, assuming they have the same structure. If they do not, an error will be raised.
- __or__(other)#
Merges two logdicts, overwriting values from the self dict if they exist in both. The same happens for the steps.
- property data: dict[str, Any]#
Returns: dict: The data stored in the logdict as a standard dictionary.
- filter(predicate)#
Filters the logdict values and steps based on a predicate function.
- Parameters:
predicate (Callable[[str, Any], bool]) – A function that takes a key and value and returns True if the item should be kept, False otherwise.
- Returns:
A new logdict containing only the items that satisfy the predicate.
- Return type:
- prefix(prefix)#
Adds a prefix to all keys in the logdict.
- Parameters:
prefix (str) – The prefix to add to each key.
- Returns:
A new logdict with prefixed keys.
- Return type:
- reduce(mode='mean', keep_steps=True)#
Reduces the logdict values using the specified mode.
- Parameters:
mode (str) – The reduction mode to apply. Can be one of “mean”, “first” or “last”.
keep_steps (bool) – Whether to keep the steps in the reduced logdict.
- Returns:
A new logdict containing the reduced values and optionally the steps.
- Return type:
Note that reduction over specific steps is not implemented yet.
- Raises:
ValueError – If the reduction mode is not recognized.
- property slice: _SliceProxy#
Provides a convenient interface to slice the logdict. This allows to slice the logdict using the standard slicing syntax.
- Returns:
A proxy object that allows slicing the logdict.
- Return type:
_SliceProxy
-
steps:
dict[str,stepdict]#
- tree_flatten()#
Function that flattens the logdict into a flat list of data and steps. This is used to register the logdict as a pytree in JAX.
- Returns:
A tuple containing the flattened data and steps. The first element is a flat list of data values, and the second element is a tuple containing the structure of the data and steps.
- Return type:
tuple
- classmethod tree_unflatten(structure, logs_flat)#
Function that reconstructs the logdict from a flat list of data and steps. This is used to register the logdict as a pytree in JAX.
- Parameters:
structure (tuple) – A tuple containing the structure of the data and steps.
logs_flat – A flat list of data values and steps.
- Returns:
A new logdict instance containing the reconstructed data and steps.
- Return type: