jax_sgmc.io

Save and checkpoint chains.

JaxSGMC supports saving and checkpointing inside jit-compiled functions. Saving works by combining a Data Collector with the host callback wrappers.

Data Collectors

Data Collector Interface

class jax_sgmc.io.DataCollector[source]

Collects sampled data and data loader states.

abstract checkpoint(chain_id, state)[source]

Called every nth step.

abstract finalize(chain_id)[source]

Called after solver finished.

abstract finished(chain_id)[source]

Called in main thread after jax threads have been released.

abstract register_chain(init_sample=None, init_checkpoint=None, static_information=None)[source]

Registers a chain to save samples from.

Parameters:
  • init_sample (Optional[Any]) – Determining shape, dtype and tree structure of sample.

  • init_checkpoint (Optional[Any]) – Determining shape, dtype and tree structure of solver state to enable checkpointing.

  • static_information (Optional[Any]) – Information about the total number of collected samples

Return type:

int

Returns:

Returns id of the new chain.

abstract register_data_loader(data_loader)[source]

Register data loader to save the state.

abstract resume()[source]

Called to restore data loader state and return sample states.

abstract save(chain_id, values)[source]

Called with collected samples.

Collectors

class jax_sgmc.io.MemoryCollector(save_dir=None)[source]

Stores samples entirely in RAM (numpy arrays).

The RAM is usually larger than the device array and thus allows to store a greater number of samples.

Parameters:

save_dir – Directory to output results as numpy-npz with one file per chain. If none, the results will only be returned.

checkpoint(chain_id, state)[source]

Called every nth step.

finalize(chain_id)[source]

Waits for all writing to be finished (scheduled via host_callback).

Parameters:

chain_id (int) – ID of the chain which is finished

finished(chain_id)[source]

Returns samples after all data has been processed.

Finalize is scheduled via host_callback and finished is called in the normal python flow. Via a barrier it is possible to pause the program flow until all asynchronously saved data has been processed.

Parameters:

chain_id – ID of chain requesting continuation of normal program flow.

Returns:

Returns the collected samples in the original tree format but with numpy- arrays as leaves.

register_chain(init_sample=None, static_information=None, **unused_kwargs)[source]

Registers a chain to save samples from.

Parameters:
  • init_sample (Optional[Any]) – Pytree determining tree structure and leafs of samples to be saved.

  • static_information – Information about the total numbers of samples collected.

Return type:

int

register_data_loader(data_loader)[source]

Registers data loader to save the state.

resume()[source]

Called to restore data loader state and return sample states.

save(chain_id, values)[source]

Saves new leaves to dataset.

Parameters:
  • chain_id (int) – ID from register_chain

  • values – Tree leaves of sample

class jax_sgmc.io.HDF5Collector(file)[source]

Save to hdf5 format.

This data collector supports serializing collected samples and checkpoints into the hdf5 file format. The samples are saved in a structure similar to the original pytree and can thus be viewed easily via the hdf5-viewer.

Note

This class requires that h5py is installed. Additional information can be found in the installation instructions.

>>> import tempfile
>>> tf = tempfile.TemporaryFile()
>>>
>>> import h5py
>>> from jax_sgmc import io
>>>
>>> file = h5py.File(tf, "w") # Use a real file for real saving
>>> data_collector = io.HDF5Collector(file)
>>> saving = io.save(data_collector)
>>>
>>> # ... use the solver ...
>>>
>>> # Close the file
>>> file.close()
Parameters:

file (File) – hdf5 file object

checkpoint(chain_id, state)[source]

Called every nth step.

finalize(chain_id)[source]

Waits for all writing to be finished (scheduled via host_callback).

Parameters:

chain_id (int) – Id of the chain which is finished

finished(chain_id)[source]

Returns after everything has been written to the file.

Finalize is scheduled via host_callback and finished is called in the normal python flow. Via a barrier it is possible to pause the program flow until all asynchronously saved data has been processed.

Parameters:

chain_id (int) – ID of chain requesting continuation of normal program flow.

register_chain(init_sample=None, init_checkpoint=None, static_information=None)[source]

Registers a chain to save samples from.

Parameters:
  • init_sample (Optional[Any]) – Pytree determining tree structure and leafs of samples to be saved.

  • init_checkpoint (Optional[Any]) – Pytree determining tree structure and leafs of the states to be checkpointed

  • static_information (Optional[Any]) – Information about the total numbers of samples collected.

Return type:

int

register_data_loader(data_loader)[source]

Registers data loader to save the state.

resume()[source]

Called to restore data loader state and return sample states.

save(chain_id, values)[source]

Saves new leaves to dataset.

Parameters:
  • chain_id (int) – ID from register_chain

  • values – Tree leaves of sample

Saving Strategies

jax_sgmc.io.save(data_collector=None, checkpoint_every=0)[source]

Initializes asynchronous saving of samples and checkpoints.

Accepted samples are sent to the host and processed there. This optimizes the memory usage drastically and also allows gaining insight in the data while the simulation is running.

Returns statistics and samples depending on the Data Collector. For example hdf5 can be used for samples larger than the (device-)memory. Therefore, no samples are returned. Instead, the memory collector returns the samples collected as numpy arrays.

Example usage:

