jax_sgmc.alias.obabo
- jax_sgmc.alias.obabo(potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction=1.0, mass=None, first_step_size=0.05, last_step_size=0.001, burn_in=0, accepted_samples=1000, save_to_numpy=True, progress_bar=True)[source]
Langevin Monte Carlo with partial momentum refreshment.
[1] https://arxiv.org/abs/2102.01691
sghmc_run = alias.obabo(...) sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)} results = sghmc_run(sample, init_model_state=0, iterations=5000)[0]['samples']['variables']
- Parameters:
potential_fn (
minibatch_potential) – Stochastic potential over a minibatch of datadata_loader (
DataLoader) – Data loader, e. g. numpy data loadercache_size (
int) – Number of mini_batches in device memorybatch_size (
int) – Number of observations per batchintegration_steps (
int) – Number of leapfrog steps before resampling the momentumfriction (
Union[float,Any]) – Positive parameter controling amount of refreshed momentummass (
Optional[Any]) – Diagonal mass to be used for hamiltonian dynamicsfirst_step_size (
float) – First step sizelast_step_size (
float) – Final step sizeburn_in (
int) – Number of samples to skip before collecting samplesaccepted_samples (
int) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_insave_to_numpy (
bool) – Save on host in numpy array instead of in device memoryprogress_bar (
bool) – Print the progress of the solver
- Returns:
Returns a solver function which can be applied to multiple chains starting at
init_sample.