jax_sgmc.util.list_vmap

jax_sgmc.util.list_vmap(fun)[source]

vmaps a function over similar pytrees.

Example usage:

>>> from jax import tree_map
>>> import jax.numpy as jnp
>>> import jax_sgmc.util.list_map as lm
>>>
>>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))}
>>> tree_b = {"a": 1.0, "b": jnp.ones((2,))}
>>>
... @lm.list_vmap
... def tree_add(pytree):
...   return tree_map(jnp.subtract, pytree, tree_b)
>>>
>>> print(tree_add(tree_a, tree_b))
[{'a': Array(-1., dtype=float32, weak_type=True), 'b': Array([-1., -1.], dtype=float32)}, {'a': Array(0., dtype=float32, weak_type=True), 'b': Array([0., 0.], dtype=float32)}]
Parameters:

fun – Function accepting a single pytree as first argument.

Returns:

Returns a vmapped-function accepting multiple pytree args with similar tree- structure.