jax_sgmc.integrator

Overview

Defines integrators which form the core of the solvers.

Integrators

jax_sgmc.integrator.obabo(potential_fn, batch_fn, steps=10, friction=1.0, const_mass=None)[source]

Initializes the OBABO integration scheme.

The OBABO integration scheme is reversible even when using stochastic gradients and provides second order accuracy.

[1] https://arxiv.org/abs/2102.01691

Parameters:
  • potential_fn (StochasticPotential) – Likelihood and prior applied over a minibatch of data

  • batch_fn (Tuple[Any, GetBatchFunction, Callable[[], None]]) – Function to draw a mini-batch of reference data

  • steps (Array) – Number of integration steps.

  • friction (Array) – Controls impact of momentum from previous step

  • const_mass (Optional[Any]) – Mass matrix if no matrix is adapted. Must have the same tree structure as the sample

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns a function running the time OBABO integrator for T steps.

jax_sgmc.integrator.reversible_leapfrog(potential_fn, batch_fn, steps=10, friction=0.25, const_mass=None)[source]

Initializes a reversible leapfrog integrator.

AMAGOLD requires a reversible leapfrog integrator with half step at the beginning and end.

Parameters:
  • potential_fn (StochasticPotential) – Likelihood and prior applied over a minibatch of data

  • batch_fn (Tuple[Any, GetBatchFunction, Callable[[], None]]) – Function to draw a mini-batch of reference data

  • steps (int) – Number of intermediate leapfrog steps

  • friction ([<class 'float'>, <class 'jax.Array'>]) – Decay of momentum to counteract induced noise due to stochastic gradients

  • const_mass (Optional[Any]) – Mass matrix to be used when no mass matrix is adapted

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns a function running the time reversible leapfrog integrator for T steps.

jax_sgmc.integrator.friction_leapfrog(potential_fn, batch_fn, steps=10, friction=0.25, const_mass=None, noise_model=None)[source]

Initializes the original SGHMC leapfrog integrator.

Original SGHMC from [1].

[1] https://arxiv.org/pdf/1402.4102.pdf

Parameters:
  • potential_fn (StochasticPotential) – Likelihood and prior applied over a minibatch of data

  • batch_fn (Tuple[Any, GetBatchFunction, Callable[[], None]]) – Function to draw a mini-batch of reference data

  • steps (int) – Number of intermediate leapfrog steps

  • friction (Array) – Decay of momentum to counteract induced noise due to stochastic gradients

  • const_mass (Optional[Any]) – Mass matrix of the hamiltonian process

  • noise_model – Stateless adaption of the noise (e. g. via the empirical fisher information)

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns a function running the non-conservative leapfrog integrator for T steps.

jax_sgmc.integrator.langevin_diffusion(potential_fn, batch_fn, adaption=None)[source]

Initializes langevin diffusion integrator.

Parameters:
Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns a tuple consisting of ini_fn, update_fn, get_fn. The init_fn takes the arguments

  • key: Initial PRNGKey

  • adaption_kwargs: Additional arguments to determine the initial manifold state

  • batch_kwargs: Determine the state of the random data chain

Integrator States

class jax_sgmc.integrator.ObaboState(positions: Any, momentum: Any, potential: Array, model_state: Any, data_state: CacheState, key: Array, kinetic_energy_start: Array, kinetic_energy_end: Array)[source]

State of the OBABO integrator.

positions

Latent variables in hamiltonian dynamics formulation

Type:

Any

momentum

Momentum with the same shape as the latent variables

Type:

Any

potential

Stochastic potential

Type:

jax.Array

model_state

State of the model in the likelihood

Type:

Any

data_state

State of the random data function

Type:

jax_sgmc.data.core.CacheState

key

PRNGKey

Type:

jax.Array

kinetic_energy_start

Kinetic energy after the first 1/4-step

Type:

jax.Array

kinetic_energy_end

Kinetic energy after the last 3/4-step

Type:

jax.Array

class jax_sgmc.integrator.LeapfrogState(positions: Any, momentum: Any, potential: Array, model_state: Any, data_state: CacheState, key: Array, extra_fields: Optional[Any] = None)[source]

State of the reversible and friction leapfrog integrator.

positions

Latent variables in hamiltonian dynamics formulation

Type:

Any

momentum

Momentum with the same shape as the latent variables

Type:

Any

potential

Accumulated energy to calculate MH-correction step

Type:

jax.Array

model_state

State of the model in the likelihood

Type:

Any

data_state

State of the random data function

Type:

jax_sgmc.data.core.CacheState

key

PRNGKey

Type:

jax.Array

class jax_sgmc.integrator.LangevinState(latent_variables: Any, model_state: Any, key: Array, adapt_state: Any, data_state: CacheState, potential: Array, variance: Array)[source]

State of the langevin diffusion integrator.

latent_variables

Current latent variables

Type:

Any

key

PRNGKey

Type:

jax.Array

adapt_state

Containing quantities such as momentum for adaption

Type:

Any

data_state

State of the reference data cache

Type:

jax_sgmc.data.core.CacheState

model_state

Variables not considered during inference

Type:

Any

potential

Stochastic potential from last evaluation

Type:

jax.Array

variance

Variance of stochastic potential over mini-batch

Type:

jax.Array

Utility

jax_sgmc.integrator.random_tree(key, a)[source]

Build a tree shaped like a where all nodes are normally distributed.

Parameters:
  • key – PRNGKey

  • a – PyTree defining the shape of the output

Returns:

Tree shaped like a with normal distributed leaves.

jax_sgmc.integrator.init_mass(mass)[source]

Initializes a diagonal mass tensor.

Parameters:

mass – Diagonal mass which has the same tree structure as the sample.

Return type:

MassMatrix

Returns:

Returns a diagonal mass matrix.