jax_sgmc.data

jax_sgmc.data.core

Data input in jit-compiled functions

Limited device memory disallows to store all reference data on the device for big datasets. With the following functions, mini-batches of data can be requested in jit-compiled functions without loading the entire dataset into the device’s memory.

In the background, a cache of mini-batches is sequentially requested by the Host Callback Wrappers from the Data Loaders and loaded on the device via the jax_sgmc.util.host_callback module.

Host Callback Wrappers

jax_sgmc.data.core.random_reference_data(data_loader, cached_batches_count, mb_size, verify_calls=False)[source]

Initializes reference data access in jit-compiled functions.

Randomly draw batches from a given dataset on the host or the device.

Parameters:
  • data_loader (DataLoader) – Reads data from storage.

  • cached_batches_count (int) – Number of batches in the cache. A larger number is faster, but requires more memory.

  • mb_size (int) – Size of the data batch.

  • verify_calls (bool) – Verify calls to the host when using pmap. This might not be necessary if batches are assembles randomly without extra conditions.

Return type:

Tuple[Any, GetBatchFunction, Callable[[], None]]

Returns:

Returns a tuple of functions to initialize a new reference data state, get a minibatch from the reference data state and release the data loader after the last computation.

jax_sgmc.data.core.full_reference_data(data_loader, cached_batches_count=100, mb_size=None)[source]

Initializes reference data access in jit-compiled functions.

Map a function batch-wise over a dataset on the host or the device.

Parameters:
  • data_loader (DataLoader) – Reads data from storage.

  • cached_batches_count (int) – Number of batches in the cache. A larger number is faster, but requires more memory.

  • mb_size (Optional[int]) – Size of the data batch.

Return type:

Tuple[Any, FullDataMapFunction, Callable[[], None]]

Returns:

Returns a tuple of functions to initialize a new reference data state, map a function over the complete dataset and release the data loader after the last computation.

class jax_sgmc.data.core.GetBatchFunction(*args, **kwargs)[source]
__call__(data_state, information=False, device_count=1)[source]

Draws a batch of data.

Parameters:
  • data_state (CacheState) – State of the chain containing id and cached batches

  • information (bool) – Include namedtuple containing information about the data and batch

  • device_count (int) – Number of the devices on which this function is going to be called with replicated data states.

Return type:

Union[Tuple[CacheState, Any], Tuple[CacheState, Tuple[Any, MiniBatchInformation]]]

Returns:

Returns the new state of the random chain and a batch. Optionally a namedtuple containing information about the batch and dataset can be returned.

class jax_sgmc.data.core.FullDataMapFunction(*args, **kwargs)[source]
__call__(fun, data_state, carry, masking=False, information=False, device_count=1)[source]

Maps a function over the complete dataset.

Parameters:
  • fun (Union[MaskedMappedFunction, UnmaskedMappedFunction]) – Function to be mapped over the dataset

  • data_state (CacheState) – Namedtuple containing the id of the chain and cached batches

  • carry (Any) – Variables which are carried over to the next evaluation of fun

  • masking (bool) – If true, an array marking invalid samples is passed to the function such that a single result for a batch of data can be calculated. If false, then a result for each observation must be returned and the invalid results are discarded after the computation.

  • information (bool) – Pass the batch information together with the batch

  • device_count (int) – Number of the devices on which this function is going to be called with replicated data states.

Return type:

Tuple[Any, Any]

Returns:

Returns the new data state and the results of the computation including the carry of the last computation:

(data_state, (results, carry)) = full_data_map(...)

class jax_sgmc.data.core.MaskedMappedFunction(*args, **kwargs)[source]
__call__(batch, mask, state)[source]

Function which can be mapped over the whole dataset.

A function of this form must be passed to the full data map function if it is called with masking = True.

Parameters:
  • batch (Any) – Batch of data

  • mask (Array) – Array marking invalid (double) samples

  • state (Any) – Variables which results are used in the next computation

Return type:

Tuple[Any, Any]

Returns:

Must return a tuple consisting of the computation results and the state which should be used in the computation of the next batch.

class jax_sgmc.data.core.UnmaskedMappedFunction(*args, **kwargs)[source]
__call__(batch, state)[source]

Function which can be mapped over the whole dataset.

A function of this form must be passed to the full data map function if it is called with masking = False.

