jax_sgmc.util.pytree_leaves_to_list

jax_sgmc.util.pytree_leaves_to_list(pytree)[source]

Splits a pytree in a list of pytrees.

Splits every leaf of the pytree along the first dimension, thus undoing the pytree_list_to_leaves() transformation.

Example usage:

>>> import jax.numpy as jnp
>>> import jax_sgmc.util.list_map as lm
>>>
>>> tree = {"a": jnp.array([0.0, 1.0]), "b": jnp.zeros((2, 2))}
>>>
>>> tree_list = lm.pytree_leaves_to_list(tree)
>>> print(tree_list)
[{'a': Array(0., dtype=float32), 'b': Array([0., 0.], dtype=float32)}, {'a': Array(1., dtype=float32), 'b': Array([0., 0.], dtype=float32)}]
Parameters:

pytree – A single pytree where each leaf has eqal leaf.shape[0].

Returns:

Returns a list of pytrees with similar structure.