Compute Potential from Likelihood
Stochastic Gradient MCMC evaluates the potential and the model for a subset of observations or all observations. Therefore, this module acts as an interface between the different likelihoods and the integrators. The likelihood can be implemented for only a single sample or a batch of data.
Setup DataLoaders
For demonstration purposes, we setup a data loader to compute the potential for a random batch of data as well as for the full dataset. Note that the keyword arguments selected to initialize the data (here ‘mean’) have to be used to access the data of the observations in the likelihood.
>>> from functools import partial
>>> import jax.numpy as jnp
>>> import jax.scipy as jscp
>>> from jax import random, vmap
>>> from jax_sgmc import data, potential
>>> from jax_sgmc.data.numpy_loader import NumpyDataLoader
>>> mean = random.normal(random.PRNGKey(0), shape=(100, 5))
>>> data_loader = NumpyDataLoader(mean=mean)
>>>
>>> test_sample = {'mean': jnp.zeros(5), 'std': jnp.ones(1)}
Stochastic Potential
The stochastic potential is an estimate of the true potential. It is calculated over a mini-batch and rescaled to the full dataset. To this end, we need to initialize functions that retreive mini-batches of the data.
>>> batch_init, batch_get, _ = data.random_reference_data(data_loader,
... cached_batches_count=50,
... mb_size=5)
>>> random_data_state = batch_init()
Full Potential
In combination with the jax_sgmc.data
it is possible to calculate the
true potential over the full dataset.
If we specify a batch size of 3, then the likelihood will be sequentially
calculated over batches with the size 3.
>>> init_fun, fmap_fun, _ = data.full_reference_data(data_loader,
... cached_batches_count=50,
... mb_size=3)
>>> data_state = init_fun()
Unbatched Likelihood
In the simplest case, the likelihood and model function only accept a single
observation and parameter set.
Therefore, this module maps the evaluation over the mini-batch or even all
observations by making use of the Jax tools map
, vmap
and pmap
.
The likelihood can be written for a single observation. The
jax_sgmc.potential
module then evaluates the likelihood for a batch of
reference data sequentially via map
or in parallel via vmap
or pmap
.
The first input to the likelihood function is the sample, i.e. the model
parameters. You can access all parameters of the dict via the keywords defined
in the initial sample (e.g. ‘test_sample’ above). The second input is the
observation from the dataset, where the data can be accessed with the same
keyword arguments used during the initialization of the DataLoader.
>>> def likelihood(sample, observation):
... likelihoods = jscp.stats.norm.logpdf(observation['mean'],
... loc=sample['mean'],
... scale=sample['std'])
... return jnp.sum(likelihoods)
>>> prior = lambda unused_sample: 0.0
Stochastic Potential
The stochastic potential is computed automatically from the prior and likelihood of a single observation.
>>>
>>> stochastic_potential_fn = potential.minibatch_potential(prior,
... likelihood,
... strategy='map')
>>> new_random_data_state, random_batch = batch_get(random_data_state, information=True)
>>> potential_eval, unused_state = stochastic_potential_fn(test_sample, random_batch)
>>>
>>> print(round(potential_eval))
838
For debugging purposes, it is recommended to check with a test sample and a test observation whether the potential is evaluated correctly. This simplifies the search for bugs without the overhead from the SG-MCMC sampler.
Full Potential
Here, the likelihood written for a single observation can be re-used.
>>> potential_fn = potential.full_potential(prior, likelihood, strategy='vmap')
>>>
>>> potential_eval, (data_state, unused_state) = potential_fn(
... test_sample, data_state, fmap_fun)
>>>
>>> print(round(potential_eval))
707
Batched Likelihood
Some models already accept a batch of reference data. In this case, the
potential function can be constructed by setting is_batched = True
. In this
case, it is expected that the returned likelihoods are a vector with shape
(N,)
, where N is the batch-size.
>>> @partial(vmap, in_axes=(None, 0))
... def batched_likelihood(sample, observation):
... likelihoods = jscp.stats.norm.logpdf(observation['mean'],
... loc=sample['mean'],
... scale=sample['std'])
... # Only valid samples contribute to the likelihood
... return jnp.sum(likelihoods)
>>>
Stochastic Potential
To compute the correct potential now, the function needs to know that the
likelihood is batched by setting is_batched=True
. The strategy setting
has no meaning anymore and can be kept on the default value.
>>> stochastic_potential_fn = potential.minibatch_potential(prior,
... batched_likelihood,
... is_batched=True,
... strategy='map')
>>>
>>> new_random_data_state, random_batch = batch_get(random_data_state, information=True)
>>> potential_eval, unused_state = stochastic_potential_fn(test_sample, random_batch)
>>>
>>> print(round(potential_eval))
838
>>>
>>> _, (likelihoods, _) = stochastic_potential_fn(test_sample,
... random_batch,
... likelihoods=True)
>>>
>>> print(round(jnp.var(likelihoods)))
7
Full Potential
The batched likelihood can also be used to calculate the full potential.
>>> prior = lambda unused_sample: 0.0
>>>
>>> potential_fn = potential.full_potential(prior, batched_likelihood, is_batched=True)
>>>
>>> potential_eval, (data_state, unused_state) = potential_fn(
... test_sample, data_state, fmap_fun)
>>>
>>> print(round(potential_eval))
707
Likelihoods with States
By setting the argument has_state = True
, the likelihood accepts an
additional state as first positional argument. This state should not influence
the results of the computation.
>>> def stateful_likelihood(state, sample, observation):
... n, mean = state
... n += 1
... new_mean = (n-1)/n * mean + 1/n * observation['mean']
...
... likelihoods = jscp.stats.norm.logpdf((observation['mean'] - new_mean),
... loc=(sample['mean'] - new_mean),
... scale=sample['std'])
... return jnp.sum(likelihoods), (n, new_mean)
Note
If the likelihood is not batched (is_batched=False
), only the state
corresponding to the computation with the first sample of the batch is
returned.
Stochastic Potential
>>> potential_fn = potential.minibatch_potential(prior,
... stateful_likelihood,
... has_state=True)
>>>
>>> potential_eval, new_state = potential_fn(test_sample,
... random_batch,
... state=(jnp.array(2), jnp.ones(5)))
>>>
>>> print(round(potential_eval))
838
>>> print(f"n: {new_state[0] : d}")
n: 3
Full Potential
>>> full_potential_fn = potential.full_potential(prior,
... stateful_likelihood,
... has_state=True)
>>>
>>> potential_eval, (cache_state, new_state) = full_potential_fn(
... test_sample, data_state, fmap_fun, state=(jnp.array(2), jnp.ones(5)))
>>>
>>> print(f"n: {new_state[0] : d}")
n: 36