jax_sgmc.util.pytree_list_to_leaves
- jax_sgmc.util.pytree_list_to_leaves(pytrees)[source]
Transform a list of pytrees to allow pmap/vmap.
The trees must have the same tree structure and only differ in the value of their leaves. This means, that the trees might contain custom nodes, such as
jax.tree_util.Partial
, but those tree nodes must equivalent. For example>>> from jax.tree_util import Partial >>> >>> Partial(lambda x: x + 1) == Partial(lambda x: x + 1) False
because they are defined on different functions, but are still equivalent as the functions perform the same computations.
Example usage:
>>> import jax.numpy as jnp >>> import jax_sgmc.util.list_map as lm >>> >>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))} >>> tree_b = {"a": 1.0, "b": jnp.ones((2,))} >>> >>> concat_tree = lm.pytree_list_to_leaves([tree_a, tree_b]) >>> print(concat_tree) {'a': Array([0., 1.], dtype=float32, weak_type=True), 'b': Array([[0., 0.], [1., 1.]], dtype=float32)}
- Parameters:
pytrees – A list of trees with similar tree structure and equally shaped leaves
- Returns:
Returns a tree with the same tree structure but corresponding leaves concatenated along the first dimension.