jax_sgmc.util.tree_add

jax_sgmc.util.tree_add(tree_a, tree_b)[source]

Maps elementwise sum over PyTrees.

Parameters:
  • a – First PyTree

  • b – Second PyTree with the same shape as a

Return type:

Any

Returns:

Returns a PyTree obtained by leave-wise summation.