jax_sgmc.util

Array()

Array base class for JAX

tree_multiply(tree_a, tree_b)

Maps elementwise product over two vectors.

tree_add(tree_a, tree_b)

Maps elementwise sum over PyTrees.

tree_scale(alpha, tree)

Scalar-Pytree product via tree_map.

list_vmap(fun)

vmaps a function over similar pytrees.

list_pmap(fun)

pmaps a function over similar pytrees.

pytree_list_to_leaves(pytrees)

Transform a list of pytrees to allow pmap/vmap.

pytree_leaves_to_list(pytree)

Splits a pytree in a list of pytrees.