jax_sgmc.potential
Utility to evaluate stochastic or true potential.
This module transforms the likelihood function for a single observation or a
batch of observations to a function calculating the stochastic or full potential
making use of map
, vmap
and pmap
.
Stochastic Potential
- jax_sgmc.potential.minibatch_potential(prior, likelihood, strategy='map', has_state=False, is_batched=False, temperature=1.0)[source]
Initializes the potential function for a minibatch of data.
- Parameters:
prior (
Callable
[[Any
],Array
]) – Log-prior function which is evaluated for a single sample.likelihood (
Callable
) – Log-likelihood function. Ifhas_state = True
, then the first argument is the model state, otherwise the arguments aresample, reference_data
.strategy (
AnyStr
) –Determines hwo to evaluate the model function with respect for sample:
'map'
sequential evaluation'vmap'
parallel evaluation via vectorization'pmap'
parallel evaluation on multiple devices
has_state (
bool
) – If an additional state is provided for the model evaluationis_batched (
bool
) – If likelihood expects a batch of observations instead of a single observation. If the likelihood is batched, choosing the strategy has no influence on the computation.temperature (
float
) – Posterior temperature. T = 1 is the Bayesian posterior.
- Return type:
- Returns:
Returns a function which evaluates the stochastic potential for a mini-batch of data. The first argument are the latent variables and the second is the mini-batch.
- class jax_sgmc.potential.StochasticPotential(*args, **kwargs)[source]
- __call__(sample, reference_data, state=None, mask=None, likelihoods=False)[source]
Calculates the stochastic potential for a mini-batch of data.
- Parameters:
sample (
Any
) – Model parametersreference_data (
Union
[Tuple
[Any
],Tuple
[Any
,MiniBatchInformation
],Tuple
[Any
,MiniBatchInformation
,Array
]]) – Batch of observationsstate (
Optional
[Any
]) – Special parameters of the model which should not change the result of a model evaluation.mask (
Optional
[Array
]) – Marking invalid (e.g. double) sampleslikelihoods (
bool
) – Return the likelihoods of all model evaluations separately
- Return type:
- Returns:
Returns an approximation of the true potential based on a mini-batch of reference data. Moreover, the likelihood for every single observation can be returned.
Full Potential
- jax_sgmc.potential.full_potential(prior, likelihood, strategy='map', has_state=False, is_batched=False, temperature=1.0)[source]
Transforms a pdf to compute the full potential over all reference data.
- Parameters:
prior (
Callable
[[Any
],Array
]) – Log-prior function which is evaluated for a single sample.likelihood (
Callable
[[Any
,Any
],Array
]) – Log-likelihood function. Ifhas_state = True
, then the first argument is the model state, otherwise the arguments aresample, reference_data
.strategy (
AnyStr
) –Determines how to evaluate the model function with respect for sample:
'map'
sequential evaluation'vmap'
parallel evaluation via vectorization'pmap'
parallel evaluation on multiple devices
has_state (
bool
) – If an additional state is provided for the model evaluationis_batched (
bool
) – If likelihood expects a batch of observations instead of a single observation. If the likelihood is batched, choosing the strategy has no influence on the computation. In this case, the last argument of the likelihood should be an optional mask. The mask is an arrays with ones for valid observations and zeros for non-valid observations.temperature (
float
) – Posterior temperature. T = 1 is the Bayesian posterior.
- Return type:
- Returns:
Returns a function which evaluates the potential over the full dataset via a dataset mapping from the
jax_sgmc.data
module.
- class jax_sgmc.potential.FullPotential(*args, **kwargs)[source]
- __call__(sample, data_state, full_data_map_fn, state=None)[source]
Calculates the potential over the full dataset.
- Parameters:
sample (
Any
) – Model parametersdata_state (
CacheState
) – State of thefull_data_map
functionalfull_data_map_fn (
Callable
) – Functional mapping a function over the complete datasetstate (
Optional
[Any
]) – Special parameters of the model which should not change the result of a model evaluation.
- Return type:
Tuple
[Array
,Tuple
[CacheState
,Any
]]- Returns:
Returns the potential of the current sample using the full dataset.