Source code for jax_sgmc.scheduler

# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Schedules parameters of the integrator and solver.

The scheduler organizes the independent variables of the update equation, such
as the temperature and the step size, which are organized by multiple specific
schedulers.

"""

# Todo: Correct typing

from collections import namedtuple
from typing import Callable, Tuple

import jax.numpy as jnp
from jax import lax
from jax import random
from jax.experimental import host_callback as hcb

from jax_sgmc.util import Array

specific_scheduler = namedtuple("specific_scheduler",
                                ["init",
                                 "update",
                                 "get"])
"""Bundles the specific scheduler as described above.

Attributes:
  init_fn: Function to initialize the specific scheduler
  update_fn: Calculate the next state of the scheduler
  get_fn: Get the scheduled parameter

"""

schedule = namedtuple("schedule",
                      ["step_size",
                       "temperature",
                       "burn_in",
                       "accept"])
"""Auxiliary variables for integrator.

Attributes:
  step_size: Learning rate
  temperature: Scaling the magnitude of the additional noise
  burn_in: Bool, whether current step can be accepted
  accept: Bool, whether current sample should be saved

"""

scheduler_state = namedtuple("scheduler_state",
                             ["state",
                              "step_size_state",
                              "temperature_state",
                              "burn_in_state",
                              "thinning_state",
                              "progress_bar_state"])
"""Collects the states of the specific schedulers.

Attributes:
  state: State of the base scheduler, e.g. to keep track of current iteration
  step_size_state: State of the step size scheduler
  temperature_state: State of the temperature scheduler
  burn_in_state: State of the burn in scheduler
  thinning_state: State of thinning
  progress_bar_state: State of the progress bar
"""

static_information = namedtuple("static_information",
                                ["samples_collected"])
"""Information which is constant during the run.

Attributes:
  samples_collected: Number of samples saved.
