Extend Adapted Quantities

Extension of Adaption Strategies

Each adaption strategy is expected to return three functions

@adaption(quantity=SomeQuantity)
def some_adaption(minibatch_potential: Callable = None):
  ...
  return init_adaption, update_adaption, get_adaption

The decorator adaption() wraps all three functions to flatten pytrees to 1D-arrays and unflatten the results of get_adaption().

The rule is that all arguments that are passed by position are expected to have the same shape as the sample pytree and are flattened to 1D-arrays. Arguments that should not be raveled have to be passed by keyword.

  1. init_adaption()

This function initializes the state of the adaption and the ravel- and unravel functions. Therefore, it must accept at least one positional argument with the shape of the sample pytree.

...
def init_adaption(sample, momentum, parameter = 0.5):
  ...

In the example above, the sample and the momentum are 1D-arrays with size equal to the latent variable count. Parameter is a scalar and will not be raveled.

  1. update_adaption()

This function updates the state of the adaption. It must accept at least one positional argument, the state, even if the adaption is stateless.

...
# This is a stateless adaption
def update_adaption(state, *args, **kwargs):
  del state, args, kwargs
  return None

If the factory function of the adaption strategy is called with a potential function as keyword argument (minibatch_potential = some_fun), then update_adaption() is additionally called with the keyword arguments flat_potential and mini_batch. flat_potential is a wrapped version of the original potential function and can be called with the raveled sample.

  1. get_adaption()

This function calculates the desired quantity. Its argument-signature equals update_adaption(). It should return a 1D tuple of values in the right order, such that the quantity of the type NamedTuple can be created by providing positional arguments. For example, if the quantity has the fields q = namedtuple(‘q’, [‘a’, ‘b’, ‘c’]), the get function should look like

...
def get_adaption(state, *args, **kwargs):
  ...
  return a, b, c

The returned arrays can have dimension 1 or 2.

Extension of Quantities

The introduction of quantities simplifies the implementation into an integrator or solver.

For example, adapting a manifold \(G\) for SGLD requires the calculation of \(G^{-1},\ G^{-\frac{1}{2}},\ \text{and}\ \Gamma\). If get_adaption() returns all three quantities in the order

@adaption(quantity=Manifold)
def some_adaption():
  ...
  def get_adaption(state, ...):
    ...
    return g_inv, g_inv_sqrt, gamma

the manifold should be defined as following, where the correct order of filed names is important:

class Manifold(NamedTuple):
  g_inv: PyTree
  g_inv_sqrt: PyTree
  gamma: PyTree

The new get_adaption() does only return a single value of type Manifold.

init_adaption, update_adaption, get_adaption = some_adaption()
...
G = get_adaption(state, ...)