jax_sgmc.alias.sggmc

jax_sgmc.alias.sggmc(stochastic_potential_fn, full_potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction_coefficient=1.0, first_step_size=0.001, last_step_size=0.001, adaptive_step_size=False, stabilization_constant=10, decay_constant=0.75, speed_constant=0.05, target_acceptance_rate=0.25, burn_in=0, accepted_samples=None, mass=None, save_to_numpy=True, progress_bar=True)[source]

Stochastic gradient guided monte carlo.

The SGGMC solver is based on the OBABO integrator, which is reversible when using stochastic gradients. Moreover, the calculation of the full potential is only necessary once per MH-correction step, which can be applied after multiple iterations [1].

[1] https://arxiv.org/abs/2102.01691

sggmc_run = alias.sggmc(...)

sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}

results = sggmc_run(
  sample, init_model_state=0, iterations=5000
  )[0]['samples']['variables']
Parameters:
  • stochastic_potential_fn (StochasticPotential) – Stochastic potential over a minibatch of data

  • full_potential_fn (FullPotential) – Potential from full dataset

  • data_loader (DataLoader) – Data loader, e. g. numpy data loader

  • cache_size (int) – Number of mini_batches in device memory

  • batch_size (int) – Number of observations per batch

  • integration_steps (int) – Number of leapfrog-steps before each MH-correction step

  • friction_coefficient (float) – Positive parameter controling amount of refreshed momentum

  • first_step_size (float) – First step size for polynomial and adaptive step size schedule

  • last_step_size (float) – Final step size of the polynomial step size schedule

  • adaptive_step_size (bool) – Adapt the step size to optimize the acceptance rate during burn in

  • stabilization_constant (int) – Larger numbers reduce the impact of the initial steps on the step size

  • decay_constant (float) – Larger values reduce impact of later steps

  • speed_constant (float) – Speed of adaption of the step size

  • target_acceptance_rate (float) – Target of the adption of the step sizes

  • burn_in (int) – Number of samples to skip before collecting samples

  • accepted_samples (Optional[int]) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_in. If None, no thinning wil be applied.

  • mass (Optional[Any]) – Diagonal mass for HMC-dynamics

  • save_to_numpy (bool) – Save on host in numpy array instead of in device memory

  • progress_bar (bool) – Print the progress of the solver

Returns:

Returns a solver function which can be applied to multiple chains starting at init_sample.