"""

# The scheduler combines the specific scheduler. This makes it easier to
# implement only rarely used auxiliary variables by providing default values.
# The update functions are collected at a central state.

[docs]def init_scheduler(step_size: specific_scheduler = None, temperature: specific_scheduler = None, burn_in: specific_scheduler = None, thinning: specific_scheduler = None, progress_bar: bool = True, progress_bar_steps: Array = 20 ) -> Tuple[Callable, Callable, Callable]: """Initializes the scheduler. The scheduler combines the specific schedules for each variable. It updates them and gets them at a central place and makes it possible to combine them or provide default values. Args: step_size: Triplet from step-size scheduler initialization temperature: Triplet from temperature scheduler initialization burn_in: Triplet from burn-in scheduler initialization thinning: Triplet from thinning scheduler initialization progress_bar: Show the percentage of completed steps Returns: Returns a triplet of ``(init_fn, update_fn, get_fn)``. """ # Define the default values if step_size is None: step_size = polynomial_step_size(a=1, b=1, gamma=0.0) if temperature is None: temperature = constant_temperature(tau=1.0) if burn_in is None: burn_in = initial_burn_in(n=0) if thinning is None: # Accept all samples, save all samples thinning = specific_scheduler( lambda iterations: (None, iterations), lambda *args, **kwargs: None, lambda *args, **kwargs: True) if progress_bar: init_progress_bar, update_progress_bar = _progress_bar(burn_in, thinning) def init_fn(iterations: int, **scheduler_kwargs ) -> Tuple[scheduler_state, static_information]: # Initialize all the specific schedulers state = (0, iterations) # Start with iteration 0 thinning_state, total_samples = thinning.init( iterations, **scheduler_kwargs.get('thinning', {})) burn_in_state, collected_samples = burn_in.init( iterations, **scheduler_kwargs.get('burn_in', {})) # If not thinning is provided, collect all samples not subject to burn in total_samples = min(total_samples, collected_samples) if progress_bar: pg_steps = scheduler_kwargs.get("progress_bar_steps", progress_bar_steps) pg_enabled = scheduler_kwargs.get("enabled", jnp.array(progress_bar)) progress_bar_state = init_progress_bar( jnp.array(iterations), total_samples, pg_steps, pg_enabled) else: progress_bar_state = None init_state = scheduler_state( state=state, step_size_state=step_size.init( iterations, **scheduler_kwargs.get('step_size', {})), temperature_state=temperature.init( iterations, **scheduler_kwargs.get('temperature', {})), burn_in_state=burn_in_state, thinning_state=thinning_state, progress_bar_state=progress_bar_state) static = static_information( samples_collected=total_samples) return init_state, static def update_fn(state: scheduler_state, **kwargs) -> scheduler_state: # Keep track of current iteration iteration, total_iterations = state.state current_iteration = iteration + 1 # Update the states step_size_state = step_size.update(state.step_size_state, iteration, **kwargs) temperature_state = temperature.update(state.temperature_state, iteration, **kwargs) burn_in_state = burn_in.update(state.burn_in_state, iteration, **kwargs) thinning_state = thinning.update(state.thinning_state, iteration, **kwargs) if progress_bar: # The burn in and thinning state are required to count the number of # collected samples progress_bar_state = update_progress_bar( state.progress_bar_state, iteration, burn_in_state, thinning_state, **kwargs) else: progress_bar_state = None new_scheduler_state = (current_iteration, total_iterations) updated_scheduler_state = scheduler_state( state=new_scheduler_state, step_size_state=step_size_state, temperature_state=temperature_state, burn_in_state=burn_in_state, thinning_state=thinning_state, progress_bar_state=progress_bar_state) return updated_scheduler_state def get_fn(state: scheduler_state, **kwargs) -> schedule: iteration, total_iterations = state.state current_step_size = step_size.get(state.step_size_state, iteration, **kwargs) current_temperature = temperature.get(state.temperature_state, iteration, **kwargs) current_burn_in = burn_in.get(state.burn_in_state, iteration, **kwargs) current_thinning = thinning.get(state.thinning_state, iteration, **kwargs) current_schedule = schedule( step_size=jnp.array(current_step_size), temperature=jnp.array(current_temperature), burn_in=jnp.array(current_burn_in), accept=jnp.array(current_thinning), ) return current_schedule return init_fn, update_fn, get_fn
################################################################################ # # Temperature # ################################################################################
[docs]def constant_temperature(tau: Array = 1.0) -> specific_scheduler: """Scales the added noise with an unchanged constant. Args: tau: Scale of the added noise Returns: Returns a triplet as described above. """ def init_fn(iterations: int, tau: Array = tau ) -> Array: del iterations return tau def update_fn(state: Array, iteration: int, **kwargs ) -> Array: del iteration, kwargs return state def get_fn(state: Array, iteration: int, **kwargs ) -> Array: del iteration, kwargs return state return specific_scheduler(init_fn, update_fn, get_fn)
[docs]def cyclic_temperature(beta: Array=1.0, k: int=1) -> specific_scheduler: """Cyclic switch of the temperature between 0.0 and 1.0. Switches temperature form 0.0 (SGD) to 1.0 (SGLD) when ratio of initial step size and current step size drops below beta. This scheduler is intended to be used with the cyclic step size scheduler. Args: beta: Ratio of current step size to initial step size when transition to SGLD k: Number of cycles Returns: Returns a triplet as described above """ raise NotImplementedError
################################################################################ # # Progress bar # ################################################################################ def _progress_bar(burn_in: specific_scheduler, thinning: specific_scheduler): """Prints the progress of the solver. Args: burn_in: Burn in scheduler to count accepted samples thinning: Thinning scheduler to count accepted samples """ def _print_fn(info, _): percentage = round( int(info['current_iteration']) / int(info['total_iterations']) * 100) total_samples = int(info["total_samples"]) collected_samples = int(info["collected_samples"]) current_iteration = int(info["current_iteration"]) total_iterations = int(info["total_iterations"]) print(f"[Step {current_iteration}/{total_iterations}]" f"({percentage}%) Collected {collected_samples} of " f"{total_samples} samples...") def init_fn(iterations: Array, num_samples: Array, steps: Array = jnp.array(20), enabled: Array = jnp.array(True) ) -> Tuple[Array, Array, Array, Array, Array]: # Set already collected samples to zero init_state = iterations, num_samples, jnp.zeros(1), steps, enabled return init_state def step_fn(state: Tuple[Array, Array, Array, Array, Array], iteration: Array, burn_in_state, thinning_state, **kwargs ): iterations, tot_samples, collected_samples, steps, enabled = state # A sample is going to be saved if it is not subject to burn in and accepted sample_burn_in = burn_in.get(burn_in_state, iteration, **kwargs) sample_accepted = thinning.get(thinning_state, iteration, **kwargs) saved = sample_burn_in * sample_accepted collected_samples += saved info = { "total_iterations": iterations, "current_iteration": iteration, "total_samples": tot_samples, "collected_samples": collected_samples, "kwargs": kwargs } # Calculate number of steps until the progress should be printed out num_its = jnp.int_(jnp.floor(iterations / steps)) # Return the number of collected samples as result of id_tap collected_samples = lax.cond( jnp.logical_and(jnp.mod(iteration, num_its) == 0, enabled), lambda arg: hcb.id_tap(_print_fn, arg, result=collected_samples), lambda arg: info["collected_samples"], info ) new_state = iterations, tot_samples, collected_samples, steps, enabled return new_state return init_fn, step_fn ################################################################################ # # Step Size # ################################################################################
[docs]def adaptive_step_size(burn_in = 0, initial_step_size = 0.05, stabilization_constant = 100, decay_constant = 0.75, speed_constant = 0.05, target_acceptance_rate=0.02): """Dual averaging scheme to tune step size for schemes with MH-step. The adaptive step size uses the dual averaging scheme to optimize the acceptance rate, as proposed by [1]. [1] https://arxiv.org/abs/1111.4246 Args: burn_in: Initial iterations, in which the step size should be tuned initial_step_size: Initial value of the step size speed_constant: Bigger constant stabilizes adaption against initial iterations decay_constant: Controls decay of learning rate of the step size speed_constant: Weights acceptance ratio statistics target_acceptance_rate: Desired acceptance rate Returns: Returns a specific step size scheduler. """ def init(iterations: int, burn_in=burn_in, initial_step_size=initial_step_size, stabilization_constant=stabilization_constant, decay_constant=decay_constant, speed_constant=speed_constant, target_acceptance_rate=target_acceptance_rate): del iterations x_bar = jnp.log(initial_step_size) h_bar = 0.0 init_state = ( burn_in, x_bar, h_bar, target_acceptance_rate, stabilization_constant, decay_constant, speed_constant, jnp.log(10 * initial_step_size)) return init_state def update(state: Array, iteration: int, acceptance_ratio=0.0, **kwargs): del kwargs burn_in, x_bar, h_bar, alpha, t0, kappa, gamma, mu = state m = iteration + 1 h_bar *= (1 - 1/(m + t0)) h_bar += 1/(m + t0) * (target_acceptance_rate - acceptance_ratio) x = mu - jnp.sqrt(m) / gamma * h_bar lr = jnp.power(m, -kappa) x_bar_old = x_bar x_bar *= (1 - lr) x_bar += lr * x # Only update during burn in x_bar = jnp.where(iteration < burn_in, x_bar, x_bar_old) return burn_in, x_bar, h_bar, alpha, t0, kappa, gamma, mu def get(state: Array, iteration: int, **kwargs): del iteration, kwargs return jnp.exp(state[1]) return specific_scheduler(init, update, get)
[docs]def polynomial_step_size(a: Array = 1.0, b: Array = 1.0, gamma: Array = 0.33 ) -> specific_scheduler: """Polynomial decreasing step size schedule. Implements the original proposal of a polynomial step size schedule :math:`\epsilon = a(b + n)^{\gamma}`. Args: a: Scale of all step sizes b: Stabilization constant gamma: Decay constant Returns: Returns triplet as described above. """ # The internal state is just an array, which holds the step size for all the # iterations def init_fn(iterations: int, a: Array = a, b: Array = b, gamma: Array = gamma ) -> Array: assert gamma >= 0, f"Gamma must be positive: gamma = {gamma}" assert a > 0, f"a must be positive: a = {a}" assert b > 0, f"b must be greater than zero: b = {b}" n = jnp.arange(iterations) unscaled = jnp.power(b + n, -gamma) scaled = jnp.multiply(a, unscaled) return scaled def update_fn(state: Array, iteration: int, **kwargs) -> Array: del iteration, kwargs return state def get_fn(state: Array, iteration: int, **kwargs) -> Array: del kwargs return state[iteration] return specific_scheduler(init_fn, update_fn, get_fn)
[docs]def polynomial_step_size_first_last(first: [float, Array] = 1.0, last: [float, Array] = 1.0, gamma: [float, Array] = 0.33 ) -> specific_scheduler: """Initializes polynomial step size schedule via first and last step. Args: first: Step size in the first iteration last: Step size in the last iteration gamma: Rate of decay Returns: Returns a polynomial step size schedule defined via the first and the last step size. """ # Calculates the required coefficients of the polynomial def find_ab(its, gamma, first, last): ginv = jnp.power(gamma, -1.0) fpow = jnp.power(first, -ginv) # pylint: disable=E1130 lpow = jnp.power(last, -ginv) # pylint: disable=E1130 apow = jnp.divide(lpow - fpow, its - 1) a = jnp.power(apow, -gamma) b = jnp.power(jnp.divide(first, a), -ginv) # pylint: disable=E1130 return a, b def init_fn(iterations: int, first: [float, Array] = first, last: [float, Array] = last, gamma: [float, Array] = gamma ) -> Array: # Check for valid parameters assert gamma > 0, f"Gamma must be bigger than 0, is {gamma}" assert first >= last, f"The first step size must be larger than the last:" \ f" {first} !>= {last}" a, b = find_ab(iterations, gamma, first, last) init_fn, _, _ = polynomial_step_size(a=a, b=b, gamma=gamma) return init_fn(iterations) def update_fn(state: Array, iteration: int, **kwargs) -> Array: del iteration, kwargs return state def get_fn(state: Array, iteration: int, **kwargs) -> Array: del kwargs return state[iteration] return specific_scheduler(init_fn, update_fn, get_fn)
################################################################################ # # Burn In # ################################################################################ # Burn in: Return 1.0 if the sample should be accepted and 0.0 otherwise
[docs]def cyclic_burn_in(beta: Array=1.0, k:int=1): """Discards samples at the beginning of each cycle. Args: beta: Ratio of current and initial step size up to which burn in should be applied k: Number of cycles Returns: Returns a burn in schedule, which applies burn in to the beginning of each cycle. """ raise NotImplementedError
[docs]def initial_burn_in(n: Array = 0) -> specific_scheduler: """Discards the first n steps. Args: n: Count of initial steps which should be discarded Returns: Returns specific scheduler. """ def init_fn(iterations: int, n: Array = n) -> Tuple[Array, Array]: return n, iterations - n def update_fn(state: Array, iteration: int, **kwargs) -> Array: del iteration, kwargs return state def get_fn(state: Array, iteration: int, **kwargs) -> Array: del kwargs return jnp.where(state <= iteration, 1.0, 0.0) return specific_scheduler(init_fn, update_fn, get_fn)
################################################################################ # # Thinning # ################################################################################ # Thinning provides information about the number of samples which will be saved.
[docs]def random_thinning(step_size_schedule: specific_scheduler, burn_in_schedule: specific_scheduler, selections: int, key: Array = None ) -> specific_scheduler: """Random thinning weighted by the step size. Randomly select samples not subject to burn in. The probability of selection is proportional to the step size to deal with the issue of the decaying step size. This only works for static step size and burn in schedules. Args: step_size_schedule: Static step size schedule burn_in_schedule: Static burn in schedule selections: Number of selected samples key: PRNGKey for drawing selections Returns: Returns a scheduler marking the accepted samples. """ def init_fn(iterations: int, step_size_schedule: specific_scheduler = step_size_schedule, burn_in_schedule: specific_scheduler = burn_in_schedule, selections: int = selections, key: Array = key ) -> Tuple[Array, Array]: if key is None: key = random.PRNGKey(0) step_size_state = step_size_schedule.init(iterations) burn_in_state, _ = burn_in_schedule.init(iterations) def update_fn(state, iteration): step_size_state, burn_in_state = state step_size = step_size_schedule.get(step_size_state, iteration) burn_in = burn_in_schedule.get(burn_in_state, iteration) probability = step_size * burn_in new_state = (step_size_schedule.update(step_size_state, iteration), burn_in_schedule.update(burn_in_state, iteration)) return new_state, probability _, probs = lax.scan(update_fn, (step_size_state, burn_in_state), jnp.arange(iterations)) # Check that a sufficient number of elements can be drawn assert jnp.count_nonzero(probs) >= selections, "Cannot select enough values" # Draw the iterations which should be accepted accepted_its = random.choice(key, jnp.arange(iterations), shape=(selections,), replace=False, p=probs) return accepted_its, selections def update_fn(state: Array, iteration: int, **kwargs) -> Array: del iteration, kwargs return state def get_fn(state: Array, iteration: int, **kwargs) -> jnp.bool_: del kwargs accepted = jnp.where(jnp.any(iteration == state), True, False) return accepted return specific_scheduler(init_fn, update_fn, get_fn)