jax_sgmc.util.pytree_leaves_to_list
- jax_sgmc.util.pytree_leaves_to_list(pytree)[source]
Splits a pytree in a list of pytrees.
Splits every leaf of the pytree along the first dimension, thus undoing the
pytree_list_to_leaves()
transformation.Example usage:
>>> import jax.numpy as jnp >>> import jax_sgmc.util.list_map as lm >>> >>> tree = {"a": jnp.array([0.0, 1.0]), "b": jnp.zeros((2, 2))} >>> >>> tree_list = lm.pytree_leaves_to_list(tree) >>> print(tree_list) [{'a': Array(0., dtype=float32), 'b': Array([0., 0.], dtype=float32)}, {'a': Array(1., dtype=float32), 'b': Array([0., 0.], dtype=float32)}]
- Parameters:
pytree – A single pytree where each leaf has eqal leaf.shape[0].
- Returns:
Returns a list of pytrees with similar structure.