jax_sgmc.util.list_pmap

jax_sgmc.util.list_pmap(fun)[source]

pmaps a function over similar pytrees.

Parameters:

fun – Function accepting a single pytree as first argument.

Returns:

Returns a pmapped-function accepting multiple pytree args with similar tree- structure.