>>> import jax.numpy as jnp
>>> from jax.lax import scan
>>> from jax_sgmc import io, scheduler
>>>
>>> dc = io.MemoryCollector()
>>> init_save, save, postprocess_save = io.save(dc)
>>>
>>> def update(saving_state, it):
...   saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it})
...   return saving_state
>>>
>>> # The information about the number of collected samples must be defined
>>> # before the run
>>> static_information = scheduler.static_information(samples_collected=3)
>>>
>>> # The saving function must now the form of the sample which should be saved
>>> saving_state = init_save({'iteration': jnp.array(0)}, {}, static_information)
>>> final_state, _ = scan(update, saving_state, jnp.arange(5))
>>>
>>> saved_samples = postprocess_save(final_state, None)
>>> print(saved_samples)
{'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': array([0, 2, 4], dtype=int32)}}
Parameters:
  • data_collector (Optional[DataCollector]) – Stateful object for data storage and serialization

  • checkpoint_every (int) – Create a checkpoint for late resuming every n iterations

Warning

Checkpointing is currently not supported.

Return type:

Tuple[Callable[[Any, Any, Any], saving_state], Callable[[saving_state, bool_, Any, Any, Any], Union[Any, NoReturn]], Callable[[Any], Union[Any, NoReturn]]]

Returns:

Returns a saving strategy.

jax_sgmc.io.no_save()[source]

Does not save the data on the host but return it instead.

If the samples are small, collection on the device is possible. Samples must be copied repeatedly.

Save keep every second element of a 5 element scan

>>> import jax.numpy as jnp
>>> from jax.lax import scan
>>> from jax_sgmc import io, scheduler
>>>
>>> init_save, save, postprocess_save = io.no_save()
>>>
>>> def update(saving_state, it):
...   saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it})
...   return saving_state
>>>
>>> # The information about the number of collected samples must be defined
>>> # before the run
>>> static_information = scheduler.static_information(samples_collected=3)
>>>
>>> # The saving function must now the form of the sample which should be saved
>>> saving_state = init_save({'iteration': jnp.array(0)}, None, static_information)
>>> final_state, _ = scan(update, saving_state, jnp.arange(5))
>>>
>>> saved_samples = postprocess_save(final_state, None)
>>> print(saved_samples)
{'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': Array([0, 2, 4], dtype=int32)}}
Return type:

Tuple[Callable[[Any, Any, Any], saving_state], Callable[[saving_state, bool_, Any, Any, Any], Union[Any, NoReturn]], Callable[[Any], Union[Any, NoReturn]]]

Returns:

Returns a saving strategy, which keeps the samples entirely in the device’s memory.

Pytree to Dict Transformation

jax_sgmc.io.pytree_to_dict(tree)[source]

Constructs a dictionary from a pytree.

Transforms each node of a pytree to a dict by appling defined dictionize. New rules can be specified by the jax_sgmc.io.register_dictionize_rule() -decorator.

>>> from jax_sgmc import io
>>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}}
>>> as_dict = io.pytree_to_dict(some_tree)
>>> print(as_dict)
{'a': 0.0, 'b': {'b1': {'list_element_0': 0.0, 'list_element_1': 0.1}, 'b2': 0.0}}
Parameters:

tree (Any) – All nodes of the tree must either have no children or a registered transformation to dict.

Returns:

Returns the tree as a dict with similar structure.

jax_sgmc.io.dict_to_pytree(pytree_as_dict, target)[source]

Restores the original tree structure given by the target from a dict.

Restores the pytree as a dict to its original tree structure. This function can also operate on subtrees, as long as the subtree (a dict) of the pytree as dict matches the subtree of the target dict.

>>> from jax_sgmc import io
>>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}}
>>> as_dict = io.pytree_to_dict(some_tree)
>>> sub_pytree = io.dict_to_pytree(as_dict['b'], some_tree['b'])
>>> print(sub_pytree)
{'b1': [0.0, 0.1], 'b2': 0.0}
Parameters:
  • pytree_as_dict (dict) – A pytree which has been transformed to a dict of dicts.

  • target (Any) – A pytree defining the original tree structure.

jax_sgmc.io.pytree_dict_keys(tree)[source]

Returns a list of keys to acces the leaves of the tree.

Parameters:

tree (Any) – Pytree as a dict

Returns:

Returns a list of tuples, where each tuple contains the keys to access the leaf of the flattened pytree in the unflattened dict.

For example:

>>> from jax import tree_leaves
>>> from jax_sgmc import io
>>> pytree = {"a": [0.0, 1.0], "b": 2.0}
>>> pytree_as_dict = io.pytree_to_dict(pytree)
>>> pytree_leaves = tree_leaves(pytree_as_dict)
>>> pytree_keys = io.pytree_dict_keys(pytree_as_dict)
>>> print(pytree_leaves)
[0.0, 1.0, 2.0]
>>> print(pytree_keys)
[('a', 'list_element_0'), ('a', 'list_element_1'), ('b',)]

jax_sgmc.io.register_dictionize_rule(type)[source]

Decorator to define new rules transforming a pytree node to a dict.

By default, transformations are defined for some default types:

  • list

  • dict

  • (named)tuple

Additionally, transformation for the following optional libraries are implemented:

  • haiku._src.data_structures.FlatMapping

Parameters:

type (Type) – Type (or class) of the currently undefined node

Return type:

Callable[[Callable], None]

Returns:

The decorated function is not intended to be used directly.