jax_sgmc.potential

Utility to evaluate stochastic or true potential.

This module transforms the likelihood function for a single observation or a batch of observations to a function calculating the stochastic or full potential making use of map, vmap and pmap.

Stochastic Potential

jax_sgmc.potential.minibatch_potential(prior, likelihood, strategy='map', has_state=False, is_batched=False, temperature=1.0)[source]

Initializes the potential function for a minibatch of data.

Parameters:
  • prior (Callable[[Any], Array]) – Log-prior function which is evaluated for a single sample.

  • likelihood (Callable) – Log-likelihood function. If has_state = True, then the first argument is the model state, otherwise the arguments are sample, reference_data.

  • strategy (AnyStr) –

    Determines hwo to evaluate the model function with respect for sample:

    • 'map' sequential evaluation

    • 'vmap' parallel evaluation via vectorization

    • 'pmap' parallel evaluation on multiple devices

  • has_state (bool) – If an additional state is provided for the model evaluation

  • is_batched (bool) – If likelihood expects a batch of observations instead of a single observation. If the likelihood is batched, choosing the strategy has no influence on the computation.

  • temperature (float) – Posterior temperature. T = 1 is the Bayesian posterior.

Return type:

StochasticPotential

Returns:

Returns a function which evaluates the stochastic potential for a mini-batch of data. The first argument are the latent variables and the second is the mini-batch.

class jax_sgmc.potential.StochasticPotential(*args, **kwargs)[source]
__call__(sample, reference_data, state=None, mask=None, likelihoods=False)[source]

Calculates the stochastic potential for a mini-batch of data.

Parameters:
Return type:

Union[Tuple[Array, Any], Tuple[Array, Tuple[Array, Any]]]

Returns:

Returns an approximation of the true potential based on a mini-batch of reference data. Moreover, the likelihood for every single observation can be returned.

Full Potential

jax_sgmc.potential.full_potential(prior, likelihood, strategy='map', has_state=False, is_batched=False, temperature=1.0)[source]

Transforms a pdf to compute the full potential over all reference data.

Parameters:
  • prior (Callable[[Any], Array]) – Log-prior function which is evaluated for a single sample.

  • likelihood (Callable[[Any, Any], Array]) – Log-likelihood function. If has_state = True, then the first argument is the model state, otherwise the arguments are sample, reference_data.

  • strategy (AnyStr) –

    Determines how to evaluate the model function with respect for sample:

    • 'map' sequential evaluation

    • 'vmap' parallel evaluation via vectorization

    • 'pmap' parallel evaluation on multiple devices

  • has_state (bool) – If an additional state is provided for the model evaluation

  • is_batched (bool) – If likelihood expects a batch of observations instead of a single observation. If the likelihood is batched, choosing the strategy has no influence on the computation. In this case, the last argument of the likelihood should be an optional mask. The mask is an arrays with ones for valid observations and zeros for non-valid observations.

  • temperature (float) – Posterior temperature. T = 1 is the Bayesian posterior.

Return type:

FullPotential

Returns:

Returns a function which evaluates the potential over the full dataset via a dataset mapping from the jax_sgmc.data module.

class jax_sgmc.potential.FullPotential(*args, **kwargs)[source]
__call__(sample, data_state, full_data_map_fn, state=None)[source]

Calculates the potential over the full dataset.

Parameters:
  • sample (Any) – Model parameters

  • data_state (CacheState) – State of the full_data_map functional

  • full_data_map_fn (Callable) – Functional mapping a function over the complete dataset

  • state (Optional[Any]) – Special parameters of the model which should not change the result of a model evaluation.

Return type:

Tuple[Array, Tuple[CacheState, Any]]

Returns:

Returns the potential of the current sample using the full dataset.