jax_sgmc.alias.sggmc
- jax_sgmc.alias.sggmc(stochastic_potential_fn, full_potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction_coefficient=1.0, first_step_size=0.001, last_step_size=0.001, adaptive_step_size=False, stabilization_constant=10, decay_constant=0.75, speed_constant=0.05, target_acceptance_rate=0.25, burn_in=0, accepted_samples=None, mass=None, save_to_numpy=True, progress_bar=True)[source]
Stochastic gradient guided monte carlo.
The SGGMC solver is based on the OBABO integrator, which is reversible when using stochastic gradients. Moreover, the calculation of the full potential is only necessary once per MH-correction step, which can be applied after multiple iterations [1].
[1] https://arxiv.org/abs/2102.01691
sggmc_run = alias.sggmc(...) sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)} results = sggmc_run( sample, init_model_state=0, iterations=5000 )[0]['samples']['variables']
- Parameters:
stochastic_potential_fn (
StochasticPotential) – Stochastic potential over a minibatch of datafull_potential_fn (
FullPotential) – Potential from full datasetdata_loader (
DataLoader) – Data loader, e. g. numpy data loadercache_size (
int) – Number of mini_batches in device memorybatch_size (
int) – Number of observations per batchintegration_steps (
int) – Number of leapfrog-steps before each MH-correction stepfriction_coefficient (
float) – Positive parameter controling amount of refreshed momentumfirst_step_size (
float) – First step size for polynomial and adaptive step size schedulelast_step_size (
float) – Final step size of the polynomial step size scheduleadaptive_step_size (
bool) – Adapt the step size to optimize the acceptance rate during burn instabilization_constant (
int) – Larger numbers reduce the impact of the initial steps on the step sizedecay_constant (
float) – Larger values reduce impact of later stepsspeed_constant (
float) – Speed of adaption of the step sizetarget_acceptance_rate (
float) – Target of the adption of the step sizesburn_in (
int) – Number of samples to skip before collecting samplesaccepted_samples (
Optional[int]) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_in. If None, no thinning wil be applied.save_to_numpy (
bool) – Save on host in numpy array instead of in device memoryprogress_bar (
bool) – Print the progress of the solver
- Returns:
Returns a solver function which can be applied to multiple chains starting at
init_sample.