jax_sgmc.alias.amagold

jax_sgmc.alias.amagold(stochastic_potential_fn, full_potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction=0.25, first_step_size=0.001, last_step_size=0.001, adaptive_step_size=False, stabilization_constant=10, decay_constant=0.75, speed_constant=0.05, target_acceptance_rate=0.25, burn_in=0, accepted_samples=None, mass=None, save_to_numpy=True, progress_bar=True)[source]

Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC.

The AMAGOLD solver constructs a skew-reversible markov chain, such that MH-correction steps can be applied periodically to sample from the correct distribution [1].

[1] https://arxiv.org/abs/2003.00193

amagold_run = alias.amagold(...)

sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}

results = amagold_run(
  sample, init_model_state=0, iterations=5000
  )[0]['samples']['variables']
Parameters:
  • stochastic_potential_fn (StochasticPotential) – Stochastic potential over a minibatch of data

  • full_potential_fn (FullPotential) – Potential from full dataset

  • data_loader (DataLoader) – Data loader, e. g. numpy data loader

  • cache_size (int) – Number of mini_batches in device memory

  • batch_size (int) – Number of observations per batch

  • integration_steps (int) – Number of leapfrog-steps before each MH-correction step

  • friction (float) – Parameter between 0.0 and 1.0 controlling decay of momentum

  • first_step_size (float) – First step size for polynomial and adaptive step size schedule

  • last_step_size (float) – Final step size of the polynomial step size schedule

  • adaptive_step_size (bool) – Adapt the step size to optimize the acceptance rate during burn in

  • stabilization_constant (int) – Larger numbers reduce the impact of the initial steps on the step size

  • decay_constant (float) – Larger values reduce impact of later steps

  • speed_constant (float) – Speed of adaption of the step size

  • target_acceptance_rate (float) – Target of the adption of the step sizes

  • burn_in (int) – Number of samples to skip before collecting samples

  • accepted_samples (Optional[int]) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_in. If None, no thinning wil be applied.

  • mass (Optional[Any]) – Diagonal mass for HMC-dynamics

  • save_to_numpy (bool) – Save on host in numpy array instead of in device memory

  • progress_bar (bool) – Print the progress of the solver

Returns:

Returns a solver function which can be applied to multiple chains starting at init_sample.