jax_sgmc.io
Save and checkpoint chains.
JaxSGMC supports saving and checkpointing inside jit-compiled functions. Saving works by combining a Data Collector with the host callback wrappers.
Data Collectors
Data Collector Interface
Collectors
- class jax_sgmc.io.MemoryCollector(save_dir=None)[source]
Stores samples entirely in RAM (numpy arrays).
The RAM is usually larger than the device array and thus allows to store a greater number of samples.
- Parameters:
save_dir – Directory to output results as numpy-npz with one file per chain. If none, the results will only be returned.
- finalize(chain_id)[source]
Waits for all writing to be finished (scheduled via host_callback).
- Parameters:
chain_id (
int
) – ID of the chain which is finished
- finished(chain_id)[source]
Returns samples after all data has been processed.
Finalize is scheduled via host_callback and finished is called in the normal python flow. Via a barrier it is possible to pause the program flow until all asynchronously saved data has been processed.
- Parameters:
chain_id – ID of chain requesting continuation of normal program flow.
- Returns:
Returns the collected samples in the original tree format but with numpy- arrays as leaves.
- class jax_sgmc.io.HDF5Collector(file)[source]
Save to hdf5 format.
This data collector supports serializing collected samples and checkpoints into the hdf5 file format. The samples are saved in a structure similar to the original pytree and can thus be viewed easily via the hdf5-viewer.
Note
This class requires that
h5py
is installed. Additional information can be found in the installation instructions.>>> import tempfile >>> tf = tempfile.TemporaryFile() >>> >>> import h5py >>> from jax_sgmc import io >>> >>> file = h5py.File(tf, "w") # Use a real file for real saving >>> data_collector = io.HDF5Collector(file) >>> saving = io.save(data_collector) >>> >>> # ... use the solver ... >>> >>> # Close the file >>> file.close()
- Parameters:
file (
File
) – hdf5 file object
- finalize(chain_id)[source]
Waits for all writing to be finished (scheduled via host_callback).
- Parameters:
chain_id (
int
) – Id of the chain which is finished
- finished(chain_id)[source]
Returns after everything has been written to the file.
Finalize is scheduled via host_callback and finished is called in the normal python flow. Via a barrier it is possible to pause the program flow until all asynchronously saved data has been processed.
- Parameters:
chain_id (
int
) – ID of chain requesting continuation of normal program flow.
- register_chain(init_sample=None, init_checkpoint=None, static_information=None)[source]
Registers a chain to save samples from.
- Parameters:
init_sample (
Optional
[Any
]) – Pytree determining tree structure and leafs of samples to be saved.init_checkpoint (
Optional
[Any
]) – Pytree determining tree structure and leafs of the states to be checkpointedstatic_information (
Optional
[Any
]) – Information about the total numbers of samples collected.
- Return type:
Saving Strategies
- jax_sgmc.io.save(data_collector=None, checkpoint_every=0)[source]
Initializes asynchronous saving of samples and checkpoints.
Accepted samples are sent to the host and processed there. This optimizes the memory usage drastically and also allows gaining insight in the data while the simulation is running.
Returns statistics and samples depending on the Data Collector. For example hdf5 can be used for samples larger than the (device-)memory. Therefore, no samples are returned. Instead, the memory collector returns the samples collected as numpy arrays.
Example usage:
>>> import jax.numpy as jnp >>> from jax.lax import scan >>> from jax_sgmc import io, scheduler >>> >>> dc = io.MemoryCollector() >>> init_save, save, postprocess_save = io.save(dc) >>> >>> def update(saving_state, it): ... saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it}) ... return saving_state >>> >>> # The information about the number of collected samples must be defined >>> # before the run >>> static_information = scheduler.static_information(samples_collected=3) >>> >>> # The saving function must now the form of the sample which should be saved >>> saving_state = init_save({'iteration': jnp.array(0)}, {}, static_information) >>> final_state, _ = scan(update, saving_state, jnp.arange(5)) >>> >>> saved_samples = postprocess_save(final_state, None) >>> print(saved_samples) {'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': array([0, 2, 4], dtype=int32)}}
- Parameters:
data_collector (
Optional
[DataCollector
]) – Stateful object for data storage and serializationcheckpoint_every (
int
) – Create a checkpoint for late resuming every n iterations
Warning
Checkpointing is currently not supported.
- jax_sgmc.io.no_save()[source]
Does not save the data on the host but return it instead.
If the samples are small, collection on the device is possible. Samples must be copied repeatedly.
Save keep every second element of a 5 element scan
>>> import jax.numpy as jnp >>> from jax.lax import scan >>> from jax_sgmc import io, scheduler >>> >>> init_save, save, postprocess_save = io.no_save() >>> >>> def update(saving_state, it): ... saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it}) ... return saving_state >>> >>> # The information about the number of collected samples must be defined >>> # before the run >>> static_information = scheduler.static_information(samples_collected=3) >>> >>> # The saving function must now the form of the sample which should be saved >>> saving_state = init_save({'iteration': jnp.array(0)}, None, static_information) >>> final_state, _ = scan(update, saving_state, jnp.arange(5)) >>> >>> saved_samples = postprocess_save(final_state, None) >>> print(saved_samples) {'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': Array([0, 2, 4], dtype=int32)}}
Pytree to Dict Transformation
- jax_sgmc.io.pytree_to_dict(tree)[source]
Constructs a dictionary from a pytree.
Transforms each node of a pytree to a dict by appling defined dictionize. New rules can be specified by the
jax_sgmc.io.register_dictionize_rule()
-decorator.>>> from jax_sgmc import io >>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}} >>> as_dict = io.pytree_to_dict(some_tree) >>> print(as_dict) {'a': 0.0, 'b': {'b1': {'list_element_0': 0.0, 'list_element_1': 0.1}, 'b2': 0.0}}
- Parameters:
tree (
Any
) – All nodes of the tree must either have no children or a registered transformation to dict.- Returns:
Returns the tree as a dict with similar structure.
- jax_sgmc.io.dict_to_pytree(pytree_as_dict, target)[source]
Restores the original tree structure given by the target from a dict.
Restores the pytree as a dict to its original tree structure. This function can also operate on subtrees, as long as the subtree (a dict) of the pytree as dict matches the subtree of the target dict.
>>> from jax_sgmc import io >>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}} >>> as_dict = io.pytree_to_dict(some_tree) >>> sub_pytree = io.dict_to_pytree(as_dict['b'], some_tree['b']) >>> print(sub_pytree) {'b1': [0.0, 0.1], 'b2': 0.0}
- jax_sgmc.io.pytree_dict_keys(tree)[source]
Returns a list of keys to acces the leaves of the tree.
- Parameters:
tree (
Any
) – Pytree as a dict- Returns:
Returns a list of tuples, where each tuple contains the keys to access the leaf of the flattened pytree in the unflattened dict.
For example:
>>> from jax import tree_leaves >>> from jax_sgmc import io >>> pytree = {"a": [0.0, 1.0], "b": 2.0} >>> pytree_as_dict = io.pytree_to_dict(pytree) >>> pytree_leaves = tree_leaves(pytree_as_dict) >>> pytree_keys = io.pytree_dict_keys(pytree_as_dict) >>> print(pytree_leaves) [0.0, 1.0, 2.0] >>> print(pytree_keys) [('a', 'list_element_0'), ('a', 'list_element_1'), ('b',)]
- jax_sgmc.io.register_dictionize_rule(type)[source]
Decorator to define new rules transforming a pytree node to a dict.
By default, transformations are defined for some default types:
list
dict
(named)tuple
Additionally, transformation for the following optional libraries are implemented:
haiku._src.data_structures.FlatMapping