Parameters:
  • batch (Any) – Batch of data

  • state (Any) – Variables which results are used in the next computation

Return type:

Tuple[Any, Any]

Returns:

Must return a tuple consisting of the computation results and the state which should be used in the computation of the next batch.

jax_sgmc.data.core.full_data_mapper(data_loader=None, cached_batches_count=1, mb_size=1)[source]

Initializes a functional to map a function over a complete dataset.

This function extends the functionality of full_reference_data() by loading the data states form the host before each mapping.

Parameters:
  • data_loader (Optional[DataLoader]) – Reads data from storage

  • cached_batches_count (int) – Number of batches in the cache. A larger number is faster, but requires more memory

  • mb_size (int) – Size of the data batch

Return type:

Tuple[FullDataMapperFunction, Callable]

Returns:

Returns a tuple of functions to map another function over a complete dataset of an appropriate DataLoader and another function to release the data loader after the last computation.

class jax_sgmc.data.core.FullDataMapperFunction(*args, **kwargs)[source]
__call__(fun, carry, masking=False, information=False, batched=True, device_count=1)[source]

Maps a function over the complete dataset.

This function differs to FullDataMapFunction that it acquires a CacheState before each mapping over the full dataset.

Parameters:
  • fun (Union[MaskedMappedFunction, UnmaskedMappedFunction]) – Function to be mapped over the dataset

  • carry (Any) – Variables which are carried over to the next evaluation of fun

  • masking (bool) – If true, an array marking invalid samples is passed to the function such that a single result for a batch of data can be calculated. If false, then a result for each observation must be returned and the invalid results are discarded automatically after the computation.

  • information (bool) – Pass the batch information together with the batch

  • batched (bool) – Whether function to be mapped over full dataset is vectorized. If false, the function is vmapped such that it can process a batch of observations.

  • device_count (int) – Number of the devices on which this function is going to be called with replicated data states.

Return type:

Any

Returns:

Returns the results of the computation including the carry of the last computation:

(results, carry) = full_data_mapper(...)

States

class jax_sgmc.data.core.MiniBatchInformation(observation_count: Array, mask: Array, batch_size: Array)[source]

Bundles all information about the reference data.

Parameters:
  • observation_count – Total number of observations

  • effective_observation_count – The number of observations without the discarded samples remaining after e.g. shuffling.

  • mini_batch – List of tuples, tuples consist of (observations, parameters)

class jax_sgmc.data.core.CacheState(callback_uuid: Optional[JaxUUID] = None, cached_batches: Optional[Any] = None, cached_batches_count: Optional[Array] = None, current_line: Optional[Array] = None, chain_id: Optional[Array] = None, state: Optional[Any] = None, valid: Optional[Array] = None, token: Optional[JaxUUID] = None)[source]

Caches several batches of randomly batched reference data.

Parameters:
  • cached_batches – An array of mini-batches

  • cached_batches_count – Number of cached mini-batches. Equals the first dimension of the cached batches

  • current_line – Marks the next batch to be returned.

  • chain_id – Identifier of the chain

  • state – Additional information

  • valid – Array containing information about the validity of individual samples

Base Classes

class jax_sgmc.data.core.DataLoader[source]

Abstract class to define required methods of a DataLoader.

This class defines common methods of a DataLoader, such as returning an all- zero batch with correct shape to initialize the model.

initializer_batch(mb_size=None)[source]

Returns a zero-like mini-batch.

Parameters:

mb_size (Optional[int]) – Number of observations in a batch. If None, the returned pytree has the shape of a single observation.

Return type:

Any

abstract property static_information: Dict

Information about the dataset such as the total observation count.

Return type:

Dict

class jax_sgmc.data.core.DeviceDataLoader[source]

Abstract class to define required methods of a DeviceDataLoader.

A class implementing the data loader must have the functionality to return the complete dataset as a dictionary of arrays.

abstract get_full_data()[source]

Returns the whole dataset as dictionary of arrays.

Return type:

Dict

abstract get_random_data(state, batch_size)[source]

Returns a random batch of the data.

This function must be jit-able and free of side effects.

Return type:

Tuple[Any, Tuple[Any, MiniBatchInformation]]

abstract init_random_data(*args, **kwargs)[source]

Initializes the state necessary to randomly draw data.

Return type:

Any

