jax_sgmc.solver

Solvers for Stochastic Gradient Bayesian Inference.

MCMC

Run multiple chains of a solver in parallel or vectorized and save the results.

class jax_sgmc.solver.mcmc(solver, scheduler, strategy='map', saving=None, loading=None)[source]

Runs the solver for multiple chains and saves the collected samples

Parameters:
  • solver – Computes the next state form a given state.

  • scheduler – Schedules solver parameters such as temperature and burn in.

  • strategy – Run multiple chains in parallel or vectorized

  • saving – Save samples via host_callback (if saving requires much memory)

  • loading – Restore a previously saved checkpoint (scheduler state and solver state)

Returns:

Returns function which runs the solver for a given number of iterations.

Solvers

jax_sgmc.solver.sgmc(integrator)[source]

Initializes the standard SGLD - sampler.

This sampler simply integrates without acceptance/rejection or parallel tempering.

Parameters:

integrator – sgld or leapfrog.

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns a solver which depends only on external variables from the scheduler, such as the step size.

jax_sgmc.solver.amagold(integrator_fn, full_potential_fn, full_data_map, mass_adaption=None)[source]

Initializes AMAGOLD integration.

Parameters:
  • integrator_fn – Reversible leapfrog integrator.

  • full_potential_fn (FullPotential) – Function to calculate true potential.

  • full_data_map (Tuple[Any, FullDataMapFunction, Callable[[], None]]) – Tuple returned by jax_sgmc.data.full_reference_data`()

  • mass_adaption (Optional[Callable]) – Function to adapt a constant mass during the burn in phase.

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns the AMAGOLD solver, a combination of reversible stochastic hamiltonian dynamics and amortized MH corrections steps.

jax_sgmc.solver.sggmc(integrator_fn, full_potential_fn, full_data_map, mass_adaption=None)[source]

Gradient Guided Monte Carlo using Stochastic Gradients.

The OBABO integration scheme is reversible even when using stochastic gradients and provides second order accuracy. Therefore, a MH-acceptance step can be applied to sample from the correct posterior distribution.

[1] https://arxiv.org/abs/2102.01691

Parameters:
  • integrator_fn – Reversible leapfrog integrator.

  • full_potential_fn (FullPotential) – Function to calculate true potential.

  • mass_adaption (Optional[Callable]) – Function to predict the mass during warmup.

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns the SGGMC solver.

jax_sgmc.solver.parallel_tempering(integrator, sa_schedule=<function <lambda>>)[source]

Exchange samples from a normal and tempered chain.

This solver runs an additional tempered chain, from which no samples are drawn. The normal chain and the additional chain exchange samples at random by a reversible jump process [1].

[1] https://arxiv.org/abs/2008.05367v3

Parameters:
  • integrator – standard langevin diffusion integrator

  • sa_schedule (Callable) – learning rate schedule to estimate the standard deviation of the stochastic potential

Return type:

Tuple[Callable, Callable, Callable]

Returns:

Returns the reSGLD solver.

Solver States

class jax_sgmc.solver.AMAGOLDState(full_data_state: CacheState, potential: Array, integrator_state: LeapfrogState, key: Array, acceptance_ratio: Array, mass_state: Any)[source]

State of the AMAGOLD solver.

full_data_state

Cache state for full data mapping function.

Type:

jax_sgmc.data.core.CacheState

potential

True potential of the current sample

Type:

jax.Array

integrator_state

State of the reversible leapfrog integrator

Type:

jax_sgmc.integrator.LeapfrogState

key

PRNGKey for MH correction step

Type:

jax.Array

acceptance_ratio

Acceptance ratio used in last step.

Type:

jax.Array

mass_state

State of the mass adaption

Type:

Any

class jax_sgmc.solver.SGGMCState(full_data_state: CacheState, potential: Array, integrator_state: ObaboState, key: Array, acceptance_ratio: Array, mass_state: Any)[source]

State of the AMAGOLD solver.

full_data_state

Cache state for full data mapping function.

Type:

jax_sgmc.data.core.CacheState

potential

True potential of the current sample

Type:

jax.Array

integrator_state

State of the reversible leapfrog integrator

Type:

jax_sgmc.integrator.ObaboState

key

PRNGKey for MH correction step

Type:

jax.Array

acceptance_ratio

Acceptance ratio used in last step.

Type:

jax.Array

mass_state

State of the mass adaption

Type:

Any