jax_sgmc.util.pytree_list_to_leaves

jax_sgmc.util.pytree_list_to_leaves(pytrees)[source]

Transform a list of pytrees to allow pmap/vmap.

The trees must have the same tree structure and only differ in the value of their leaves. This means, that the trees might contain custom nodes, such as jax.tree_util.Partial, but those tree nodes must equivalent. For example

>>> from jax.tree_util import Partial
>>>
>>> Partial(lambda x: x + 1) == Partial(lambda x: x + 1)
False

because they are defined on different functions, but are still equivalent as the functions perform the same computations.

Example usage:

>>> import jax.numpy as jnp
>>> import jax_sgmc.util.list_map as lm
>>>
>>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))}
>>> tree_b = {"a": 1.0, "b": jnp.ones((2,))}
>>>
>>> concat_tree = lm.pytree_list_to_leaves([tree_a, tree_b])
>>> print(concat_tree)
{'a': Array([0., 1.], dtype=float32, weak_type=True), 'b': Array([[0., 0.],
       [1., 1.]], dtype=float32)}
Parameters:

pytrees – A list of trees with similar tree structure and equally shaped leaves

Returns:

Returns a tree with the same tree structure but corresponding leaves concatenated along the first dimension.