jax_sgmc.util.list_vmap
- jax_sgmc.util.list_vmap(fun)[source]
vmaps a function over similar pytrees.
Example usage:
>>> from jax import tree_map >>> import jax.numpy as jnp >>> import jax_sgmc.util.list_map as lm >>> >>> tree_a = {"a": 0.0, "b": jnp.zeros((2,))} >>> tree_b = {"a": 1.0, "b": jnp.ones((2,))} >>> ... @lm.list_vmap ... def tree_add(pytree): ... return tree_map(jnp.subtract, pytree, tree_b) >>> >>> print(tree_add(tree_a, tree_b)) [{'a': Array(-1., dtype=float32, weak_type=True), 'b': Array([-1., -1.], dtype=float32)}, {'a': Array(0., dtype=float32, weak_type=True), 'b': Array([0., 0.], dtype=float32)}]
- Parameters:
fun – Function accepting a single pytree as first argument.
- Returns:
Returns a vmapped-function accepting multiple pytree args with similar tree- structure.