jax_sgmc.util
|
Array base class for JAX |
|
Maps elementwise product over two vectors. |
|
Maps elementwise sum over PyTrees. |
|
Scalar-Pytree product via tree_map. |
|
vmaps a function over similar pytrees. |
|
pmaps a function over similar pytrees. |
|
Transform a list of pytrees to allow pmap/vmap. |
|
Splits a pytree in a list of pytrees. |