jax_sgmc.util.tree_multiply

jax_sgmc.util.tree_multiply(tree_a, tree_b)[source]

Maps elementwise product over two vectors.

Parameters:
  • a – First pytree

  • b – Second pytree, must have the same shape as a

Return type:

Any

Returns:

Returns a PyTree obtained by an element-wise product of all PyTree leaves.