Setup Schedulers

A scheduler is a combination of specific schedulers, which control only a single parameter, for example the step size. Specific schedulers for different variables are combined into a basic scheduler via jax_sgmc.scheduler.init_scheduler(), which updates all specific schedulers and provides default values for parameters without a specific scheduler.

Specific Schedulers

>>> from jax_sgmc import scheduler
>>>
>>> step_size_schedule_unused = scheduler.polynomial_step_size(
...   a=0.1, b=1.0, gamma=0.33)

We already provided all required arguments. However, it is also possible to provide only the arguments, which should stay equal over all chains. For example we could provide different gamma-values by specifying them during the initialization of the basic scheduler:

>>> step_size_schedule_partial = scheduler.polynomial_step_size(
...   a=0.1, b=1.0)

Basic Scheduler

It is not necessary to setup a scheduler for all parameters, because the basic scheduler provides default values. Therefore, we can initialize the basic scheduler only with the specific step size schedule we initialized above:

>>> init_fn, next_fn, get_fn = scheduler.init_scheduler(
...   step_size=step_size_schedule_partial, progress_bar=False)

After we created the basic scheduler, we can initialize a schedule. Here we have to provide the missing values for the partially initialized schedulers.

>>> sched_a, static_information = init_fn(10, step_size={'gamma': 0.1})
>>> sched_b, _ = init_fn(10, step_size={'gamma': 1.0})

Static information is returned in addition to the scheduler state, e.g. the total number of iterations or the expected number of collected samples. This information is necessary, e.g., for the io-module to allocate sufficient memory for the samples to be saved.

>>> print(static_information)
static_information(samples_collected=10)

In this example, we can see that the temperature parameter has been assigned to a default value of 1.0 and the different step size schedules are updated with different gamma parameters:

>>> curr_sched_a = get_fn(sched_a)
>>> curr_sched_b = get_fn(sched_b)
>>> print(f"Scheduler a\n===========\n"
...       f"  Step-Size = {curr_sched_a.step_size : .2f}\n"
...       f"  Temperature = {curr_sched_a.temperature : .2f}")
Scheduler a
===========
  Step-Size =  0.10
  Temperature =  1.00
>>> print(f"Scheduler b\n===========\n"
...       f"  Step-Size = {curr_sched_b.step_size : .2f}\n"
...       f"  Temperature = {curr_sched_b.temperature : .2f}")
Scheduler b
===========
  Step-Size =  0.10
  Temperature =  1.00
>>> # Get the parameters at the next iteration
>>> sched_a = next_fn(sched_a)
>>> sched_b = next_fn(sched_b)
>>> curr_sched_a = get_fn(sched_a)
>>> curr_sched_b = get_fn(sched_b)
>>> print(f"Scheduler a\n===========\n"
...       f"  Step-Size = {curr_sched_a.step_size : .2f}\n"
...       f"  Temperature = {curr_sched_a.temperature : .2f}")
Scheduler a
===========
  Step-Size =  0.09
  Temperature =  1.00
>>> print(f"Scheduler b\n===========\n"
...       f"  Step-Size = {curr_sched_b.step_size : .2f}\n"
...       f"  Temperature = {curr_sched_b.temperature : .2f}")
Scheduler b
===========
  Step-Size =  0.05
  Temperature =  1.00