Source code for jax_sgmc.util.tree_util

# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Defines types special to jax.edited.bak or this library. """

from typing import Any, NamedTuple
from functools import partial

from jax import tree_util
from jax import flatten_util
import jax.numpy as jnp

Array = jnp.ndarray
PyTree = Any

class Tensor(NamedTuple):
  """Vector and matrix pytree-products.

  Attributes:
    ndim: Dimension of the pytree (1: vector, 2: matrix)
    tensor: Data of the pytree

  """
  ndim: int
  tensor: PyTree

def tensor_matmul(matrix: Tensor, vector: PyTree):
  """Matrix vector product with a tensor and a pytree.

  Distinguishes between full matrices and diagonal matrices.

  Args:
    matrix: Matrix in tensor format
    vector: PyTree, which is compatible to the tensor

  """
  if matrix.ndim == 0:
    return tree_scale(matrix.tensor, vector)
  elif matrix.ndim == 1:
    return tree_multiply(matrix.tensor, vector)
  elif matrix.ndim == 2:
    return tree_matmul(matrix.tensor, vector)
  else:
    raise NotImplementedError(f"Cannot multiply matrix with dimension "
                              f"{matrix.ndim}")

[docs]def tree_multiply(tree_a: PyTree, tree_b: PyTree) -> PyTree: """Maps elementwise product over two vectors. Args: a: First pytree b: Second pytree, must have the same shape as a Returns: Returns a PyTree obtained by an element-wise product of all PyTree leaves. """ return tree_util.tree_map(jnp.multiply, tree_a, tree_b)
[docs]def tree_scale(alpha: Array, tree: PyTree) -> PyTree: """Scalar-Pytree product via tree_map. Args: alpha: Scalar a: Arbitrary PyTree Returns: Returns a PyTree with all leaves scaled by alpha. """ @partial(partial, tree_util.tree_map) def tree_scale_imp(x: PyTree): return alpha * x return tree_scale_imp(tree)
[docs]def tree_add(tree_a: PyTree, tree_b: PyTree) -> PyTree: """Maps elementwise sum over PyTrees. Arguments: a: First PyTree b: Second PyTree with the same shape as a Returns: Returns a PyTree obtained by leave-wise summation. """ @partial(partial, tree_util.tree_map) def tree_add_imp(leaf_a, leaf_b): return leaf_a + leaf_b return tree_add_imp(tree_a, tree_b)
def tree_matmul(tree_mat: Array, tree_vec: PyTree): """Matrix tree product for LD on manifold. Arguments: tree_mat: Matrix to be multiplied with flattened tree tree_vec: Tree representing vector Returns: Returns the un-flattened product of the matrix and the flattened tree. """ # Todo: Redefine without need for flatten util vec_flat, unravel_fn = flatten_util.ravel_pytree(tree_vec) return unravel_fn(jnp.matmul(tree_mat, vec_flat)) def tree_dot(tree_a: PyTree, tree_b: PyTree): """Scalar product of two pytrees. Args: tree_a: First pytree tree_b: Second pytree with same tree stree structure and leaf shape as tree_a Returns: Returns a scalar, which is the sum of the element-wise product of all leaves. """ leaves_a = tree_util.tree_leaves(tree_a) leaves_b = tree_util.tree_leaves(tree_b) return sum((jnp.sum(jnp.multiply(a, b)) for a, b in zip(leaves_a, leaves_b)))