# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines integrators which form the core of the solvers."""
from functools import partial
from typing import Callable, Any, Tuple, Dict, NamedTuple
import jax.numpy as jnp
import numpy as onp
from jax import random, tree_unflatten, tree_flatten, grad, lax, tree_util, \
value_and_grad, named_call
from jax_sgmc.adaption import AdaptionStrategy, MassMatrix
from jax_sgmc.data import RandomBatch, CacheState
from jax_sgmc.potential import StochasticPotential
from jax_sgmc.scheduler import schedule
from jax_sgmc.util import Array, tree_scale, tree_add, tensor_matmul, \
tree_dot, Tensor, tree_multiply
PyTree = Any
[docs]class LeapfrogState(NamedTuple):
"""State of the reversible and friction leapfrog integrator.
Attributes:
positions: Latent variables in hamiltonian dynamics formulation
momentum: Momentum with the same shape as the latent variables
potential: Accumulated energy to calculate MH-correction step
model_state: State of the model in the likelihood
data_state: State of the random data function
key: PRNGKey
"""
positions: PyTree
momentum: PyTree
potential: Array
model_state: PyTree
data_state: CacheState
key: Array
extra_fields: PyTree = None
[docs]class ObaboState(NamedTuple):
"""State of the OBABO integrator.
Attributes:
positions: Latent variables in hamiltonian dynamics formulation
momentum: Momentum with the same shape as the latent variables
potential: Stochastic potential
model_state: State of the model in the likelihood
data_state: State of the random data function
key: PRNGKey
kinetic_energy_start: Kinetic energy after the first 1/4-step
kinetic_energy_end: Kinetic energy after the last 3/4-step
"""
positions: PyTree
momentum: PyTree
potential: Array
model_state: PyTree
data_state: CacheState
key: Array
kinetic_energy_start: Array
kinetic_energy_end: Array
[docs]class LangevinState(NamedTuple):
"""State of the langevin diffusion integrator.
Attributes:
latent_variables: Current latent variables
key: PRNGKey
adapt_state: Containing quantities such as momentum for adaption
data_state: State of the reference data cache
model_state: Variables not considered during inference
potential: Stochastic potential from last evaluation
variance: Variance of stochastic potential over mini-batch
"""
latent_variables: PyTree
model_state: PyTree
key: Array
adapt_state: Any
data_state: CacheState
potential: Array
variance: Array
[docs]def init_mass(mass) -> MassMatrix:
"""Initializes a diagonal mass tensor.
Args:
mass: Diagonal mass which has the same tree structure as the sample.
Returns:
Returns a diagonal mass matrix.
"""
inv_mass = tree_util.tree_map(
lambda x: jnp.power(x, -1.0),
mass)
sqrt_mass = tree_util.tree_map(
jnp.sqrt,
mass)
inv_mass = Tensor(ndim=1, tensor=inv_mass)
sqrt_mass = Tensor(ndim=1, tensor=sqrt_mass)
return MassMatrix(inv=inv_mass, sqrt=sqrt_mass)
[docs]def random_tree(key, a):
"""Build a tree shaped like a where all nodes are normally distributed.
Arguments:
key: PRNGKey
a: PyTree defining the shape of the output
Returns:
Tree shaped like a with normal distributed leaves.
"""
leaves, tree_def = tree_flatten(a)
splits = random.split(key, len(leaves))
noise_leaves = [random.normal(split, leaf.shape)
for split, leaf in zip(splits, leaves)]
noise_tree = tree_unflatten(tree_def, noise_leaves)
return noise_tree
[docs]def obabo(potential_fn: StochasticPotential,
batch_fn: RandomBatch,
steps: Array = 10,
friction: Array = 1.0,
const_mass: PyTree = None,
) -> Tuple[Callable, Callable, Callable]:
"""Initializes the OBABO integration scheme.
The OBABO integration scheme is reversible even when using stochastic
gradients and provides second order accuracy.
[1] https://arxiv.org/abs/2102.01691
Args:
potential_fn: Likelihood and prior applied over a minibatch of data
batch_fn: Function to draw a mini-batch of reference data
steps: Number of integration steps.
friction: Controls impact of momentum from previous step
const_mass: Mass matrix if no matrix is adapted. Must have the same tree
structure as the sample
Returns:
Returns a function running the time OBABO integrator for T
steps.
"""
init_data, get_data, _ = batch_fn
stochastic_gradient = value_and_grad(potential_fn, argnums=0, has_aux=True)
# Calculate the inverse and the square root
if const_mass:
const_mass = init_mass(const_mass)
else:
const_mass = None
# Helper functions to calculate the kinetic energy, update the position,
# refresh and update the momentum.
def _kinetic_energy(mass, momentum):
scaled_momentum = tensor_matmul(mass.inv, momentum)
return 0.5 * tree_dot(momentum, scaled_momentum)
def _position_update(scale, mass, position, momentum):
scaled_momentum = tensor_matmul(mass.inv, momentum)
scaled_momentum = tree_scale(scale, scaled_momentum)
return tree_add(position, scaled_momentum)
# Half step if scale == 0.5 * step_size
def _momentum_update(scale, gradient, momentum):
scaled_gradient = tree_scale(-1.0 * scale, gradient)
return tree_add(scaled_gradient, momentum)
# Add noise to momentum
def _momentum_resampling(parameters, mass, momentum, split):
noise = random_tree(split, momentum)
scaled_noise = tensor_matmul(mass.sqrt, noise)
permanence = jnp.exp(-friction * parameters.step_size)
momentum_noise = tree_scale(
jnp.sqrt((1 - permanence) * parameters.temperature),
scaled_noise)
decayed_momentum = tree_scale(jnp.sqrt(permanence), momentum)
return tree_add(decayed_momentum, momentum_noise)
# A single OBABO-step of the integrator
def _leapfrog_steps(state: ObaboState,
step: jnp.array,
mass: PyTree,
parameters: schedule = None):
del step
key, split1, split2 = random.split(state.key, num=3)
refreshed_momentum = _momentum_resampling(
parameters,
mass,
state.momentum,
split1)
# The kinetic energy from the first step is necessary to calculate the
# acceptance probability
# start_energy = lax.select(
# step == 0,
# _kinetic_energy(mass, refreshed_momentum),
# state.kinetic_energy_start)
start_energy = state.kinetic_energy_start + _kinetic_energy(mass, refreshed_momentum)
# Momentum update with stochastic gradient
data_state, mini_batch = get_data(state.data_state, information=True)
(pot_before, model_state), gradient = stochastic_gradient(
state.positions,
mini_batch,
state=state.model_state)
first_updated_momentum = _momentum_update(
0.5 * parameters.step_size,
gradient,
refreshed_momentum)
# Position update with momentum
updated_positions = _position_update(
parameters.step_size,
mass,
state.positions,
first_updated_momentum)
# Momentum update with stochastic gradient
data_state, mini_batch = get_data(data_state, information=True)
(pot_after, model_state), gradient = stochastic_gradient(
updated_positions,
mini_batch,
state=model_state )
second_updated_momentum = _momentum_update(
0.5 * parameters.step_size,
gradient,
first_updated_momentum)
final_refreshed_momentum = _momentum_resampling(
parameters,
mass,
second_updated_momentum,
split2)
# The kinetic energy of the last step is necessary to
# calculate the acceptance probability for the MH step.
end_energy = state.kinetic_energy_end + _kinetic_energy(mass, second_updated_momentum)
new_state = ObaboState(
potential=0.5 * (pot_before + pot_after),
positions=updated_positions,
momentum=final_refreshed_momentum,
key=key,
data_state=data_state,
model_state=model_state,
kinetic_energy_start=start_energy,
kinetic_energy_end=end_energy)
return new_state, None
def init_fn(init_sample: PyTree,
key: Array = None,
batch_kwargs: Dict = None,
init_model_state: PyTree = None):
"""Initializes the initial state of the integrator.
Args:
init_sample: Initial latent variables
key: Initial PRNGKey
batch_kwargs: Determine the initial state of the random data chain
init_model_state: State of the model.
mass: Mass matrix
Returns:
Returns the initial state of the integrator.
"""
# Initializing the initial state here makes it easier to add additional
# variables which might be only necessary in special case
if batch_kwargs is None:
batch_kwargs = {}
reference_data_state = init_data(**batch_kwargs)
if key is None:
key = random.PRNGKey(0)
momentum = tree_util.tree_map(jnp.zeros_like, init_sample)
init_state = ObaboState(
kinetic_energy_start=jnp.array(0.0),
kinetic_energy_end=jnp.array(0.0),
potential=jnp.array(0.0),
key=key,
positions=init_sample,
momentum=momentum,
data_state=reference_data_state,
model_state=init_model_state)
return init_state
@partial(named_call, name='obabo_integration')
def integrate(state: ObaboState,
parameters: schedule,
mass: PyTree = None
) -> ObaboState:
# If the mass is not adapted, take the constant mass if provided
if const_mass is None:
cms = init_mass(tree_util.tree_map(jnp.ones_like, state.positions))
else:
cms = const_mass
if mass is None:
mass = cms
# Leapfrog integration
state, _ = lax.scan(
partial(_leapfrog_steps, parameters=parameters, mass=mass),
state,
onp.arange(steps))
return state
def get_fn(state: LeapfrogState) -> Dict[str, PyTree]:
"""Returns the latent variables."""
return {"variables": state.positions,
"energy": state.potential,
"model_state": state.model_state}
return init_fn, integrate, get_fn
[docs]def reversible_leapfrog(potential_fn: StochasticPotential,
batch_fn: RandomBatch,
steps: int = 10,
friction: [float, Array] = 0.25,
const_mass: PyTree = None
) -> Tuple[Callable, Callable, Callable]:
"""Initializes a reversible leapfrog integrator.
AMAGOLD requires a reversible leapfrog integrator with half step at the
beginning and end.
Args:
potential_fn: Likelihood and prior applied over a minibatch of data
batch_fn: Function to draw a mini-batch of reference data
steps: Number of intermediate leapfrog steps
friction: Decay of momentum to counteract induced noise due to stochastic
gradients
const_mass: Mass matrix to be used when no mass matrix is adapted
Returns:
Returns a function running the time reversible leapfrog integrator for T
steps.
"""
init_data, get_data, _ = batch_fn
stochastic_gradient = grad(potential_fn, has_aux=True)
# Calculate the inverse and the square root
if const_mass:
const_mass = init_mass(const_mass)
else:
const_mass = None
def _position_update(scale, mass, position, momentum):
scaled_momentum = tensor_matmul(mass.inv, momentum)
# Scale is 0.5 of step size for half momentum update, otherwise it is just
# the step size.
scaled_momentum = tree_scale(scale, scaled_momentum)
return tree_add(position, scaled_momentum)
def _cov_scaled_noise(split, mass, tree):
noise = random_tree(split, tree)
noise = tensor_matmul(mass.sqrt, noise)
return noise
def _energy(old_momentum, new_momentum, mass, gradient, scale):
# Accumulate the energy
momentum_sum = tree_add(old_momentum, new_momentum)
scaled_gradient = tensor_matmul(mass.inv, gradient)
unscaled_energy = tree_dot(momentum_sum, scaled_gradient)
return scale * unscaled_energy
def _body_fun(state: LeapfrogState,
step: jnp.array,
parameters: schedule,
mass: MassMatrix):
# Full step not required in first iteration because of the half step at the
# beginning
positions = lax.cond(step == 0,
lambda pos: pos,
lambda pos: _position_update(
parameters.step_size,
mass,
pos,
state.momentum),
state.positions)
key, split = random.split(state.key)
noise = _cov_scaled_noise(split, mass, state.momentum)
scaled_noise = tree_scale(
jnp.sqrt(4 * friction * parameters.step_size),
noise)
data_state, mini_batch = get_data(state.data_state, information=True)
gradient, model_state = stochastic_gradient(
positions,
mini_batch,
state=state.model_state)
decayed_momentum = tree_scale(
1 - parameters.step_size * friction,
state.momentum)
negative_scaled_gradient = tree_scale(
-1.0 * parameters.step_size,
gradient)
unscaled_momentum = tree_add(
tree_add(decayed_momentum, negative_scaled_gradient),
scaled_noise)
updated_momentum = tree_scale(
1 / (1 + parameters.step_size * friction),
unscaled_momentum)
energy = _energy(
state.momentum,
updated_momentum,
mass,
gradient,
0.5 * parameters.step_size)
accumulated_energy = energy + state.potential
new_state = LeapfrogState(
potential=accumulated_energy,
key=key,
positions=positions,
momentum=updated_momentum,
data_state=data_state,
model_state=model_state)
return new_state, None
def init_fn(init_sample: PyTree,
key: Array = None,
batch_kwargs: Dict = None,
init_model_state: PyTree = None,
mass: PyTree = None):
"""Initializes the initial state of the integrator.
Args:
init_sample: Initial latent variables
key: Initial PRNGKey
batch_kwargs: Determine the initial state of the random data chain
init_cov: Initial covariance.
init_model_state: State of the model.
Returns:
Returns the initial state of the integrator.
"""
# Initializing the initial state here makes it easier to add additional
# variables which might be only necessary in special case
if batch_kwargs is None:
batch_kwargs = {}
reference_data_state = init_data(**batch_kwargs)
# Use constant mass if provided
if not mass:
if const_mass:
mass = const_mass
else:
mass = init_mass(tree_util.tree_map(jnp.ones_like, init_sample))
# Sample initial momentum
if key is None:
key = random.PRNGKey(0)
key, split = random.split(key)
momentum = _cov_scaled_noise(split, mass, init_sample)
init_state = LeapfrogState(
potential=jnp.array(0.0),
key=key,
positions=init_sample,
momentum=momentum,
data_state=reference_data_state,
model_state=init_model_state)
return init_state
@partial(named_call, name='leapfrog_integration')
def integrate(state: LeapfrogState,
parameters: schedule,
mass: PyTree = None
) -> LeapfrogState:
# Use default values if mass matrix not provided
if not mass:
if const_mass:
mass = const_mass
else:
mass = init_mass(tree_util.tree_map(jnp.ones_like, state.positions))
# Half step for leapfrog integration
positions = _position_update(
0.5 * parameters.step_size, mass, state.positions, state.momentum)
# Do the leapfrog steps
state = LeapfrogState(
positions=positions,
momentum=state.momentum,
key=state.key,
potential=jnp.array(0.0),
model_state=state.model_state,
data_state=state.data_state)
# Leapfrog integration
state, _ = lax.scan(
partial(_body_fun, parameters=parameters, mass=mass),
state,
onp.arange(steps))
# Final half step
positions = _position_update(
0.5 * parameters.step_size, mass, state.positions, state.momentum)
final_state = LeapfrogState(
positions=positions,
momentum=state.momentum,
key=state.key,
potential=state.potential,
model_state=state.model_state,
data_state=state.data_state)
return final_state
def get_fn(state: LeapfrogState) -> Dict[str, PyTree]:
return {"variables": state.positions,
"energy": state.potential,
"model_state": state.model_state}
return init_fn, integrate, get_fn
[docs]def friction_leapfrog(potential_fn: StochasticPotential,
batch_fn: RandomBatch,
steps: int = 10,
friction: Array = 0.25,
const_mass: PyTree = None,
noise_model = None
) -> Tuple[Callable, Callable, Callable]:
"""Initializes the original SGHMC leapfrog integrator.
Original SGHMC from [1].
[1] https://arxiv.org/pdf/1402.4102.pdf
Args:
potential_fn: Likelihood and prior applied over a minibatch of data
batch_fn: Function to draw a mini-batch of reference data
steps: Number of intermediate leapfrog steps
friction: Decay of momentum to counteract induced noise due to stochastic
gradients
const_mass: Mass matrix of the hamiltonian process
noise_model: Stateless adaption of the noise (e. g. via the empirical fisher
information)
Returns:
Returns a function running the non-conservative leapfrog integrator for T
steps.
"""
init_data, get_data, _ = batch_fn
stochastic_gradient = value_and_grad(potential_fn, has_aux=True)
if noise_model:
init_noise_model, update_noise_model, get_noise_model = noise_model
# Calculate the inverse and the square root
if const_mass:
const_mass = init_mass(const_mass)
else:
const_mass = None
def _body_fun(state: LeapfrogState,
step: jnp.array,
parameters: schedule,
friction: PyTree,
mass: MassMatrix):
del step
# Update the position with the momentum
scaled_momentum = tensor_matmul(mass.inv, state.momentum)
position_update = tree_scale(parameters.step_size, scaled_momentum)
new_positions = tree_add(state.positions, position_update)
# Update the momentum in three steps
# 1. Momentum decays from friction
scaled_friction = tree_scale(-parameters.step_size, friction)
momentum_decay = tree_multiply(scaled_friction, scaled_momentum)
new_momentum = tree_add(state.momentum, momentum_decay)
# 2. Momentum changes due to forces (gradient of stochastic potential)
data_state, mini_batch = get_data(state.data_state, information=True)
(pot, model_state), gradient = stochastic_gradient(
new_positions,
mini_batch,
state=state.model_state)
scaled_gradient = tree_scale(-parameters.step_size, gradient)
new_momentum = tree_add(new_momentum, scaled_gradient)
# 3. Injection of noise
key, split = random.split(state.key)
noise = random_tree(split, state.momentum)
if noise_model:
noise_state = update_noise_model(
state.extra_fields,
new_positions,
gradient,
friction,
mini_batch=mini_batch,
step_size=parameters.step_size)
noise_correction = get_noise_model(
noise_state,
new_positions,
gradient,
friction,
mini_batch=mini_batch,
step_size=parameters.step_size,
model_state=state.model_state)
reduced_noise = tensor_matmul(noise_correction.cb_diff_sqrt, noise)
scaled_noise = tree_scale(jnp.sqrt(2 * parameters.step_size), reduced_noise)
else:
noise_state = None
scaled_noise = tree_scale(jnp.sqrt(2 * parameters.step_size), noise)
scaled_noise = tree_multiply(friction, scaled_noise)
new_momentum = tree_add(new_momentum, scaled_noise)
new_state = LeapfrogState(
potential=pot,
key=key,
positions=new_positions,
momentum=new_momentum,
data_state=data_state,
model_state=model_state,
extra_fields=noise_state)
return new_state, None
def init_fn(init_sample: PyTree,
key: Array = None,
batch_kwargs: Dict = None,
init_model_state: PyTree = None):
"""Initializes the initial state of the integrator.
Args:
init_sample: Initial latent variables
key: Initial PRNGKey
batch_kwargs: Determine the initial state of the random data chain
init_model_state: State of the model.
Returns:
Returns the initial state of the integrator.
"""
# Initializing the initial state here makes it easier to add additional
# variables which might be only necessary in special case
if batch_kwargs is None:
batch_kwargs = {}
if noise_model:
noise_state = init_noise_model(init_sample)
else:
noise_state = None
reference_data_state = init_data(**batch_kwargs)
# Only shape of momentum is important, as momentum is resampled in each
# integration step
if key is None:
key = random.PRNGKey(0)
momentum = init_sample
init_state = LeapfrogState(
potential=jnp.array(0.0),
key=key,
positions=init_sample,
momentum=momentum,
data_state=reference_data_state,
model_state=init_model_state,
extra_fields=noise_state)
return init_state
@partial(named_call, name='leapfrog_integration')
def integrate(state: LeapfrogState,
parameters: schedule,
mass: PyTree = None
) -> LeapfrogState:
# Use default values if mass matrix not provided
if not mass:
if const_mass:
mass = const_mass
else:
mass = init_mass(tree_util.tree_map(jnp.ones_like, state.positions))
# If the friction is scalar
if tree_util.treedef_is_leaf(tree_util.tree_structure(friction)):
multiscalar_friction = tree_util.tree_map(
partial(jnp.full_like, fill_value=friction), state.positions)
else:
multiscalar_friction = friction
# Resample momentum
key, split = random.split(state.key)
noise = random_tree(split, state.momentum)
momentum = tensor_matmul(mass.sqrt, noise)
# Do the leapfrog steps
state = LeapfrogState(
positions=state.positions,
momentum=momentum,
key=key,
potential=jnp.array(0.0),
model_state=state.model_state,
data_state=state.data_state,
extra_fields=state.extra_fields)
final_state, _ = lax.scan(
partial(_body_fun,
parameters=parameters,
mass=mass,
friction=multiscalar_friction),
state,
onp.arange(steps))
return final_state
def get_fn(state: LeapfrogState) -> Dict[str, PyTree]:
"""Returns the latent variables."""
return {"variables": state.positions,
"energy": state.potential,
"model_state": state.model_state}
return init_fn, integrate, get_fn
[docs]def langevin_diffusion(
potential_fn: StochasticPotential,
batch_fn: RandomBatch,
adaption: AdaptionStrategy = None,
) -> Tuple[Callable, Callable, Callable]:
"""Initializes langevin diffusion integrator.
Arguments:
potential_fn: Likelihood and prior applied over a minibatch of data
batch_fn: Function to draw a mini-batch of reference data
adaption: Adaption of manifold for faster inference
Returns:
Returns a tuple consisting of ``ini_fn``, ``update_fn``, ``get_fn``.
The init_fn takes the arguments
- key: Initial PRNGKey
- adaption_kwargs: Additional arguments to determine the initial manifold
state
- batch_kwargs: Determine the state of the random data chain
"""
if adaption is not None:
adapt_init, adapt_update, adapt_get = adaption
batch_init, batch_get, _ = batch_fn
stochastic_gradient = value_and_grad(potential_fn, argnums=0, has_aux=True)
# We need to define an update function. All array operations must be
# implemented via tree_map. This is probably going to change with the
# introduction of the tree vectorizing transformation
# --> https://github.com/google/jax/pull/3263
# This function is intended to generate initial states. Jax key,
# adaption, etc. can be initialized to a default value if not explicitely
# provided
def init_fn(init_sample: PyTree,
key: Array = random.PRNGKey(0),
adaption_kwargs: Dict = None,
batch_kwargs: Dict = None,
init_model_state: PyTree = None):
"""Initializes the state of the integrator.
Args:
init_sample: Initial latent variables
key: Initial PRNGKey
adaption_kwargs: Determine the initial state of the adaption
batch_kwargs: Determine the initial state of the random data chain
Returns:
Returns the initial state of the integrator.
"""
# Initializing the initial state here makes it easier to add additional
# variables which might be only necessary in special case
if adaption_kwargs is None:
adaption_kwargs = {}
if batch_kwargs is None:
batch_kwargs = {}
# Adaption is not required in the most general case
if adaption is None:
adaption_state = None
else:
adaption_state = adapt_init(init_sample, **adaption_kwargs)
reference_data_state = batch_init(**batch_kwargs)
init_state = LangevinState(
key=key,
latent_variables=init_sample,
adapt_state=adaption_state,
data_state=reference_data_state,
model_state=init_model_state,
potential=jnp.array(0.0),
variance=jnp.array(1.0)
)
return init_state
# Returns the important parameters of a state and excludes. Makes information
# hiding possible
def get_fn(state: LangevinState) -> Dict[str, PyTree]:
"""Returns the latent variables."""
return {"variables": state.latent_variables,
"likelihood": -state.potential,
"model_state": state.model_state}
# Update according to the integrator update rule
@partial(named_call, name='langevin_diffusion_step')
def update_fn(state: LangevinState,
parameters: schedule):
"""Updates the integrator state according to a schedule.
Args:
state: Integrator state
parameters: Schedule containing step_size and temperature
Returns:
Returns a new step calculated by applying langevin diffusion.
"""
key, split = random.split(state.key)
data_state, mini_batch = batch_get(state.data_state, information=True)
noise = random_tree(split, state.latent_variables)
(potential, (likelihoods, new_model_state)), gradient = stochastic_gradient(
state.latent_variables,
mini_batch,
state=state.model_state,
likelihoods=True)
variance = jnp.var(likelihoods)
scaled_gradient = tree_scale(-parameters.step_size, gradient)
scaled_noise = tree_scale(
jnp.sqrt(2 * parameters.temperature * parameters.step_size), noise)
if adaption is None:
update_step = tree_add(scaled_gradient, scaled_noise)
adapt_state = None
else:
# Update the adaption
adapt_state = adapt_update(
state.adapt_state,
state.latent_variables,
gradient,
mini_batch)
# Get the adaption
manifold = adapt_get(
adapt_state,
state.latent_variables,
gradient,
mini_batch)
adapted_gradient = tensor_matmul(manifold.g_inv, scaled_gradient)
adapted_noise = tensor_matmul(manifold.sqrt_g_inv, scaled_noise)
scaled_gamma = tree_scale(parameters.step_size, manifold.gamma.tensor)
update_step = tree_add(
tree_add(scaled_gamma, adapted_gradient),
adapted_noise)
# Conclude the variable update by adding the step to the current samples
new_sample = tree_add(state.latent_variables, update_step)
new_state = LangevinState(
key=key,
latent_variables=new_sample,
adapt_state=adapt_state,
data_state=data_state,
model_state=new_model_state,
potential=jnp.array(potential, dtype=state.potential.dtype),
variance=variance)
return new_state
return init_fn, update_fn, get_fn