class jax_sgmc.data.core.HostDataLoader[source]

Abstract class to define required methods of a HostDataLoader.

A class implementing the data loader must have the functionality to load data from storage in an ordered and a random fashion.

batch_format(cache_size, mb_size)[source]

Returns dtype and shape of cached mini-batches.

Parameters:

cache_size (int) – number of cached mini-batches

Return type:

Tuple[Any, MiniBatchInformation]

Returns:

Returns a pytree with the same tree structure as the random data cache but with jax.ShapedDtypeStruct as leaves.

abstract get_batches(chain_id)[source]

Return batches from an ordered or random chain.

Return type:

Tuple[Any, Optional[Array]]

load_state(chain_id, data)[source]

Restores dataloader state from previously computed checkpoint.

Parameters:
  • chain_id (int) – The chain to restore the state.

  • data – Data from save_state() to restore state of the chain.

abstract register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]

Register a chain which assembles batches in an ordered manner.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

Return type:

int

Returns:

Returns the id of the new chain.

abstract register_random_pipeline(cache_size=1, mb_size=None, **kwargs)[source]

Register a new chain which assembles batches randomly.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

  • seed – Set the random seed to start the chain at a well-defined state.

Return type:

int

Returns:

Returns the id of the new chain.

save_state(chain_id)[source]

Returns all necessary information to restore the dataloader state.

Parameters:

chain_id (int) – Each chain can be checkpointed independently.

Returns:

Returns necessary information to restore the state of the chain via load_state().

Utility Functions

tree_index(pytree, index)

Indexes the leaves of the tree in the first dimension.

tree_dtype_struct(pytree)

Returns a tree with leaves only representing shape and type.

jax_sgmc.data.numpy_loader

Load numpy arrays in jit-compiled functions.

The numpy data loader is easy to use if the whole dataset fits into RAM and is already present as numpy-arrays.

class jax_sgmc.data.numpy_loader.NumpyBase(on_device=True, copy=True, **reference_data)[source]
property reference_data

Returns the reference data as a dictionary.

property static_information

Returns information about total samples count and batch size.

class jax_sgmc.data.numpy_loader.NumpyDataLoader(copy=True, **reference_data)[source]

Load complete dataset into memory from multiple numpy arrays.

This data loader supports checkpointing, starting chains from a well-defined state and true random access.

The pipeline can be constructed directly from numpy arrays:

>>> import numpy as onp
>>> from jax_sgmc.data.numpy_loader import NumpyDataLoader
>>>
>>> x, y = onp.arange(10), onp.zeros((10, 4, 3))
>>>
>>> data_loader = NumpyDataLoader(name_for_x=x, name_for_y=y)
>>>
>>> zero_batch = data_loader.initializer_batch(4)
>>> for key, value in zero_batch.items():
...   print(f"{key}: shape={value.shape}, dtype={value.dtype}")
name_for_x: shape=(4,), dtype=int32
name_for_y: shape=(4, 4, 3), dtype=float32
Parameters:
  • reference_data – Each kwarg-pair is an entry in the returned data-dict.

  • copy – Whether to copy the reference data (default True) or only create a reference.

get_batches(chain_id)[source]

Draws a batch from a chain.

Parameters:

chain_id (int) – ID of the chain, which holds the information about the form of the batch and the process of assembling.

Return type:

Any

Returns:

Returns a batch of batches as registered by register_random_pipeline() or register_ordered_pipeline() with cache_size batches holding mb_size observations.

load_state(chain_id, data)[source]

Restores dataloader state from previously computed checkpoint.

Parameters:
  • chain_id (int) – The chain to restore the state.

  • data – Data from save_state() to restore state of the chain.

Return type:

None

register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]

Register a chain which assembles batches in an ordered manner.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

  • seed – Set the random seed to start the chain at a well-defined state.

Return type:

int

Returns:

Returns the id of the new chain.

register_random_pipeline(cache_size=1, mb_size=None, in_epochs=False, shuffle=False, **kwargs)[source]

