# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Popular solvers ready to use.
While JaxSGMC has been designed to be flexible, starting with full
flexibility can be complicated. Therefore, this file contains some popular
solvers with preset properties, which can be applied directly to the problem or
used as a guide to set up a custom solver.
"""
import warnings
from functools import partial
from typing import Any, Union
from jax_sgmc import data, potential, adaption, integrator, scheduler, solver, io
Pytree = Any
[docs]def sgld(potential_fn: potential.minibatch_potential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
first_step_size: float = 0.05,
last_step_size: float = 0.001,
burn_in: int = 0,
accepted_samples: int = 1000,
rms_prop: bool = False,
alpha: float = 0.9,
lmbd: float = 1e-5,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Stochastic Gradient Langevin Dynamics.
SGLD with a polynomial step size schedule and optional speed up via RMS-prop
adaption [1].
[1] https://arxiv.org/abs/1512.07666
::
rms_run = alias.sgld(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(10.0)}
results = rms_run(sample, init_model_state=0, iterations=50000)[0]['samples']['variables']
Args:
potential_fn: Stochastic potential over a minibatch of data
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
first_step_size: First step size
last_step_size: Final step size
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in
rms_prop: Whether to adapt a manifold via the RMSprop strategy
alpha: Decay speed of previous manifold approximations
lmbd: Stabilization parameter
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``init_sample``. If the likelihood is stateful, an initial state must be
provided.
"""
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
if rms_prop:
rms_prop = adaption.rms_prop()
else:
rms_prop = None
rms_integrator = integrator.langevin_diffusion(
potential_fn, random_data, adaption=rms_prop)
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size, last=last_step_size)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
progress_bar=progress_bar)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
sgld_solver = solver.sgmc(rms_integrator)
mcmc = solver.mcmc(sgld_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations = 1000):
init_with_adaption_kwargs = partial(
sgld_solver[0],
adaption_kwargs={
'alpha': alpha,
'lmbd': lmbd,
},
init_model_state=init_model_state)
states = map(init_with_adaption_kwargs, init_samples)
return mcmc(*states, iterations=iterations)
return run_fn
[docs]def re_sgld(potential_fn: potential.minibatch_potential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
temperature: float = 1000.0,
first_step_size: float = 0.05,
last_step_size: float = 0.001,
burn_in: int = 0,
accepted_samples: int = 100,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Replica Exchange Stochastic Gradient Langevin Diffusion.
reSGLD simulates a tempered and a default chain in parallel, which
exchange samples at random following a (biased) markov jump process [1].
[1] https://arxiv.org/abs/2008.05367v3
::
resgld_run = alias.re_sgld(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
init_samples = [(sample, sample), (sample, sample), (sample, sample)]
results = resgld_run(
*init_samples, init_model_state=0, iterations=50000
)[0]['samples']['variables']
Args:
potential_fn: Stochastic potential over a minibatch of data
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
temperature: Temperature at which the helper chain should run
first_step_size: First step size
last_step_size: Final step size
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``start_chain_{idx}``.
"""
del progress_bar
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
resgld_integrator = integrator.langevin_diffusion(
potential_fn, random_data)
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size, last=last_step_size)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
temperature_schedule = scheduler.constant_temperature(1.0)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
temperature=temperature_schedule,
progress_bar=False)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
resgld_solver = solver.parallel_tempering(resgld_integrator)
mcmc = solver.mcmc(resgld_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations = 1000):
init_resgld_fn = partial(
resgld_solver[0],
init_model_state=init_model_state)
states = map(init_resgld_fn, *zip(*init_samples))
return mcmc(*states,
iterations=iterations,
schedulers=[{'temperature': {'tau': temperature}}])
return run_fn
[docs]def amagold(stochastic_potential_fn: potential.StochasticPotential,
full_potential_fn: potential.FullPotential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
integration_steps: int = 10,
friction: float = 0.25,
first_step_size: float = 0.001,
last_step_size: float = 0.001,
adaptive_step_size: bool = False,
stabilization_constant: int = 10,
decay_constant: float = 0.75,
speed_constant: float = 0.05,
target_acceptance_rate: float = 0.25,
burn_in: int = 0,
accepted_samples: Union[int, None] = None,
mass: Pytree = None,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Amortized Metropolis Adjustment for Efficient Stochastic Gradient MCMC.
The AMAGOLD solver constructs a skew-reversible markov chain, such that
MH-correction steps can be applied periodically to sample from the correct
distribution [1].
[1] https://arxiv.org/abs/2003.00193
::
amagold_run = alias.amagold(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
results = amagold_run(
sample, init_model_state=0, iterations=5000
)[0]['samples']['variables']
Args:
stochastic_potential_fn: Stochastic potential over a minibatch of data
full_potential_fn: Potential from full dataset
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
integration_steps: Number of leapfrog-steps before each MH-correction step
friction: Parameter between 0.0 and 1.0 controlling decay of momentum
first_step_size: First step size for polynomial and adaptive step size
schedule
last_step_size: Final step size of the polynomial step size schedule
adaptive_step_size: Adapt the step size to optimize the acceptance rate
during burn in
stabilization_constant: Larger numbers reduce the impact of the initial
steps on the step size
decay_constant: Larger values reduce impact of later steps
speed_constant: Speed of adaption of the step size
target_acceptance_rate: Target of the adption of the step sizes
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in.
If None, no thinning wil be applied.
mass: Diagonal mass for HMC-dynamics
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``init_sample``.
"""
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
full_data_map = data.full_reference_data(data_loader, cache_size, batch_size)
reversible_leapfrog = integrator.reversible_leapfrog(
stochastic_potential_fn, random_data, integration_steps, friction, mass)
amagold_solver = solver.amagold(
reversible_leapfrog, full_potential_fn, full_data_map)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
if adaptive_step_size:
step_size_schedule = scheduler.adaptive_step_size(
burn_in=burn_in,
initial_step_size=first_step_size,
stabilization_constant=stabilization_constant,
decay_constant=decay_constant,
speed_constant=speed_constant,
target_acceptance_rate=target_acceptance_rate)
random_thinning_schedule = None
assert accepted_samples is None, ('Thinning currently not supported for'
' adaptive step size.')
else:
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size,
last=last_step_size)
if accepted_samples is None:
random_thinning_schedule = None
else:
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
progress_bar=progress_bar)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
mcmc = solver.mcmc(amagold_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations=1000):
init_amagold_fn = partial(
amagold_solver[0],
init_model_state = init_model_state)
states = map(init_amagold_fn, init_samples)
return mcmc(*states, iterations=iterations)
return run_fn
[docs]def sggmc(stochastic_potential_fn: potential.StochasticPotential,
full_potential_fn: potential.FullPotential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
integration_steps: int = 10,
friction_coefficient: float = 1.0,
first_step_size: float = 0.001,
last_step_size: float = 0.001,
adaptive_step_size: bool = False,
stabilization_constant: int = 10,
decay_constant: float = 0.75,
speed_constant: float = 0.05,
target_acceptance_rate: float = 0.25,
burn_in: int = 0,
accepted_samples: Union[int, None] = None,
mass: Pytree = None,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Stochastic gradient guided monte carlo.
The SGGMC solver is based on the OBABO integrator, which is reversible
when using stochastic gradients. Moreover, the calculation of the full
potential is only necessary once per MH-correction step, which can be applied
after multiple iterations [1].
[1] https://arxiv.org/abs/2102.01691
::
sggmc_run = alias.sggmc(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
results = sggmc_run(
sample, init_model_state=0, iterations=5000
)[0]['samples']['variables']
Args:
stochastic_potential_fn: Stochastic potential over a minibatch of data
full_potential_fn: Potential from full dataset
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
integration_steps: Number of leapfrog-steps before each MH-correction step
friction_coefficient: Positive parameter controling amount of refreshed
momentum
first_step_size: First step size for polynomial and adaptive step size
schedule
last_step_size: Final step size of the polynomial step size schedule
adaptive_step_size: Adapt the step size to optimize the acceptance rate
during burn in
stabilization_constant: Larger numbers reduce the impact of the initial
steps on the step size
decay_constant: Larger values reduce impact of later steps
speed_constant: Speed of adaption of the step size
target_acceptance_rate: Target of the adption of the step sizes
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in.
If None, no thinning wil be applied.
mass: Diagonal mass for HMC-dynamics
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``init_sample``.
"""
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
full_data_map = data.full_reference_data(data_loader, cache_size, batch_size)
obabo = integrator.obabo(
stochastic_potential_fn, random_data, integration_steps,
friction_coefficient, mass)
sggmc_solver = solver.sggmc(
obabo, full_potential_fn, full_data_map)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
if adaptive_step_size:
step_size_schedule = scheduler.adaptive_step_size(
burn_in=burn_in,
initial_step_size=first_step_size,
stabilization_constant=stabilization_constant,
decay_constant=decay_constant,
speed_constant=speed_constant,
target_acceptance_rate=target_acceptance_rate)
random_thinning_schedule = None
assert accepted_samples is None, ('Thinning currently not supported for'
' adaptive step size.')
else:
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size,
last=last_step_size)
if accepted_samples is None:
random_thinning_schedule = None
else:
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
progress_bar=progress_bar)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
mcmc = solver.mcmc(sggmc_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations=1000):
init_sggmc_fn = partial(sggmc_solver[0], init_model_state=init_model_state)
states = map(init_sggmc_fn, init_samples)
return mcmc(*states, iterations=iterations)
return run_fn
[docs]def sghmc(potential_fn: potential.minibatch_potential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
integration_steps: int = 10,
friction: Union[float, Pytree] = 1.0,
mass: Pytree = None,
first_step_size: float = 0.05,
last_step_size: float = 0.001,
burn_in: int = 0,
accepted_samples: int = 1000,
adapt_noise_model: bool = False,
diagonal_noise: bool = True,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Stochastic Gradient Hamiltonian Monte Carlo.
SGHMC improves the exploratory power of SGLD by introducing momentum [1].
[1] https://arxiv.org/abs/1402.4102
::
sghmc_run = alias.sghmc(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
results = sghmc_run(sample, init_model_state=0, iterations=5000)[0]['samples']['variables']
Args:
potential_fn: Stochastic potential over a minibatch of data
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
integration_steps: Number of leapfrog steps before resampling the momentum
friction: Friction to counteract noise introduced by stochastic gradients.
Can be specified for each variable or for all variables (scalar value)
mass: Diagonal mass to be used for hamiltonian dynamics
first_step_size: First step size
last_step_size: Final step size
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in
adapt_noise_model: Estimate the gradient noise to speed up the convergence.
diagonal_noise: Restrict the noise estimate to be diagonal.
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``init_sample``.
"""
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
if adapt_noise_model:
noise_model = adaption.fisher_information(minibatch_potential=potential_fn,
diagonal=diagonal_noise)
else:
noise_model = None
friction_leapfrog = integrator.friction_leapfrog(
potential_fn, random_data, friction=friction, const_mass=mass,
steps=integration_steps, noise_model=noise_model)
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size, last=last_step_size)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
progress_bar=progress_bar)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
sghmc_solver = solver.sgmc(friction_leapfrog)
mcmc = solver.mcmc(sghmc_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations = 1000):
init_sghmc_fn = partial(sghmc_solver[0], init_model_state=init_model_state)
states = map(init_sghmc_fn, init_samples)
return mcmc(*states, iterations=iterations)
return run_fn
[docs]def obabo(potential_fn: potential.minibatch_potential,
data_loader: data.DataLoader,
cache_size: int = 512,
batch_size: int = 32,
integration_steps: int = 10,
friction: Union[float, Pytree] = 1.0,
mass: Pytree = None,
first_step_size: float = 0.05,
last_step_size: float = 0.001,
burn_in: int = 0,
accepted_samples: int = 1000,
save_to_numpy: bool = True,
progress_bar: bool = True):
"""Langevin Monte Carlo with partial momentum refreshment.
[1] https://arxiv.org/abs/2102.01691
::
sghmc_run = alias.obabo(...)
sample = {"w": jnp.zeros((N, 1)), "sigma": jnp.array(2.0)}
results = sghmc_run(sample, init_model_state=0, iterations=5000)[0]['samples']['variables']
Args:
potential_fn: Stochastic potential over a minibatch of data
data_loader: Data loader, e. g. numpy data loader
cache_size: Number of mini_batches in device memory
batch_size: Number of observations per batch
integration_steps: Number of leapfrog steps before resampling the momentum
friction: Positive parameter controling amount of refreshed momentum
mass: Diagonal mass to be used for hamiltonian dynamics
first_step_size: First step size
last_step_size: Final step size
burn_in: Number of samples to skip before collecting samples
accepted_samples: Total number of samples to collect, will be determined by
random thinning if accepted samples < iterations - burn_in
save_to_numpy: Save on host in numpy array instead of in device memory
progress_bar: Print the progress of the solver
Returns:
Returns a solver function which can be applied to multiple chains starting
at ``init_sample``.
"""
random_data = data.random_reference_data(data_loader, cache_size, batch_size)
obabo_integrator = integrator.obabo(
potential_fn=potential_fn, batch_fn=random_data, steps=integration_steps,
friction=friction, const_mass=mass
)
step_size_schedule = scheduler.polynomial_step_size_first_last(
first=first_step_size, last=last_step_size)
burn_in_schedule = scheduler.initial_burn_in(burn_in)
random_thinning_schedule = scheduler.random_thinning(
step_size_schedule, burn_in_schedule, selections=accepted_samples)
schedule = scheduler.init_scheduler(
step_size=step_size_schedule,
burn_in=burn_in_schedule,
thinning=random_thinning_schedule,
progress_bar=progress_bar)
if save_to_numpy:
data_collector = io.MemoryCollector()
saving = io.save(data_collector)
else:
saving = None
sghmc_solver = solver.sgmc(obabo_integrator)
mcmc = solver.mcmc(sghmc_solver, schedule, strategy='map', saving=saving)
def run_fn(*init_samples, init_model_state: Pytree = None, iterations = 1000):
init_sghmc_fn = partial(sghmc_solver[0], init_model_state=init_model_state)
states = map(init_sghmc_fn, init_samples)
return mcmc(*states, iterations=iterations)
return run_fn