jax_sgmc.adaption
Adapt quantities to the local or global geometry.
Adaption Strategies
Mass Matrix
Manifold
- class jax_sgmc.adaption.Manifold(g_inv: Tensor, sqrt_g_inv: Tensor, gamma: Tensor)[source]
Adapted manifold.
- g_inv
Inverse manifold.
- Type:
jax_sgmc.util.tree_util.Tensor
- sqrt_g_inv
Square root of inverse manifold.
- Type:
jax_sgmc.util.tree_util.Tensor
- gamma
Diffusion to correct for positional dependence of manifold.
- Type:
jax_sgmc.util.tree_util.Tensor
Noise Model
- class jax_sgmc.adaption.NoiseModel(cb_diff_sqrt: Any, b_sqrt: Any)[source]
Approximation of the gradient noise.
- cb_diff_sqrt
Square root of the difference between the friction term and the noise model
- Type:
Any
- b_sqrt
Square root of the noise model
- Type:
Any
- jax_sgmc.adaption.fisher_information(minibatch_potential=None, diagonal=True)[source]
Adapts empirical fisher information.
Use the empirical fisher information as a noise model for SGHMC. The empirical fisher information is approximated according to [1].
- Return type:
- Returns:
Returns noise model approximation strategy.
Developer Information
- class jax_sgmc.adaption.AdaptionState(state: Any, ravel_fn: Callable, unravel_fn: Callable, flat_potential: Callable)[source]
Extended adaption state returned by adaption decorator.
This tuple stores functions to ravel and unravel the parameter and gradient pytree in addition to the adaption state.
- state
State of the adaption strategy
- Type:
Any
- ravel_fn
Jax-partial function to transform pytree to 1D array
- Type:
Callable
- unravel_fn
Jax-partial function to undo ravelling of pytree
- Type:
Callable
- flat_potential
Potential function on the flattened pytree
- Type:
Callable
- jax_sgmc.adaption.get_unravel_fn(tree)[source]
Calculates the unravel function.
- Parameters:
tree (
Any
) – Parameter pytree- Returns:
Returns a jax Partial object such that the function can be passed as valid argument.
- jax_sgmc.adaption.adaption(quantity=<class 'tuple'>)[source]
Decorator to make adaption strategies operate on 1D arrays.
Positional arguments are flattened while keyword arguments are passed unchanged.
- Parameters:
quantity (
namedtuple
) – Namedtuple to specify which fields are returned by :func:get_adaption
.