Register a new chain which draws samples randomly.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

  • shuffle (bool) – Shuffle dataset instead of drawing randomly from the observations.

  • in_epochs (bool) – Samples returned twice per epoch are marked via mask = 0 (only if shuffle = True.

  • seed – Set the random seed to start the chain at a well-defined state.

Return type:

int

Returns:

Returns the id of the new chain.

save_state(chain_id)[source]

Returns all necessary information to restore the dataloader state.

Parameters:

chain_id (int) – Each chain can be checkpointed independently.

Return type:

Any

Returns:

Returns necessary information to restore the state of the chain via load_state().

class jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader(copy=True, **reference_data)[source]

Load complete dataset into memory from multiple numpy arrays.

This data loader supports checkpointing, starting chains from a well-defined state and true random access.

The pipeline can be constructed directly from numpy arrays:

>>> import numpy as onp
>>> from jax_sgmc.data.numpy_loader import DeviceNumpyDataLoader
>>>
>>> x, y = onp.arange(10), onp.zeros((10, 4, 3))
>>>
>>> data_loader = DeviceNumpyDataLoader(name_for_x=x, name_for_y=y)
>>>
>>> zero_batch = data_loader.initializer_batch(4)
>>> for key, value in zero_batch.items():
...   print(f"{key}: shape={value.shape}, dtype={value.dtype}")
name_for_x: shape=(4,), dtype=int32
name_for_y: shape=(4, 4, 3), dtype=float32
Parameters:
  • reference_data – Each kwarg-pair is an entry in the returned data-dict.

  • copy – Whether to copy the reference data (default True) or only create a reference.

get_full_data()[source]

Returns the whole dataset as dictionary of arrays.

Return type:

Dict

get_random_data(state, batch_size)[source]

Returns a random batch of the data.

This function must be jit-able and free of side effects.

Return type:

Tuple[Any, Tuple[Any, MiniBatchInformation]]

init_random_data(*args, **kwargs)[source]

Initializes the state necessary to randomly draw data.

Return type:

Any

jax_sgmc.data.tensorflow_loader

Load Tensorflow-Datasets in jit-compiled functions.

The tensorflow dataloader supports tensorflow Datasets, e.g. from the tensorflow_datasets package.

Note

This submodule requires that tensorflow and tensorflow_datasets are installed. Additional information can be found in the installation instructions.

class jax_sgmc.data.tensorflow_loader.TensorflowDataLoader(pipeline, mini_batch_size=None, shuffle_cache=100, exclude_keys=None)[source]

Load data from a tensorflow dataset object.

The tensorflow datasets package provides a high number of ready to go datasets, which can be provided directly to the Tensorflow Data Loader.

import tensorflow_datasets as tdf
import tensorflow_datasets as tfds
from jax_sgmc import data
from jax_sgmc.data.tensorflow_loader import TensorflowDataLoader

pipeline = tfds.load("cifar10", split="train")
data_loader = TensorflowDataLoader(pipeline, shuffle_cache=100, exclude_keys=['id'])
Parameters:

pipeline (DatasetV2) – A tensorflow data pipeline, which can be obtained from the tensorflow dataset package

get_batches(chain_id)[source]

Draws a batch from a chain.

Parameters:

chain_id (int) – ID of the chain, which holds the information about the form of the batch and the process of assembling.

Return type:

Any

Returns:

Returns a batch of batches as registered by register_random_pipeline() or register_ordered_pipeline() with cache_size batches holding mb_size observations.

register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]

Register a chain which assembles batches in an ordered manner.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

Return type:

int

Returns:

Returns the id of the new chain.

register_random_pipeline(cache_size=1, mb_size=None, **kwargs)[source]

Register a new chain which draws samples randomly.

Parameters:
  • cache_size (int) – The number of drawn batches.

  • mb_size (Optional[int]) – The number of observations per batch.

Return type:

int

Returns:

Returns the id of the new chain.

property static_information

Returns information about total samples count and batch size.

jax_sgmc.data.hdf5_loader

Use samples saved with jax_sgmc.io.HDF5Collector as reference data.

class jax_sgmc.data.hdf5_loader.HDF5Loader(file, subdir='/chain~0/variables/', sample=None)[source]

Load reference data from HDF5-files.

This data loader can load reference data stored in HDF5 files. This makes it possible to use the jax_sgmc.data module to evaluate samples saved via the jax_sgmc.io.HDF5Collector.

Parameters:
  • file – Path to the HDF5 file containing the reference data

  • subdir – Path to the subset of the data set which should be loaded

  • sample – PyTree to specify the original shape of the sub-pytree before it has been saved by the jax_sgmc.io.HDF5Collector