jax_sgmc.integrator
Overview
Defines integrators which form the core of the solvers.
Integrators
- jax_sgmc.integrator.obabo(potential_fn, batch_fn, steps=10, friction=1.0, const_mass=None)[source]
Initializes the OBABO integration scheme.
The OBABO integration scheme is reversible even when using stochastic gradients and provides second order accuracy.
[1] https://arxiv.org/abs/2102.01691
- Parameters:
potential_fn (
StochasticPotential
) – Likelihood and prior applied over a minibatch of databatch_fn (
Tuple
[Any
,GetBatchFunction
,Callable
[[],None
]]) – Function to draw a mini-batch of reference datasteps (
Array
) – Number of integration steps.friction (
Array
) – Controls impact of momentum from previous stepconst_mass (
Optional
[Any
]) – Mass matrix if no matrix is adapted. Must have the same tree structure as the sample
- Return type:
- Returns:
Returns a function running the time OBABO integrator for T steps.
- jax_sgmc.integrator.reversible_leapfrog(potential_fn, batch_fn, steps=10, friction=0.25, const_mass=None)[source]
Initializes a reversible leapfrog integrator.
AMAGOLD requires a reversible leapfrog integrator with half step at the beginning and end.
- Parameters:
potential_fn (
StochasticPotential
) – Likelihood and prior applied over a minibatch of databatch_fn (
Tuple
[Any
,GetBatchFunction
,Callable
[[],None
]]) – Function to draw a mini-batch of reference datasteps (
int
) – Number of intermediate leapfrog stepsfriction ([<class 'float'>, <class 'jax.Array'>]) – Decay of momentum to counteract induced noise due to stochastic gradients
const_mass (
Optional
[Any
]) – Mass matrix to be used when no mass matrix is adapted
- Return type:
- Returns:
Returns a function running the time reversible leapfrog integrator for T steps.
- jax_sgmc.integrator.friction_leapfrog(potential_fn, batch_fn, steps=10, friction=0.25, const_mass=None, noise_model=None)[source]
Initializes the original SGHMC leapfrog integrator.
Original SGHMC from [1].
[1] https://arxiv.org/pdf/1402.4102.pdf
- Parameters:
potential_fn (
StochasticPotential
) – Likelihood and prior applied over a minibatch of databatch_fn (
Tuple
[Any
,GetBatchFunction
,Callable
[[],None
]]) – Function to draw a mini-batch of reference datasteps (
int
) – Number of intermediate leapfrog stepsfriction (
Array
) – Decay of momentum to counteract induced noise due to stochastic gradientsconst_mass (
Optional
[Any
]) – Mass matrix of the hamiltonian processnoise_model – Stateless adaption of the noise (e. g. via the empirical fisher information)
- Return type:
- Returns:
Returns a function running the non-conservative leapfrog integrator for T steps.
- jax_sgmc.integrator.langevin_diffusion(potential_fn, batch_fn, adaption=None)[source]
Initializes langevin diffusion integrator.
- Parameters:
potential_fn (
StochasticPotential
) – Likelihood and prior applied over a minibatch of databatch_fn (
Tuple
[Any
,GetBatchFunction
,Callable
[[],None
]]) – Function to draw a mini-batch of reference dataadaption (
Optional
[Callable
]) – Adaption of manifold for faster inference
- Return type:
- Returns:
Returns a tuple consisting of
ini_fn
,update_fn
,get_fn
. The init_fn takes the argumentskey: Initial PRNGKey
adaption_kwargs: Additional arguments to determine the initial manifold state
batch_kwargs: Determine the state of the random data chain
Integrator States
- class jax_sgmc.integrator.ObaboState(positions: Any, momentum: Any, potential: Array, model_state: Any, data_state: CacheState, key: Array, kinetic_energy_start: Array, kinetic_energy_end: Array)[source]
State of the OBABO integrator.
- positions
Latent variables in hamiltonian dynamics formulation
- Type:
Any
- momentum
Momentum with the same shape as the latent variables
- Type:
Any
- model_state
State of the model in the likelihood
- Type:
Any
- data_state
State of the random data function
- class jax_sgmc.integrator.LeapfrogState(positions: Any, momentum: Any, potential: Array, model_state: Any, data_state: CacheState, key: Array, extra_fields: Optional[Any] = None)[source]
State of the reversible and friction leapfrog integrator.
- positions
Latent variables in hamiltonian dynamics formulation
- Type:
Any
- momentum
Momentum with the same shape as the latent variables
- Type:
Any
- model_state
State of the model in the likelihood
- Type:
Any
- data_state
State of the random data function
- class jax_sgmc.integrator.LangevinState(latent_variables: Any, model_state: Any, key: Array, adapt_state: Any, data_state: CacheState, potential: Array, variance: Array)[source]
State of the langevin diffusion integrator.
- latent_variables
Current latent variables
- Type:
Any
- adapt_state
Containing quantities such as momentum for adaption
- Type:
Any
- data_state
State of the reference data cache
- model_state
Variables not considered during inference
- Type:
Any
Utility
- jax_sgmc.integrator.random_tree(key, a)[source]
Build a tree shaped like a where all nodes are normally distributed.
- Parameters:
key – PRNGKey
a – PyTree defining the shape of the output
- Returns:
Tree shaped like a with normal distributed leaves.