jax_sgmc.data
jax_sgmc.data.core
Data input in jit-compiled functions
Limited device memory disallows to store all reference data on the device for
big datasets. With the following functions, mini-batches of data can be
requested in jit
-compiled functions without loading the entire dataset into
the device’s memory.
In the background, a cache of mini-batches is sequentially requested by the
Host Callback Wrappers from the Data Loaders and loaded on the device via the
jax_sgmc.util.host_callback
module.
Host Callback Wrappers
- jax_sgmc.data.core.random_reference_data(data_loader, cached_batches_count, mb_size, verify_calls=False)[source]
Initializes reference data access in jit-compiled functions.
Randomly draw batches from a given dataset on the host or the device.
- Parameters:
data_loader (
DataLoader
) – Reads data from storage.cached_batches_count (
int
) – Number of batches in the cache. A larger number is faster, but requires more memory.mb_size (
int
) – Size of the data batch.verify_calls (
bool
) – Verify calls to the host when using pmap. This might not be necessary if batches are assembles randomly without extra conditions.
- Return type:
Tuple
[Any
,GetBatchFunction
,Callable
[[],None
]]- Returns:
Returns a tuple of functions to initialize a new reference data state, get a minibatch from the reference data state and release the data loader after the last computation.
- jax_sgmc.data.core.full_reference_data(data_loader, cached_batches_count=100, mb_size=None)[source]
Initializes reference data access in jit-compiled functions.
Map a function batch-wise over a dataset on the host or the device.
- Parameters:
data_loader (
DataLoader
) – Reads data from storage.cached_batches_count (
int
) – Number of batches in the cache. A larger number is faster, but requires more memory.
- Return type:
Tuple
[Any
,FullDataMapFunction
,Callable
[[],None
]]- Returns:
Returns a tuple of functions to initialize a new reference data state, map a function over the complete dataset and release the data loader after the last computation.
- class jax_sgmc.data.core.GetBatchFunction(*args, **kwargs)[source]
- __call__(data_state, information=False, device_count=1)[source]
Draws a batch of data.
- Parameters:
data_state (
CacheState
) – State of the chain containing id and cached batchesinformation (
bool
) – Include namedtuple containing information about the data and batchdevice_count (
int
) – Number of the devices on which this function is going to be called with replicated data states.
- Return type:
Union
[Tuple
[CacheState
,Any
],Tuple
[CacheState
,Tuple
[Any
,MiniBatchInformation
]]]- Returns:
Returns the new state of the random chain and a batch. Optionally a namedtuple containing information about the batch and dataset can be returned.
- class jax_sgmc.data.core.FullDataMapFunction(*args, **kwargs)[source]
- __call__(fun, data_state, carry, masking=False, information=False, device_count=1)[source]
Maps a function over the complete dataset.
- Parameters:
fun (
Union
[MaskedMappedFunction
,UnmaskedMappedFunction
]) – Function to be mapped over the datasetdata_state (
CacheState
) – Namedtuple containing the id of the chain and cached batchescarry (
Any
) – Variables which are carried over to the next evaluation offun
masking (
bool
) – If true, an array marking invalid samples is passed to the function such that a single result for a batch of data can be calculated. If false, then a result for each observation must be returned and the invalid results are discarded after the computation.information (
bool
) – Pass the batch information together with the batchdevice_count (
int
) – Number of the devices on which this function is going to be called with replicated data states.
- Return type:
- Returns:
Returns the new data state and the results of the computation including the carry of the last computation:
(data_state, (results, carry)) = full_data_map(...)
- class jax_sgmc.data.core.MaskedMappedFunction(*args, **kwargs)[source]
- __call__(batch, mask, state)[source]
Function which can be mapped over the whole dataset.
A function of this form must be passed to the full data map function if it is called with
masking = True
.- Parameters:
- Return type:
- Returns:
Must return a tuple consisting of the computation results and the state which should be used in the computation of the next batch.
- class jax_sgmc.data.core.UnmaskedMappedFunction(*args, **kwargs)[source]
- jax_sgmc.data.core.full_data_mapper(data_loader=None, cached_batches_count=1, mb_size=1)[source]
Initializes a functional to map a function over a complete dataset.
This function extends the functionality of
full_reference_data()
by loading the data states form the host before each mapping.- Parameters:
data_loader (
Optional
[DataLoader
]) – Reads data from storagecached_batches_count (
int
) – Number of batches in the cache. A larger number is faster, but requires more memorymb_size (
int
) – Size of the data batch
- Return type:
- Returns:
Returns a tuple of functions to map another function over a complete dataset of an appropriate
DataLoader
and another function to release the data loader after the last computation.
- class jax_sgmc.data.core.FullDataMapperFunction(*args, **kwargs)[source]
- __call__(fun, carry, masking=False, information=False, batched=True, device_count=1)[source]
Maps a function over the complete dataset.
This function differs to
FullDataMapFunction
that it acquires aCacheState
before each mapping over the full dataset.- Parameters:
fun (
Union
[MaskedMappedFunction
,UnmaskedMappedFunction
]) – Function to be mapped over the datasetcarry (
Any
) – Variables which are carried over to the next evaluation offun
masking (
bool
) – If true, an array marking invalid samples is passed to the function such that a single result for a batch of data can be calculated. If false, then a result for each observation must be returned and the invalid results are discarded automatically after the computation.information (
bool
) – Pass the batch information together with the batchbatched (
bool
) – Whether function to be mapped over full dataset is vectorized. If false, the function is vmapped such that it can process a batch of observations.device_count (
int
) – Number of the devices on which this function is going to be called with replicated data states.
- Return type:
- Returns:
Returns the results of the computation including the carry of the last computation:
(results, carry) = full_data_mapper(...)
States
- class jax_sgmc.data.core.MiniBatchInformation(observation_count: Array, mask: Array, batch_size: Array)[source]
Bundles all information about the reference data.
- Parameters:
observation_count – Total number of observations
effective_observation_count – The number of observations without the discarded samples remaining after e.g. shuffling.
mini_batch – List of tuples, tuples consist of
(observations, parameters)
- class jax_sgmc.data.core.CacheState(callback_uuid: Optional[JaxUUID] = None, cached_batches: Optional[Any] = None, cached_batches_count: Optional[Array] = None, current_line: Optional[Array] = None, chain_id: Optional[Array] = None, state: Optional[Any] = None, valid: Optional[Array] = None, token: Optional[JaxUUID] = None)[source]
Caches several batches of randomly batched reference data.
- Parameters:
cached_batches – An array of mini-batches
cached_batches_count – Number of cached mini-batches. Equals the first dimension of the cached batches
current_line – Marks the next batch to be returned.
chain_id – Identifier of the chain
state – Additional information
valid – Array containing information about the validity of individual samples
Base Classes
- class jax_sgmc.data.core.DataLoader[source]
Abstract class to define required methods of a DataLoader.
This class defines common methods of a DataLoader, such as returning an all- zero batch with correct shape to initialize the model.
- class jax_sgmc.data.core.DeviceDataLoader[source]
Abstract class to define required methods of a DeviceDataLoader.
A class implementing the data loader must have the functionality to return the complete dataset as a dictionary of arrays.
- class jax_sgmc.data.core.HostDataLoader[source]
Abstract class to define required methods of a HostDataLoader.
A class implementing the data loader must have the functionality to load data from storage in an ordered and a random fashion.
- batch_format(cache_size, mb_size)[source]
Returns dtype and shape of cached mini-batches.
- Parameters:
cache_size (
int
) – number of cached mini-batches- Return type:
- Returns:
Returns a pytree with the same tree structure as the random data cache but with
jax.ShapedDtypeStruct
as leaves.
- load_state(chain_id, data)[source]
Restores dataloader state from previously computed checkpoint.
- Parameters:
chain_id (
int
) – The chain to restore the state.data – Data from
save_state()
to restore state of the chain.
- abstract register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]
Register a chain which assembles batches in an ordered manner.
- abstract register_random_pipeline(cache_size=1, mb_size=None, **kwargs)[source]
Register a new chain which assembles batches randomly.
- save_state(chain_id)[source]
Returns all necessary information to restore the dataloader state.
- Parameters:
chain_id (
int
) – Each chain can be checkpointed independently.- Returns:
Returns necessary information to restore the state of the chain via
load_state()
.
Utility Functions
|
Indexes the leaves of the tree in the first dimension. |
|
Returns a tree with leaves only representing shape and type. |
jax_sgmc.data.numpy_loader
Load numpy arrays in jit-compiled functions.
The numpy data loader is easy to use if the whole dataset fits into RAM and is already present as numpy-arrays.
- class jax_sgmc.data.numpy_loader.NumpyBase(on_device=True, copy=True, **reference_data)[source]
- property reference_data
Returns the reference data as a dictionary.
- property static_information
Returns information about total samples count and batch size.
- class jax_sgmc.data.numpy_loader.NumpyDataLoader(copy=True, **reference_data)[source]
Load complete dataset into memory from multiple numpy arrays.
This data loader supports checkpointing, starting chains from a well-defined state and true random access.
The pipeline can be constructed directly from numpy arrays:
>>> import numpy as onp >>> from jax_sgmc.data.numpy_loader import NumpyDataLoader >>> >>> x, y = onp.arange(10), onp.zeros((10, 4, 3)) >>> >>> data_loader = NumpyDataLoader(name_for_x=x, name_for_y=y) >>> >>> zero_batch = data_loader.initializer_batch(4) >>> for key, value in zero_batch.items(): ... print(f"{key}: shape={value.shape}, dtype={value.dtype}") name_for_x: shape=(4,), dtype=int32 name_for_y: shape=(4, 4, 3), dtype=float32
- Parameters:
reference_data – Each kwarg-pair is an entry in the returned data-dict.
copy – Whether to copy the reference data (default True) or only create a reference.
- get_batches(chain_id)[source]
Draws a batch from a chain.
- Parameters:
chain_id (
int
) – ID of the chain, which holds the information about the form of the batch and the process of assembling.- Return type:
- Returns:
Returns a batch of batches as registered by
register_random_pipeline()
orregister_ordered_pipeline()
with cache_size batches holding mb_size observations.
- load_state(chain_id, data)[source]
Restores dataloader state from previously computed checkpoint.
- Parameters:
chain_id (
int
) – The chain to restore the state.data – Data from
save_state()
to restore state of the chain.
- Return type:
- register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]
Register a chain which assembles batches in an ordered manner.
- register_random_pipeline(cache_size=1, mb_size=None, in_epochs=False, shuffle=False, **kwargs)[source]
Register a new chain which draws samples randomly.
- Parameters:
cache_size (
int
) – The number of drawn batches.mb_size (
Optional
[int
]) – The number of observations per batch.shuffle (
bool
) – Shuffle dataset instead of drawing randomly from the observations.in_epochs (
bool
) – Samples returned twice per epoch are marked via mask = 0 (only ifshuffle = True
.seed – Set the random seed to start the chain at a well-defined state.
- Return type:
- Returns:
Returns the id of the new chain.
- save_state(chain_id)[source]
Returns all necessary information to restore the dataloader state.
- Parameters:
chain_id (
int
) – Each chain can be checkpointed independently.- Return type:
- Returns:
Returns necessary information to restore the state of the chain via
load_state()
.
- class jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader(copy=True, **reference_data)[source]
Load complete dataset into memory from multiple numpy arrays.
This data loader supports checkpointing, starting chains from a well-defined state and true random access.
The pipeline can be constructed directly from numpy arrays:
>>> import numpy as onp >>> from jax_sgmc.data.numpy_loader import DeviceNumpyDataLoader >>> >>> x, y = onp.arange(10), onp.zeros((10, 4, 3)) >>> >>> data_loader = DeviceNumpyDataLoader(name_for_x=x, name_for_y=y) >>> >>> zero_batch = data_loader.initializer_batch(4) >>> for key, value in zero_batch.items(): ... print(f"{key}: shape={value.shape}, dtype={value.dtype}") name_for_x: shape=(4,), dtype=int32 name_for_y: shape=(4, 4, 3), dtype=float32
- Parameters:
reference_data – Each kwarg-pair is an entry in the returned data-dict.
copy – Whether to copy the reference data (default True) or only create a reference.
jax_sgmc.data.tensorflow_loader
Load Tensorflow-Datasets in jit-compiled functions.
The tensorflow dataloader supports tensorflow Datasets, e.g. from the tensorflow_datasets package.
Note
This submodule requires that tensorflow
and tensorflow_datasets
are
installed. Additional information can be found in the
installation instructions.
- class jax_sgmc.data.tensorflow_loader.TensorflowDataLoader(pipeline, mini_batch_size=None, shuffle_cache=100, exclude_keys=None)[source]
Load data from a tensorflow dataset object.
The tensorflow datasets package provides a high number of ready to go datasets, which can be provided directly to the Tensorflow Data Loader.
import tensorflow_datasets as tdf import tensorflow_datasets as tfds from jax_sgmc import data from jax_sgmc.data.tensorflow_loader import TensorflowDataLoader pipeline = tfds.load("cifar10", split="train") data_loader = TensorflowDataLoader(pipeline, shuffle_cache=100, exclude_keys=['id'])
- Parameters:
pipeline (
DatasetV2
) – A tensorflow data pipeline, which can be obtained from the tensorflow dataset package
- get_batches(chain_id)[source]
Draws a batch from a chain.
- Parameters:
chain_id (
int
) – ID of the chain, which holds the information about the form of the batch and the process of assembling.- Return type:
- Returns:
Returns a batch of batches as registered by
register_random_pipeline()
orregister_ordered_pipeline()
with cache_size batches holding mb_size observations.
- register_ordered_pipeline(cache_size=1, mb_size=None, **kwargs)[source]
Register a chain which assembles batches in an ordered manner.
- register_random_pipeline(cache_size=1, mb_size=None, **kwargs)[source]
Register a new chain which draws samples randomly.
- property static_information
Returns information about total samples count and batch size.
jax_sgmc.data.hdf5_loader
Use samples saved with jax_sgmc.io.HDF5Collector
as reference data.
- class jax_sgmc.data.hdf5_loader.HDF5Loader(file, subdir='/chain~0/variables/', sample=None)[source]
Load reference data from HDF5-files.
This data loader can load reference data stored in HDF5 files. This makes it possible to use the
jax_sgmc.data
module to evaluate samples saved via thejax_sgmc.io.HDF5Collector
.- Parameters:
file – Path to the HDF5 file containing the reference data
subdir – Path to the subset of the data set which should be loaded
sample – PyTree to specify the original shape of the sub-pytree before it has been saved by the
jax_sgmc.io.HDF5Collector