jax_sgmc.adaption

Adapt quantities to the local or global geometry.

Adaption Strategies

Mass Matrix

class jax_sgmc.adaption.MassMatrix(inv: Any, sqrt: Any)[source]

Mass matrix for HMC.

inv

Inverse of the mass matrix

Type:

Any

sqrt

Square root of the mass matrix

Type:

Any

jax_sgmc.adaption.mass_matrix(diagonal=True, burn_in=1000)[source]

Adapt the mass matrix for HMC.

Parameters:
  • diagonal – Restrict the adapted matrix to be diagonal

  • burn_in – Number of steps in which the matrix should be updated

Returns:

Returns an adaption strategy for the mass matrix.

Manifold

class jax_sgmc.adaption.Manifold(g_inv: Tensor, sqrt_g_inv: Tensor, gamma: Tensor)[source]

Adapted manifold.

g_inv

Inverse manifold.

Type:

jax_sgmc.util.tree_util.Tensor

sqrt_g_inv

Square root of inverse manifold.

Type:

jax_sgmc.util.tree_util.Tensor

gamma

Diffusion to correct for positional dependence of manifold.

Type:

jax_sgmc.util.tree_util.Tensor

jax_sgmc.adaption.rms_prop()[source]

RMSprop adaption.

Adapt a diagonal matrix to the local curvature requiring only the stochastic gradient.

Return type:

Callable

Returns:

Returns RMS-prop adaption strategy.

[1] https://arxiv.org/abs/1512.07666

Noise Model

class jax_sgmc.adaption.NoiseModel(cb_diff_sqrt: Any, b_sqrt: Any)[source]

Approximation of the gradient noise.

cb_diff_sqrt

Square root of the difference between the friction term and the noise model

Type:

Any

b_sqrt

Square root of the noise model

Type:

Any

jax_sgmc.adaption.fisher_information(minibatch_potential=None, diagonal=True)[source]

Adapts empirical fisher information.

Use the empirical fisher information as a noise model for SGHMC. The empirical fisher information is approximated according to [1].

Return type:

Callable

Returns:

Returns noise model approximation strategy.

[1] https://arxiv.org/abs/1206.6380

Developer Information

class jax_sgmc.adaption.AdaptionState(state: Any, ravel_fn: Callable, unravel_fn: Callable, flat_potential: Callable)[source]

Extended adaption state returned by adaption decorator.

This tuple stores functions to ravel and unravel the parameter and gradient pytree in addition to the adaption state.

state

State of the adaption strategy

Type:

Any

ravel_fn

Jax-partial function to transform pytree to 1D array

Type:

Callable

unravel_fn

Jax-partial function to undo ravelling of pytree

Type:

Callable

flat_potential

Potential function on the flattened pytree

Type:

Callable

jax_sgmc.adaption.get_unravel_fn(tree)[source]

Calculates the unravel function.

Parameters:

tree (Any) – Parameter pytree

Returns:

Returns a jax Partial object such that the function can be passed as valid argument.

jax_sgmc.adaption.adaption(quantity=<class 'tuple'>)[source]

Decorator to make adaption strategies operate on 1D arrays.

Positional arguments are flattened while keyword arguments are passed unchanged.

Parameters:

quantity (namedtuple) – Namedtuple to specify which fields are returned by :func:get_adaption.