jax_sgmc.data.core.tree_index

jax_sgmc.data.core.tree_index(pytree, index)[source]

Indexes the leaves of the tree in the first dimension.

Parameters:
  • pytree (Any) – Tree to index with array-like leaves

  • index – Selects which slice to return

Returns:

Returns a tree with the same structure as pytree, but the leaves have a dimension reduced by 1.