# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data input in jit-compiled functions
Limited device memory disallows to store all reference data on the device for
big datasets. With the following functions, mini-batches of data can be
requested in ``jit``-compiled functions without loading the entire dataset into
the device's memory.
In the background, a cache of mini-batches is sequentially requested by the
Host Callback Wrappers from the Data Loaders and loaded on the device via the
:mod:`jax_sgmc.util.host_callback` module.
"""
import abc
import threading
import warnings
from functools import partial
import itertools
from typing import Tuple, Any, Callable, Union, NamedTuple, Dict, Protocol, List
import jax
from jax import tree_util, lax
import jax.numpy as jnp
import jax.experimental as jxp
import numpy as onp
from jax_sgmc.util import Array, stop_vmap
from jax_sgmc.util.uuid import JaxUUID
PyTree = Any
MiniBatch = Union[Tuple[PyTree],
Tuple[PyTree, MiniBatchInformation],
Tuple[PyTree, MiniBatchInformation, Array]]
[docs]class DataLoader(metaclass=abc.ABCMeta):
"""Abstract class to define required methods of a DataLoader.
This class defines common methods of a DataLoader, such as returning an all-
zero batch with correct shape to initialize the model.
"""
@property
@abc.abstractmethod
def static_information(self) -> Dict:
"""Information about the dataset such as the total observation count. """
@property
@abc.abstractmethod
def _format(self):
"""dtype and shape of a single sample."""
[docs] def initializer_batch(self, mb_size: int = None) -> PyTree:
"""Returns a zero-like mini-batch.
Args:
mb_size: Number of observations in a batch. If ``None``, the returned
pytree has the shape of a single observation.
"""
obs_format = self._format
# Append the cache size to the batch_format
def append_cache_size(leaf):
if mb_size is None:
new_shape = tuple(int(s) for s in leaf.shape)
else:
new_shape = tuple(int(s) for s in
itertools.chain([mb_size], leaf.shape))
return jnp.zeros(
dtype=leaf.dtype,
shape=new_shape
)
batch = tree_util.tree_map(append_cache_size, obs_format)
return batch
[docs]class DeviceDataLoader(DataLoader, metaclass=abc.ABCMeta):
"""Abstract class to define required methods of a DeviceDataLoader.
A class implementing the data loader must have the functionality to return the
complete dataset as a dictionary of arrays.
"""
[docs] @abc.abstractmethod
def init_random_data(self, *args, **kwargs) -> PyTree:
"""Initializes the state necessary to randomly draw data. """
[docs] @abc.abstractmethod
def get_random_data(self,
state,
batch_size
) ->Tuple[PyTree, Tuple[PyTree, MiniBatchInformation]]:
"""Returns a random batch of the data.
This function must be jit-able and free of side effects.
"""
[docs] @abc.abstractmethod
def get_full_data(self) -> Dict:
"""Returns the whole dataset as dictionary of arrays."""
[docs]class HostDataLoader(DataLoader, metaclass=abc.ABCMeta):
"""Abstract class to define required methods of a HostDataLoader.
A class implementing the data loader must have the functionality to load data
from storage in an ordered and a random fashion.
"""
[docs] def save_state(self, chain_id: int):
"""Returns all necessary information to restore the dataloader state.
Args:
chain_id: Each chain can be checkpointed independently.
Returns:
Returns necessary information to restore the state of the chain via
:func:`load_state`.
"""
raise NotImplementedError("This method must be overwritten to allow "
"checkpointing of the data loader.")
[docs] def load_state(self, chain_id: int, data):
"""Restores dataloader state from previously computed checkpoint.
Args:
chain_id: The chain to restore the state.
data: Data from :func:`save_state` to restore state of the chain.
"""
raise NotImplementedError("This method must be overwritten to allow "
"checkpointing of the data loader.")
[docs] @abc.abstractmethod
def register_random_pipeline(self,
cache_size: int = 1,
mb_size: int = None,
**kwargs
) -> int:
"""Register a new chain which assembles batches randomly.
Args:
cache_size: The number of drawn batches.
mb_size: The number of observations per batch.
seed: Set the random seed to start the chain at a well-defined state.
Returns:
Returns the id of the new chain.
"""
[docs] @abc.abstractmethod
def register_ordered_pipeline(self,
cache_size: int = 1,
mb_size: int = None,
**kwargs
) -> int:
"""Register a chain which assembles batches in an ordered manner.
Args:
cache_size: The number of drawn batches.
mb_size: The number of observations per batch.
Returns:
Returns the id of the new chain.
"""
[docs] @abc.abstractmethod
def get_batches(self, chain_id: int) -> Tuple[PyTree, Union[Array, None]]:
"""Return batches from an ordered or random chain. """
[docs]class CacheState(NamedTuple):
"""Caches several batches of randomly batched reference data.
Args:
cached_batches: An array of mini-batches
cached_batches_count: Number of cached mini-batches. Equals the first
dimension of the cached batches
current_line: Marks the next batch to be returned.
chain_id: Identifier of the chain
state: Additional information
valid: Array containing information about the validity of individual samples
"""
callback_uuid: JaxUUID = None
cached_batches: PyTree = None
cached_batches_count: Array = None
current_line: Array = None
chain_id: Array = None
state: PyTree = None
valid: Array = None
token: JaxUUID = None
random_data_state = CacheState
Batch = Union[Tuple[CacheState, PyTree],
Tuple[CacheState, Tuple[PyTree, MiniBatchInformation]]]
[docs]class GetBatchFunction(Protocol):
[docs] def __call__(self,
data_state: CacheState,
information: bool = False,
device_count: int = 1) -> Batch:
"""Draws a batch of data.
Args:
data_state: State of the chain containing id and cached batches
information: Include namedtuple containing information about the data
and batch
device_count: Number of the devices on which this function is going to be
called with replicated data states.
Returns:
Returns the new state of the random chain and a batch. Optionally a
namedtuple containing information about the batch and dataset can be
returned.
"""
[docs]class MaskedMappedFunction(Protocol):
[docs] def __call__(self,
batch: PyTree,
mask: Array,
state: PyTree
) -> Tuple[PyTree, PyTree]:
"""Function which can be mapped over the whole dataset.
A function of this form must be passed to the full data map function if it
is called with ``masking = True``.
Args:
batch: Batch of data
mask: Array marking invalid (double) samples
state: Variables which results are used in the next computation
Returns:
Must return a tuple consisting of the computation results and the state
which should be used in the computation of the next batch.
"""
[docs]class UnmaskedMappedFunction(Protocol):
[docs] def __call__(self,
batch: PyTree,
state: PyTree
) -> Tuple[PyTree, PyTree]:
"""Function which can be mapped over the whole dataset.
A function of this form must be passed to the full data map function if it
is called with ``masking = False``.
Args:
batch: Batch of data
state: Variables which results are used in the next computation
Returns:
Must return a tuple consisting of the computation results and the state
which should be used in the computation of the next batch.
"""
MappedFunction = Union[MaskedMappedFunction, UnmaskedMappedFunction]
[docs]class FullDataMapFunction(Protocol):
[docs] def __call__(self,
fun: MappedFunction,
data_state: CacheState,
carry: PyTree,
masking: bool = False,
information: bool = False,
device_count: int = 1
) -> Tuple[PyTree, PyTree]:
"""Maps a function over the complete dataset.
Args:
fun: Function to be mapped over the dataset
data_state: Namedtuple containing the id of the chain and cached batches
carry: Variables which are carried over to the next evaluation of ``fun``
masking: If true, an array marking invalid samples is passed to the
function such that a single result for a batch of data can be
calculated. If false, then a result for each observation must be
returned and the invalid results are discarded after the computation.
information: Pass the batch information together with the batch
device_count: Number of the devices on which this function is going to be
called with replicated data states.
Returns:
Returns the new data state and the results of the computation including
the carry of the last computation:
::
(data_state, (results, carry)) = full_data_map(...)
"""
[docs]class FullDataMapperFunction(Protocol):
[docs] def __call__(self,
fun: MappedFunction,
carry: PyTree,
masking: bool = False,
information: bool = False,
batched: bool = True,
device_count: int = 1
) -> PyTree:
"""Maps a function over the complete dataset.
This function differs to :class:`FullDataMapFunction` that it acquires a
:class:`CacheState` before each mapping over the full dataset.
Args:
fun: Function to be mapped over the dataset
carry: Variables which are carried over to the next evaluation of ``fun``
masking: If true, an array marking invalid samples is passed to the
function such that a single result for a batch of data can be
calculated. If false, then a result for each observation must be
returned and the invalid results are discarded automatically after the
computation.
information: Pass the batch information together with the batch
batched: Whether function to be mapped over full dataset is vectorized. If
false, the function is vmapped such that it can process a batch of
observations.
device_count: Number of the devices on which this function is going to be
called with replicated data states.
Returns:
Returns the results of the computation including the carry of the last
computation:
::
(results, carry) = full_data_mapper(...)
"""
RandomBatch = Tuple[Any, GetBatchFunction, Callable[[], None]]
OrderedBatch = Tuple[Any, FullDataMapFunction, Callable[[], None]]
class _Requests:
def __init__(self):
self._cached_requests: Dict[Any, List[int, CacheState, JaxUUID]] = {}
self._requests: Dict[Any, Dict[int, CacheState, JaxUUID]] = {}
self._host_data_loaders: Dict[Any, HostDataLoader] = {}
self._lock = threading.Lock()
def __call__(self,
chain_id: int,
token: JaxUUID,
callback_uuid: JaxUUID,
device_count: int = 1,
strict: bool = False):
with self._lock:
if callback_uuid.as_uuid not in self._requests.keys():
self._requests[callback_uuid.as_uuid] = {}
# The token that is expected for a new callback
current_token = self._requests[callback_uuid.as_uuid].get(int(chain_id))
# Check if token is from previous request
old_request = self._cached_requests.get(token.as_uuid)
if old_request is not None:
counter, callback_response, new_token = old_request
if counter == 1:
# The callback response is no longer necessary
del self._cached_requests[token.as_uuid]
else:
self._cached_requests[token.as_uuid][0] -= 1
# Check if token is invalid
elif current_token != token.as_uuid and current_token is not None:
if strict:
raise RuntimeError(f"Invalid request for "
f"chain {chain_id}. This might be due to using a "
f"pmap in a jitted function. See "
f"usage/data.html#combining-pmap-and-jit in the "
f"docs.")
warnings.warn(f"Invalid request for chain "
f"{chain_id}. This might be due to using a pmap in a "
f"jitted function. If the preservation ")
# The request is valid and the first request, so the results have to be
# stored in the cache
else:
# Issue a new token and request data
new_token = JaxUUID()
callback_response = self._host_data_loaders[callback_uuid.as_uuid].get_batches(chain_id)
self._requests[callback_uuid.as_uuid][int(chain_id)] = new_token.as_uuid
# Store data if other devices are going to request it
if device_count != 1:
self._cached_requests[token.as_uuid] = [
device_count - 1, callback_response, new_token]
return callback_response, new_token
def __getstate__(self):
state = self.__dict__.copy()
# The lock should always be reinitialized
del state['_lock']
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._lock = threading.Lock()
_data_requests = _Requests()
[docs]def random_reference_data(data_loader: DataLoader,
cached_batches_count: int,
mb_size: int,
verify_calls: bool = False
) -> RandomBatch:
"""Initializes reference data access in jit-compiled functions.
Randomly draw batches from a given dataset on the host or the device.
Args:
data_loader: Reads data from storage.
cached_batches_count: Number of batches in the cache. A larger number is
faster, but requires more memory.
mb_size: Size of the data batch.
verify_calls: Verify calls to the host when using pmap. This might not be
necessary if batches are assembles randomly without extra conditions.
Returns:
Returns a tuple of functions to initialize a new reference data state, get
a minibatch from the reference data state and release the data loader after
the last computation.
"""
# Check batch size is not bigger than total observation count
observation_count = data_loader.static_information["observation_count"]
if observation_count < mb_size:
raise ValueError(f"Batch size cannot be bigger than the number of total "
f"observations. Got {observation_count} and {mb_size}.")
if cached_batches_count <= 0 or mb_size <= 0:
raise ValueError(f"Cache size and batch size must be positive, got"
f"{cached_batches_count} and {mb_size}.")
if isinstance(data_loader, HostDataLoader):
return _random_reference_data_host(
data_loader, cached_batches_count, mb_size,
verify_calls=verify_calls)
elif isinstance(data_loader, DeviceDataLoader):
if not cached_batches_count == 1:
raise ValueError("No caching on device.")
return _random_reference_data_device(
data_loader, mb_size)
else:
raise TypeError("The DataLoader must inherit from HostDataLoader or "
"DeviceDataLoader")
[docs]def full_reference_data(data_loader: DataLoader,
cached_batches_count: int = 100,
mb_size: int = None
) -> OrderedBatch:
"""Initializes reference data access in jit-compiled functions.
Map a function batch-wise over a dataset on the host or the device.
Args:
data_loader: Reads data from storage.
cached_batches_count: Number of batches in the cache. A larger number is
faster, but requires more memory.
mb_size: Size of the data batch.
Returns:
Returns a tuple of functions to initialize a new reference data state, map a
function over the complete dataset and release the data loader after the
last computation.
"""
# Check batch size is not bigger than total observation count
observation_count = data_loader.static_information["observation_count"]
if observation_count < mb_size:
raise ValueError(f"Batch size cannot be bigger than the number of total "
f"observations. Got {observation_count} and {mb_size}.")
if isinstance(data_loader, HostDataLoader):
init_fn, (_batch_fn, mb_information), cleanup = _full_reference_data_host(
data_loader, cached_batches_count, mb_size)
elif isinstance(data_loader, DeviceDataLoader):
init_fn, (_batch_fn, mb_information), cleanup = _full_reference_data_device(
data_loader, cached_batches_count)
else:
raise TypeError("The DataLoader must inherit from HostDataLoader or "
"DeviceDataLoader")
num_iterations = int(onp.ceil(
mb_information.observation_count / mb_information.batch_size))
batch_size = mb_information.batch_size
def _uninitialized_body_fn(fun,
state,
iteration,
information=False,
masking=False,
device_count=1):
# The mask has is 1 if the observation is valid and 0 otherwise. This is
# necessary to ensure, that fun is always called with the same tree shape.
observations = iteration * batch_size + jnp.arange(batch_size)
mask = observations < mb_information.observation_count
data_state, fun_state = state
data_state, batch = _batch_fn(
data_state,
information=information, device_count=device_count)
if masking:
result, fun_state = fun(batch, mask, fun_state)
else:
result, fun_state = fun(batch, fun_state)
return (data_state, fun_state), result
def batch_scan(fun: MappedFunction,
data_state: CacheState,
carry: PyTree,
masking: bool = False,
information: bool = False,
device_count: int = 1
) -> Tuple[PyTree, PyTree]:
"""Maps the function over all data.
Args:
fun: Function accepting a batch of data, a mask and a state.
data_state: Reference data state.
carry: A argument that is carried over between iterations
masking: If set to true, the mapped function is called with a positional
argument mask and expected to return the results with a reduced dimension.
Setting to true changes the signature from `fun(data, carry)` to
`fun(data, mask, carry)`.
information: Provide the minibatch information in addition to the data
batch.
"""
_body_fn = partial(_uninitialized_body_fn,
fun,
information=information,
masking=masking,
device_count=device_count)
(data_state, carry), results = lax.scan(
_body_fn, (data_state, carry), onp.arange(num_iterations))
if masking:
true_results = results
else:
# The results must be concatenated
concat_results = tree_util.tree_map(
partial(jnp.concatenate, axis=0),
results)
# Invalid results (fillers) must be thrown away
true_results = tree_util.tree_map(
lambda leaf: leaf[0:mb_information.observation_count],
concat_results)
return data_state, (true_results, carry)
return init_fn, batch_scan, cleanup
[docs]def tree_dtype_struct(pytree: PyTree):
"""Returns a tree with leaves only representing shape and type."""
@partial(partial, tree_util.tree_map)
def concrete_to_shape_struct(leaf):
shape_struct = jax.ShapeDtypeStruct(
dtype=leaf.dtype,
shape=leaf.shape)
return shape_struct
return concrete_to_shape_struct(pytree)
[docs]def tree_index(pytree: PyTree, index):
"""Indexes the leaves of the tree in the first dimension.
Args:
pytree: Tree to index with array-like leaves
index: Selects which slice to return
Returns:
Returns a tree with the same structure as pytree, but the leaves have a
dimension reduced by 1.
"""
@partial(partial, tree_util.tree_map)
def split_tree_imp(leaf):
if leaf.ndim == 1:
return leaf[index]
else:
return leaf[index, ::]
return split_tree_imp(pytree)
# Callback is independent of assembling of the batches
def _hcb_wrapper(data_loader: HostDataLoader,
cached_batches_count: int,
mb_size: int,
verify_calls: bool = False
) -> Tuple[GetBatchFunction, MiniBatchInformation]:
# These are helper function which keep a reference to the stateful data object
# and can be called via the host_callback.call function
# The format of the mini batch is static.
hcb_format, mb_information = data_loader.batch_format(
cached_batches_count, mb_size=mb_size)
mask_shape = (cached_batches_count, mb_size)
# The definition requires passing an argument to the host function. The shape
# of the returned data must be known before the first call. The chain id
# determines whether the data is collected randomly or sequentially.
def get_data(req):
chain_id, callback_uuid, token, device_count = req
# The data request class takes care of assigning the request to the right
# data loader and verifies it.
(new_data, mask), new_token = _data_requests(
chain_id, token, callback_uuid, device_count,
strict=verify_calls)
if mask is None:
# Assume all samples to be valid. It is important to perform the creation
# of the array on the host, as otherwise a deadlock will occur.
mask = onp.ones(mask_shape, dtype=jnp.bool_)
return new_data, mask, new_token
def _new_cache_fn(req: Tuple[CacheState, int]) -> CacheState:
"""This function is called if the cache must be refreshed."""
state, device_count = req
result_shape_dtypes = tree_util.tree_map(
lambda x: jax.ShapeDtypeStruct(
shape=x.shape, dtype=jax.dtypes.canonicalize_dtype(x.dtype)
),
(
hcb_format,
jax.ShapeDtypeStruct(shape=mask_shape, dtype=jnp.bool_),
JaxUUID()
)
)
new_data, masks, token = jxp.io_callback(
get_data,
result_shape_dtypes,
(state.chain_id, state.callback_uuid, state.token, device_count),
)
new_state = CacheState(
cached_batches_count=state.cached_batches_count,
cached_batches=new_data,
current_line=jnp.array(0),
chain_id=state.chain_id,
valid=masks,
callback_uuid=state.callback_uuid,
token=token)
return new_state
def _old_cache_fn(req: Tuple[CacheState, int]) -> CacheState:
"""This function is called if the cache must not be refreshed."""
return req[0]
# Necessary, because cond is replaced by select under vmap, but the cond
# branches have side effects.
@stop_vmap.stop_vmap
def _data_state_helper(data_state, device_count):
return lax.cond(data_state.current_line == data_state.cached_batches_count,
_new_cache_fn,
_old_cache_fn,
(data_state, device_count))
def batch_fn(data_state: CacheState,
information: bool = False,
device_count: int = 1
) -> Batch:
"""Draws a new random batch (hides data transfer between devices).
Args:
data_state: State with cached samples
information: Whether to return batch information
device_count: Number of parallel programs calling the batch function
Returns:
Returns the new data state and the next batch. Optionally an additional
struct containing information about the batch can be returned.
"""
if device_count > jax.device_count():
raise ValueError(f"The value of device_count cannot exceed the true "
f"device count. Expecting device_count (given: "
f"{device_count}) <= {jax.device_count()}.")
if device_count != 1:
warnings.warn("Changing the device count can cause memory accumulation. "
"Continue only if you know what you are doing.")
# Refresh the cache if necessary, after all cached batches have been used.
data_state = _data_state_helper(data_state, device_count)
current_line = jnp.mod(data_state.current_line,
data_state.cached_batches_count)
# Read the current line from the cache and add the mask containing
# information about the validity of the individual samples
mini_batch = tree_index(data_state.cached_batches, current_line)
mask = data_state.valid[current_line, :]
current_line = current_line + 1
new_state = CacheState(
cached_batches=data_state.cached_batches,
cached_batches_count=data_state.cached_batches_count,
current_line=current_line,
chain_id=data_state.chain_id,
valid=data_state.valid,
callback_uuid=data_state.callback_uuid,
token=data_state.token)
info = MiniBatchInformation(
observation_count = mb_information.observation_count,
batch_size = mb_information.batch_size,
mask = mask)
if information:
return new_state, (mini_batch, info)
else:
return new_state, mini_batch
return batch_fn, mb_information
def _random_reference_data_host(data_loader: HostDataLoader,
cached_batches_count: int = 100,
mb_size: int = 1,
verify_calls: bool = False) -> RandomBatch:
"""Random reference data access via host-callback. """
# Warn if cached_batches are bigger than total dataset
observation_count = data_loader.static_information["observation_count"]
if observation_count < cached_batches_count * mb_size:
warnings.warn("Cached batches are bigger than the total dataset. Consider "
"using a DeviceDataLoader.")
batch_fn, _ = _hcb_wrapper(
data_loader, cached_batches_count, mb_size,
verify_calls=verify_calls)
callback_uuid = JaxUUID()
_data_requests._host_data_loaders[callback_uuid.as_uuid] = data_loader
def init_fn(**kwargs) -> CacheState:
# Pass the data loader the information about the number of cached
# mini-batches. The data loader returns an unique id for reproducibility
chain_id = data_loader.register_random_pipeline(
cached_batches_count,
mb_size=mb_size,
**kwargs)
initial_state, initial_mask = _data_requests._host_data_loaders[callback_uuid.as_uuid].get_batches(chain_id)
if initial_mask is None:
initial_mask = jnp.ones((cached_batches_count, mb_size), dtype=jnp.bool_)
inital_cache_state = CacheState(
cached_batches=initial_state,
cached_batches_count=jnp.array(cached_batches_count),
current_line=jnp.array(0),
chain_id=jnp.array(chain_id),
valid=initial_mask,
callback_uuid=callback_uuid,
token=JaxUUID())
return inital_cache_state
def release():
del _data_requests._host_data_loaders[callback_uuid.as_uuid]
return init_fn, batch_fn, release
def _random_reference_data_device(data_loader: DeviceDataLoader,
mb_size: int
) -> RandomBatch:
"""Random reference data on device. """
def init_fn(**kwargs) -> CacheState:
state = data_loader.init_random_data(**kwargs)
return CacheState(state=state)
def batch_fn(state: CacheState,
information: bool = False
) -> Batch:
state, (batch, mb_info) = data_loader.get_random_data(
state.state, batch_size=mb_size)
new_state = CacheState(state=state)
if information:
return new_state, (batch, mb_info)
else:
return new_state, batch
def release():
pass
return init_fn, batch_fn, release
def _full_reference_data_host(data_loader: HostDataLoader,
cached_batches_count: int = 100,
mb_size: int = None
) -> Tuple[Callable, Tuple[Callable, MiniBatchInformation], Callable]:
"""Sequentially load batches of reference data via host-callback. """
# Warn if cached_batches are bigger than total dataset
observation_count = data_loader.static_information["observation_count"]
if observation_count < cached_batches_count * mb_size:
warnings.warn("Cached batches are bigger than the total dataset. Consider "
"using a DeviceDataLoader.")
batch_fn = _hcb_wrapper(
data_loader,
cached_batches_count,
mb_size,
verify_calls=True)
# Register the data loader
callback_uuid = JaxUUID()
_data_requests._host_data_loaders[callback_uuid.as_uuid] = data_loader
def init_fn(**kwargs) -> CacheState:
# Pass the data loader the information about the number of cached
# mini-batches. The data loader returns an unique id for reproducibility
chain_id = _data_requests._host_data_loaders[callback_uuid.as_uuid].register_ordered_pipeline(
cached_batches_count,
mb_size=mb_size,
**kwargs)
initial_state, initial_mask = _data_requests._host_data_loaders[callback_uuid.as_uuid].get_batches(chain_id)
if initial_mask is None:
initial_mask = jnp.ones((cached_batches_count, mb_size), dtype=jnp.bool_)
inital_cache_state=CacheState(
cached_batches=initial_state,
cached_batches_count=jnp.array(cached_batches_count),
current_line=jnp.array(0),
chain_id=jnp.array(chain_id),
valid=initial_mask,
callback_uuid=callback_uuid,
token=JaxUUID())
return inital_cache_state
def release():
del _data_requests._host_data_loaders[callback_uuid.as_uuid]
return init_fn, batch_fn, release
def _full_reference_data_device(data_loader: DeviceDataLoader,
mb_size: int = None
) -> Tuple[Callable,
Tuple[Callable,
MiniBatchInformation],
Callable]:
"""Batches the dataset on the device. """
reference_data = data_loader.get_full_data()
total_observations = data_loader.static_information["observation_count"]
# The information about the batches need to be static.
mb_info = MiniBatchInformation(
observation_count=total_observations,
batch_size=mb_size,
mask=onp.ones(mb_size),
)
def init_fn(offset: jnp.ndarray = 0):
if offset >= total_observations:
raise ValueError(f"The offset cannot be greater than the total "
f"observation count. Given {offset} and "
f"{total_observations}.")
init_state = CacheState(current_line=offset)
return init_state
def batch_fn(data_state: CacheState,
information: bool = False,
device_count: int = 1):
del device_count
indices = jnp.mod(jnp.arange(mb_size) + data_state.current_line,
total_observations)
# Update the offset, where to start slicing in the next iteration.
new_state = CacheState(
current_line=jnp.mod(indices[-1] + 1, total_observations))
selected_data = tree_index(reference_data, indices)
if information:
return new_state, (selected_data, mb_info)
else:
return new_state, selected_data
def release():
pass
return init_fn, (batch_fn, mb_info), release
class _FullDataHelper:
"""Class to keep track of unused CacheStates. """
def __init__(self, data_loader, cache_size, batch_size):
self._init_fn, self._map_fn, self._cleanup_fn = full_reference_data(
data_loader, cache_size, batch_size)
# Initialize the first cache state to compute the shape of the CacheState
self._unused_states = [self._new_cache_state()]
self._cache_state_format = tree_util.tree_map(
lambda leaf: jax.ShapeDtypeStruct(dtype=leaf.dtype, shape=leaf.shape),
self._unused_states[0])
self._lock = threading.Lock()
@property
def get_map_fn(self) -> FullDataMapFunction:
return self._map_fn
@property
def get_cache_state_format(self) -> PyTree:
return self._cache_state_format
def get_cache_state(self) -> CacheState:
"""Return unused cache state or create new one if none available."""
with self._lock:
if len(self._unused_states) == 0:
state = self._new_cache_state()
else:
state = self._unused_states.pop()
return state
def free_cache_state(self, cache_state: CacheState) -> Array:
"""Adds cache state back to the unused cache states."""
with self._lock:
self._unused_states.append(cache_state)
return jnp.array(1.0)
def _new_cache_state(self) -> CacheState:
"""Creates a new cache state. """
new_cache_state = self._init_fn()
return new_cache_state
def cleanup(self):
self._cleanup_fn()
self._unused_states = None
def __getstate__(self):
state = self.__dict__.copy()
del state["_lock"]
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._lock = threading.Lock()
[docs]def full_data_mapper(data_loader: DataLoader = None,
cached_batches_count: int = 1,
mb_size: int = 1
) -> Tuple[FullDataMapperFunction, Callable]:
"""Initializes a functional to map a function over a complete dataset.
This function extends the functionality of
:func:`full_reference_data` by loading the data states
form the host before each mapping.
Args:
data_loader: Reads data from storage
cached_batches_count: Number of batches in the cache. A larger number is
faster, but requires more memory
mb_size: Size of the data batch
Returns:
Returns a tuple of functions to map another function over a complete dataset
of an appropriate :class:`DataLoader` and another function to release
the data loader after the last computation.
"""
_helper = _FullDataHelper(data_loader, cached_batches_count, mb_size)
# Helper functions to load/save the CacheStates on the host via host_callback
def _get_state() -> CacheState:
cache_state = jxp.io_callback(
lambda _: _helper.get_cache_state(),
_helper.get_cache_state_format,
jnp.array(1.0)
)
return cache_state
def _free_cache_state(cache_state: CacheState, results: PyTree) -> Array:
# Loop-through the results to hinder XLA to remove the tap call
jxp.io_callback(
lambda cs: _helper.free_cache_state(cs),
jnp.array(1.0),
cache_state
)
return results
def mapper_fn(fun: Union[MaskedMappedFunction, UnmaskedMappedFunction],
carry: PyTree,
masking: bool = False,
information: bool = False,
batched: bool = True,
device_count: int = 1) -> PyTree:
data_state = _get_state()
# Batch the function if it is not batched.
if batched:
batched_fun = fun
else:
if masking:
raise ValueError("The function must be vectorized manually to allow "
"masking.")
if mb_size == 1:
# No vmapping required but first axis of all observations must be
# removed.
def batched_fun(batch, state):
squeezed_batch = tree_util.tree_map(
partial(jnp.squeeze, axis=0),
batch)
result, state = fun(squeezed_batch, state)
expanded_result = tree_util.tree_map(
partial(jnp.expand_dims, axis=0),
result)
return expanded_result, state
else:
# Only the first resulting state is passed to the next iteration
vmapped_fun = jax.vmap(fun, in_axes=(0, None))
def batched_fun(batch, state):
results, states = vmapped_fun(batch, state)
if states is None:
state = None
else:
state = tree_util.tree_map(
partial(jnp.take_along_axis, indices=0, axis=0),
states)
return results, state
(new_data_state, results) = _helper.get_map_fn(
batched_fun, data_state, carry,
masking=masking, information=information, device_count=device_count)
results = _free_cache_state(new_data_state, results)
return results
def release():
_helper.cleanup()
return mapper_fn, release