jax_sgmc.util.tree_scale

jax_sgmc.util.tree_scale(alpha, tree)[source]

Scalar-Pytree product via tree_map.

Parameters:
  • alpha (Array) – Scalar

  • a – Arbitrary PyTree

Return type:

Any

Returns:

Returns a PyTree with all leaves scaled by alpha.