jax_sgmc.solver
Solvers for Stochastic Gradient Bayesian Inference.
MCMC
Run multiple chains of a solver in parallel or vectorized and save the results.
- class jax_sgmc.solver.mcmc(solver, scheduler, strategy='map', saving=None, loading=None)[source]
Runs the solver for multiple chains and saves the collected samples
- Parameters:
solver – Computes the next state form a given state.
scheduler – Schedules solver parameters such as temperature and burn in.
strategy – Run multiple chains in parallel or vectorized
saving – Save samples via host_callback (if saving requires much memory)
loading – Restore a previously saved checkpoint (scheduler state and solver state)
- Returns:
Returns function which runs the solver for a given number of iterations.
Solvers
- jax_sgmc.solver.sgmc(integrator)[source]
Initializes the standard SGLD - sampler.
This sampler simply integrates without acceptance/rejection or parallel tempering.
- jax_sgmc.solver.amagold(integrator_fn, full_potential_fn, full_data_map, mass_adaption=None)[source]
Initializes AMAGOLD integration.
- Parameters:
integrator_fn – Reversible leapfrog integrator.
full_potential_fn (
FullPotential
) – Function to calculate true potential.full_data_map (
Tuple
[Any
,FullDataMapFunction
,Callable
[[],None
]]) – Tuple returned byjax_sgmc.data.full_reference_data`()
mass_adaption (
Optional
[Callable
]) – Function to adapt a constant mass during the burn in phase.
- Return type:
- Returns:
Returns the AMAGOLD solver, a combination of reversible stochastic hamiltonian dynamics and amortized MH corrections steps.
- jax_sgmc.solver.sggmc(integrator_fn, full_potential_fn, full_data_map, mass_adaption=None)[source]
Gradient Guided Monte Carlo using Stochastic Gradients.
The OBABO integration scheme is reversible even when using stochastic gradients and provides second order accuracy. Therefore, a MH-acceptance step can be applied to sample from the correct posterior distribution.
- jax_sgmc.solver.parallel_tempering(integrator, sa_schedule=<function <lambda>>)[source]
Exchange samples from a normal and tempered chain.
This solver runs an additional tempered chain, from which no samples are drawn. The normal chain and the additional chain exchange samples at random by a reversible jump process [1].
Solver States
- class jax_sgmc.solver.AMAGOLDState(full_data_state: CacheState, potential: Array, integrator_state: LeapfrogState, key: Array, acceptance_ratio: Array, mass_state: Any)[source]
State of the AMAGOLD solver.
- full_data_state
Cache state for full data mapping function.
- integrator_state
State of the reversible leapfrog integrator
- mass_state
State of the mass adaption
- Type:
Any
- class jax_sgmc.solver.SGGMCState(full_data_state: CacheState, potential: Array, integrator_state: ObaboState, key: Array, acceptance_ratio: Array, mass_state: Any)[source]
State of the AMAGOLD solver.
- full_data_state
Cache state for full data mapping function.
- integrator_state
State of the reversible leapfrog integrator
- mass_state
State of the mass adaption
- Type:
Any