jax_sgmc.data.core.tree_index
- jax_sgmc.data.core.tree_index(pytree, index)[source]
Indexes the leaves of the tree in the first dimension.
- Parameters:
pytree (
Any) – Tree to index with array-like leavesindex – Selects which slice to return
- Returns:
Returns a tree with the same structure as pytree, but the leaves have a dimension reduced by 1.