jax_sgmc.alias.amagold
- jax_sgmc.alias.amagold(stochastic_potential_fn, full_potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction=0.25, first_step_size=0.001, last_step_size=0.001, adaptive_step_size=False, stabilization_constant=10, decay_constant=0.75, speed_constant=0.05, target_acceptance_rate=0.25, burn_in=0, accepted_samples=None, mass=None, save_to_numpy=True, progress_bar=True)[source]
Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC.
The AMAGOLD solver constructs a skew-reversible markov chain, such that MH-correction steps can be applied periodically to sample from the correct distribution [1].
[1] https://arxiv.org/abs/2003.00193
amagold_run = alias.amagold(...) sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)} results = amagold_run( sample, init_model_state=0, iterations=5000 )[0]['samples']['variables']
- Parameters:
stochastic_potential_fn (
StochasticPotential) – Stochastic potential over a minibatch of datafull_potential_fn (
FullPotential) – Potential from full datasetdata_loader (
DataLoader) – Data loader, e. g. numpy data loadercache_size (
int) – Number of mini_batches in device memorybatch_size (
int) – Number of observations per batchintegration_steps (
int) – Number of leapfrog-steps before each MH-correction stepfriction (
float) – Parameter between 0.0 and 1.0 controlling decay of momentumfirst_step_size (
float) – First step size for polynomial and adaptive step size schedulelast_step_size (
float) – Final step size of the polynomial step size scheduleadaptive_step_size (
bool) – Adapt the step size to optimize the acceptance rate during burn instabilization_constant (
int) – Larger numbers reduce the impact of the initial steps on the step sizedecay_constant (
float) – Larger values reduce impact of later stepsspeed_constant (
float) – Speed of adaption of the step sizetarget_acceptance_rate (
float) – Target of the adption of the step sizesburn_in (
int) – Number of samples to skip before collecting samplesaccepted_samples (
Optional[int]) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_in. If None, no thinning wil be applied.save_to_numpy (
bool) – Save on host in numpy array instead of in device memoryprogress_bar (
bool) – Print the progress of the solver
- Returns:
Returns a solver function which can be applied to multiple chains starting at
init_sample.