jax_sgmc.alias.sghmc

jax_sgmc.alias.sghmc(potential_fn, data_loader, cache_size=512, batch_size=32, integration_steps=10, friction=1.0, mass=None, first_step_size=0.05, last_step_size=0.001, burn_in=0, accepted_samples=1000, adapt_noise_model=False, diagonal_noise=True, save_to_numpy=True, progress_bar=True)[source]

Stochastic Gradient Hamiltonian Monte Carlo.

SGHMC improves the exploratory power of SGLD by introducing momentum [1].

[1] https://arxiv.org/abs/1402.4102

sghmc_run = alias.sghmc(...)

sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
results = sghmc_run(sample, init_model_state=0, iterations=5000)[0]['samples']['variables']
Parameters:
  • potential_fn (minibatch_potential) – Stochastic potential over a minibatch of data

  • data_loader (DataLoader) – Data loader, e. g. numpy data loader

  • cache_size (int) – Number of mini_batches in device memory

  • batch_size (int) – Number of observations per batch

  • integration_steps (int) – Number of leapfrog steps before resampling the momentum

  • friction (Union[float, Any]) – Friction to counteract noise introduced by stochastic gradients. Can be specified for each variable or for all variables (scalar value)

  • mass (Optional[Any]) – Diagonal mass to be used for hamiltonian dynamics

  • first_step_size (float) – First step size

  • last_step_size (float) – Final step size

  • burn_in (int) – Number of samples to skip before collecting samples

  • accepted_samples (int) – Total number of samples to collect, will be determined by random thinning if accepted samples < iterations - burn_in

  • adapt_noise_model (bool) – Estimate the gradient noise to speed up the convergence.

  • diagonal_noise (bool) – Restrict the noise estimate to be diagonal.

  • save_to_numpy (bool) – Save on host in numpy array instead of in device memory

  • progress_bar (bool) – Print the progress of the solver

Returns:

Returns a solver function which can be applied to multiple chains starting at init_sample.