Source code for jax_sgmc.data.hdf5_loader

# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Use samples saved with :class:`jax_sgmc.io.HDF5Collector` as reference data. """

import itertools
from typing import Any

import h5py
import numpy as onp
import jax.numpy as jnp
import jax

from jax_sgmc.data.numpy_loader import NumpyDataLoader
from jax_sgmc.io import pytree_dict_keys

PyTree = Any

# Inherit from NumpyDataLoader because slicing of arrays is similar
[docs]class HDF5Loader(NumpyDataLoader): """Load reference data from HDF5-files. This data loader can load reference data stored in HDF5 files. This makes it possible to use the :mod:`jax_sgmc.data` module to evaluate samples saved via the :class:`jax_sgmc.io.HDF5Collector`. Args: file: Path to the HDF5 file containing the reference data subdir: Path to the subset of the data set which should be loaded sample: PyTree to specify the original shape of the sub-pytree before it has been saved by the :class:`jax_sgmc.io.HDF5Collector` """ def __init__(self, file, subdir="/chain~0/variables/", sample=None): # The sample is necessary to return the observations in the correct format. super().__init__() if isinstance(file, h5py.File): self._dataset = file else: self._dataset = h5py.File(name=file, mode="r") self._reference_data = ["/".join(itertools.chain([subdir], key_tuple)) for key_tuple in pytree_dict_keys(sample)] self._pytree_structure = jax.tree_structure(sample) self._sample_format = jax.tree_map( lambda leaf: jax.ShapeDtypeStruct(shape=leaf.shape, dtype=leaf.dtype), sample) observations_counts = [len(self._dataset[leaf_name]) for leaf_name in self._reference_data] self._observation_count = observations_counts[0] def get_batches(self, chain_id: int) -> PyTree: """Draws a batch from a chain. Args: chain_id: ID of the chain, which holds the information about the form of the batch and the process of assembling. Returns: Returns a superbatch as registered by :func:`register_random_pipeline` or :func:`register_ordered_pipeline` with `cache_size` batches holding `mb_size` observations. """ # Data slicing is the same for all methods of random and ordered access, # only the indices for slicing differ. The method _get_indices find the # correct method for the chain. selections_idx, selections_mask = self._get_indices(chain_id) select_unique_idx = [onp.unique(batch_idx, return_inverse=True) for batch_idx in selections_idx] # Slice the data and transform into pytree selected_observations = [] for leaf_name in self._reference_data: unique_selections = [jnp.array(self._dataset[leaf_name][batch_idx]) for batch_idx, select_unique in select_unique_idx] selected_observations.append([unique[restore_idx] for unique, (_, restore_idx) in zip(unique_selections, select_unique_idx)]) selected_observations = [jnp.array(leaf) for leaf in selected_observations] selected_observations = jax.tree_unflatten(self._pytree_structure, selected_observations) return selected_observations, jnp.array(selections_mask, dtype=jnp.bool_) def save_state(self, chain_id: int) -> PyTree: raise NotImplementedError("Saving of the DataLoader state is not supported.") def load_state(self, chain_id: int, data) -> None: raise NotImplementedError("Loading of the DataLoader state is not supported.") @property def _format(self): """Returns shape and dtype of a single observation. """ return self._sample_format @property def static_information(self): """Returns information about total samples count and batch size. """ information = { "observation_count": self._observation_count } return information def close(self): self._dataset.close()