# Copyright 2021 Multiscale Modeling of Fluid Materials, TU Munich
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Save and checkpoint chains.
**JaxSGMC** supports saving and checkpointing inside jit-compiled
functions. Saving works by combining a Data Collector with the host callback
wrappers.
"""
# Todo: If dataloader supports checkpointing, it must be able to checkpoint
# itself.
import abc
import itertools
from pathlib import Path
from functools import partial
import threading
from typing import Any, Union, Tuple, Callable, Type, Dict, NoReturn
from collections import namedtuple
import numpy as onp
import jax.numpy as jnp
from jax import tree_util, lax
from jax.experimental import host_callback
try:
import ujson
except ModuleNotFoundError:
ujson = None
try:
import h5py
HDF5File = h5py.File
except ModuleNotFoundError:
HDF5File = None
h5py = None
# Import haiku to register a rule for transforming FlatMapping into a dict
try:
import haiku._src.data_structures as haiku_ds
except ModuleNotFoundError:
haiku_ds = None
from jax_sgmc import data
from jax_sgmc import scheduler
from jax_sgmc.util import stop_vmap
PyTree = Any
# Global rules for translating a tree-node into a dict
_dictionize_rules: Dict[Type, Callable] = {}
[docs]def register_dictionize_rule(type: Type) -> Callable[[Callable], None]:
"""Decorator to define new rules transforming a pytree node to a dict.
By default, transformations are defined for some default types:
- list
- dict
- (named)tuple
Additionally, transformation for the following optional libraries are
implemented:
- haiku._src.data_structures.FlatMapping
Args:
type: Type (or class) of the currently undefined node
Returns:
The decorated function is not intended to be used directly.
"""
def register_function(rule: Callable):
global _dictionize_rules
assert type not in _dictionize_rules.keys(), f"Rule for {type.__name__} is " \
f"already defined."
_dictionize_rules[type] = rule
return register_function
def _default_dictionize(node):
leaves = tree_util.tree_leaves(node)
leaf_names = [f"unknown~pytree~leaf~{idx}" for idx in range(len(leaves))]
return zip(leaf_names, leaves)
def _dictionize(node):
"""Apply the registered rules to a pytree node."""
global _dictionize_rules
# If there are no leaves, then a dictionize rule is not required
if len(tree_util.tree_leaves(node)) == 0:
return []
for node_type, node_rule in _dictionize_rules.items():
if isinstance(node, node_type):
return node_rule(node)
return _default_dictionize(node)
[docs]def pytree_to_dict(tree: PyTree):
"""Constructs a dictionary from a pytree.
Transforms each node of a pytree to a dict by appling defined dictionize. New
rules can be specified by the :func:`jax_sgmc.io.register_dictionize_rule`
-decorator.
.. doctest::
>>> from jax_sgmc import io
>>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}}
>>> as_dict = io.pytree_to_dict(some_tree)
>>> print(as_dict)
{'a': 0.0, 'b': {'b1': {'list_element_0': 0.0, 'list_element_1': 0.1}, 'b2': 0.0}}
Args:
tree: All nodes of the tree must either have no children or a registered
transformation to dict.
Returns:
Returns the tree as a dict with similar structure.
"""
global _dictionize_rules
pytree_def = tree_util.tree_structure(tree)
if tree_util.treedef_is_leaf(pytree_def):
return tree
else:
return {key: pytree_to_dict(childs)
for key, childs
in _dictionize(tree)}
[docs]def pytree_dict_keys(tree: PyTree):
"""Returns a list of keys to acces the leaves of the tree.
Args:
tree: Pytree as a dict
Returns:
Returns a list of tuples, where each tuple contains the keys to access the
leaf of the flattened pytree in the unflattened dict.
For example:
.. doctest::
>>> from jax import tree_leaves
>>> from jax_sgmc import io
>>> pytree = {"a": [0.0, 1.0], "b": 2.0}
>>> pytree_as_dict = io.pytree_to_dict(pytree)
>>> pytree_leaves = tree_leaves(pytree_as_dict)
>>> pytree_keys = io.pytree_dict_keys(pytree_as_dict)
>>> print(pytree_leaves)
[0.0, 1.0, 2.0]
>>> print(pytree_keys)
[('a', 'list_element_0'), ('a', 'list_element_1'), ('b',)]
"""
node = pytree_to_dict(tree)
leaves, treedef = tree_util.tree_flatten(node)
idx_tree = tree_util.tree_unflatten(treedef, list(range(len(leaves))))
key_list = [None] * len(leaves)
def _recurse(node, path):
if node is None:
return
elif isinstance(node, int):
key_list[node] = path
else:
for key, value in node.items():
_recurse(value, path + [key])
_recurse(idx_tree, [])
return [tuple(key) for key in key_list]
[docs]def dict_to_pytree(pytree_as_dict: dict, target: PyTree):
"""Restores the original tree structure given by the target from a dict.
Restores the pytree as a dict to its original tree structure. This function
can also operate on subtrees, as long as the subtree (a dict) of the pytree
as dict matches the subtree of the target dict.
.. doctest::
>>> from jax_sgmc import io
>>> some_tree = {'a': 0.0, 'b': {'b1': [0.0, 0.1], 'b2': 0.0}}
>>> as_dict = io.pytree_to_dict(some_tree)
>>> sub_pytree = io.dict_to_pytree(as_dict['b'], some_tree['b'])
>>> print(sub_pytree)
{'b1': [0.0, 0.1], 'b2': 0.0}
Args:
pytree_as_dict: A pytree which has been transformed to a dict of dicts.
target: A pytree defining the original tree structure.
"""
target_dict_keys = pytree_dict_keys(target)
tree_structure = tree_util.tree_structure(target)
def _recurse_get(key_list):
key_list = list(key_list)
element = pytree_as_dict
while len(key_list) > 0:
element = element.get(key_list.pop(0))
return element
new_leaves = map(_recurse_get, target_dict_keys)
return tree_util.tree_unflatten(tree_structure, new_leaves)
@register_dictionize_rule(dict)
def _dict_to_dict(some_dict: dict):
return some_dict.items()
@register_dictionize_rule(list)
def _list_to_dict(some_list: list):
return ((f"list_element_{idx}", value) for idx, value in enumerate(some_list))
@register_dictionize_rule(tuple)
def _namedtuple_to_dict(some_tuple: Union[tuple, namedtuple]):
if hasattr(some_tuple, '_fields'):
return some_tuple._asdict().items()
else:
return ((f"list_element_{idx}", value) for idx, value in enumerate(some_tuple))
if haiku_ds is not None:
@register_dictionize_rule(haiku_ds.FlatMapping)
def _flat_mapping_to_dict(flat_mapping: haiku_ds.FlatMapping):
return haiku_ds.to_immutable_dict(flat_mapping).items()
saving_state = namedtuple("saving_state",
["chain_id",
"saved_samples",
"data"])
Saving = Tuple[Callable[[PyTree, PyTree, PyTree], saving_state],
Callable[[saving_state,
jnp.bool_,
PyTree,
PyTree,
PyTree],
Union[Any, NoReturn]],
Callable[[Any], Union[Any, NoReturn]]]
[docs]class DataCollector(metaclass=abc.ABCMeta):
"""Collects sampled data and data loader states. """
[docs] @abc.abstractmethod
def register_data_loader(self, data_loader: data.DataLoader):
"""Register data loader to save the state. """
[docs] @abc.abstractmethod
def register_chain(self,
init_sample: PyTree = None,
init_checkpoint: PyTree = None,
static_information: PyTree = None
) -> int:
"""Registers a chain to save samples from.
Args:
init_sample: Determining shape, dtype and tree structure of sample.
init_checkpoint: Determining shape, dtype and tree structure of solver
state to enable checkpointing.
static_information: Information about the total number of collected
samples
Returns:
Returns id of the new chain.
"""
[docs] @abc.abstractmethod
def save(self, chain_id: int, values):
"""Called with collected samples. """
[docs] @abc.abstractmethod
def finalize(self, chain_id: int):
"""Called after solver finished. """
[docs] @abc.abstractmethod
def finished(self, chain_id: int):
"""Called in main thread after jax threads have been released."""
[docs] @abc.abstractmethod
def checkpoint(self, chain_id: int, state):
"""Called every nth step. """
[docs] @abc.abstractmethod
def resume(self):
"""Called to restore data loader state and return sample states. """
# Todo: Maybe remove? Not usable for big data -> better hdf5
class JSONCollector(DataCollector):
"""Saves samples in json format.
Args:
dir: Directory to save the collected samples.
write_frequency: Number of samples to be collected before a file is written
"""
def __init__(self, dir: str, write_frequency=0):
assert ujson is not None, "ujson is required to save samples to json files"
self._dir = Path(dir)
self._dir.mkdir(exist_ok=True)
self._collected_samples = []
self._sample_count = []
self._write_frequency = write_frequency
def register_data_loader(self, data_loader: data.DataLoader):
"""Not supported for JSON collector. """
raise NotImplementedError("Checkpointing is not supported by JSON loader.")
def register_chain(self,
init_sample: PyTree = None,
init_checkpoint: PyTree = None,
static_information: PyTree = None
) -> int:
"""Register a chain to save samples from. """
new_chain = len(self._collected_samples)
self._collected_samples.append([])
self._sample_count.append(0)
return new_chain
def save(self, chain_id: int, values):
"""Called with collected samples. """
# Store new values
self._collected_samples[chain_id].append(
tree_util.tree_map(onp.array, values)
)
self._sample_count[chain_id] += 1
# Write to file but keep collected samples in memory
if self._write_frequency > 0:
if (self._sample_count[chain_id] % self._write_frequency) == 0:
self._write_file(chain_id, self._sample_count[chain_id])
def finalize(self, chain_id: int):
"""Called after solver finished. """
if self._sample_count[chain_id] > 0:
self._write_file(chain_id, self._sample_count[chain_id])
def checkpoint(self, chain_id: int, state):
"""Called every nth step. """
raise NotImplementedError("Checkpointing is not supported by JSON loader.")
def resume(self):
"""Called to restore data loader state and return sample states. """
raise NotImplementedError("Checkpointing is not supported by JSON loader.")
def _write_file(self, chain_id, iteration):
filename = self._dir / f"chain_{chain_id}_iteration_{iteration}.json"
filename.touch()
# Transform collected samples to list after chaining together
stacked_samples = tree_util.tree_map(
lambda *samples: onp.stack(samples, axis=0),
*self._collected_samples[chain_id]
)
samples_as_list = tree_util.tree_map(
lambda leaf: leaf.tolist(),
stacked_samples
)
with open(filename, "w") as file:
ujson.dump(samples_as_list, file) # pylint: disable=c-extension-no-member
def finished(self, chain_id: int):
"""Simply return, nothing to wait for. """
return None
[docs]class HDF5Collector(DataCollector):
"""Save to hdf5 format.
This data collector supports serializing collected samples and checkpoints
into the hdf5 file format. The samples are saved in a structure similar to the
original pytree and can thus be viewed easily via the hdf5-viewer.
Note:
This class requires that ``h5py`` is installed. Additional information
can be found in the :ref:`installation instructions<additional_requirements>`.
.. doctest::
>>> import tempfile
>>> tf = tempfile.TemporaryFile()
>>>
>>> import h5py
>>> from jax_sgmc import io
>>>
>>> file = h5py.File(tf, "w") # Use a real file for real saving
>>> data_collector = io.HDF5Collector(file)
>>> saving = io.save(data_collector)
>>>
>>> # ... use the solver ...
>>>
>>> # Close the file
>>> file.close()
Args:
file: hdf5 file object
"""
def __init__(self, file: HDF5File):
assert h5py is not None, "h5py must be installed to use this DataCollector."
# Barrier to wait until all data has been processed
self._finished = []
self._sample_count = []
self._leaf_names = []
self._file = file
[docs] def register_data_loader(self, data_loader: data.DataLoader):
"""Registers data loader to save the state. """
[docs] def register_chain(self,
init_sample: PyTree = None,
init_checkpoint: PyTree = None,
static_information: PyTree = None
) -> int:
"""Registers a chain to save samples from.
Args:
init_sample: Pytree determining tree structure and leafs of samples to be
saved.
init_checkpoint: Pytree determining tree structure and leafs of the states
to be checkpointed
static_information: Information about the total numbers of samples
collected.
"""
assert init_sample is not None, "Need a sample-pytree to allocate memory."
assert init_checkpoint is not None, "Need a checkpoint-pytree to allocate " \
"memory."
chain_id = len(self._finished)
# Save the pytree as it would be transformed to a dict by pytree_to_dict
leaf_names = ["/".join(itertools.chain([f"/chain~{chain_id}"], key_tuple))
for key_tuple in pytree_dict_keys(init_sample)]
leaves = tree_util.tree_leaves(init_sample)
# Allocate memory upfront, this is more effective than the chunk-wise
# allocation
for leaf_name, leaf in zip(leaf_names, leaves):
new_shape = tuple(int(s) for s in
itertools.chain([static_information.samples_collected],
leaf.shape))
self._file.create_dataset(
leaf_name,
shape=new_shape,
dtype=leaf.dtype)
self._sample_count.append(0)
self._leaf_names.append(leaf_names)
self._finished.append(threading.Barrier(2))
return chain_id
[docs] def save(self, chain_id: int, values):
"""Saves new leaves to dataset.
Args:
chain_id: ID from register_chain
values: Tree leaves of sample
"""
for leaf_name, value in zip(self._leaf_names[chain_id], values):
self._file[leaf_name][self._sample_count[chain_id]] = value
self._sample_count[chain_id] += 1
[docs] def finalize(self, chain_id: int):
"""Waits for all writing to be finished (scheduled via host_callback).
Args:
chain_id: Id of the chain which is finished
"""
# Is called after all host callback calls have been processed
# self._finished[chain_id].wait()
[docs] def checkpoint(self, chain_id: int, state):
"""Called every nth step. """
[docs] def resume(self):
"""Called to restore data loader state and return sample states. """
[docs] def finished(self, chain_id: int):
"""Returns after everything has been written to the file.
Finalize is scheduled via host_callback and finished is called in the normal
python flow. Via a barrier it is possible to pause the program flow until
all asynchronously saved data has been processed.
Args:
chain_id: ID of chain requesting continuation of normal program flow.
"""
# self._finished[chain_id].wait()
[docs]class MemoryCollector(DataCollector):
"""Stores samples entirely in RAM (numpy arrays).
The RAM is usually larger than the device array and thus allows to store a
greater number of samples.
Args:
save_dir: Directory to output results as numpy-npz with one file per chain.
If none, the results will only be returned.
"""
def __init__(self, save_dir=None):
self._finished = []
self._samples = []
self._samples_count = []
self._treedefs = []
self._leafnames = []
self._dir = save_dir
[docs] def register_data_loader(self, data_loader: data.DataLoader):
"""Registers data loader to save the state. """
[docs] def register_chain(self,
init_sample: PyTree = None,
static_information = None,
**unused_kwargs) -> int:
"""Registers a chain to save samples from.
Args:
init_sample: Pytree determining tree structure and leafs of samples to be
saved.
static_information: Information about the total numbers of samples
collected.
"""
assert init_sample is not None, "Need a sample-pytree to allocate memory."
chain_id = len(self._finished)
leaves, treedef = tree_util.tree_flatten(init_sample)
def leaf_shape(leaf):
new_shape = onp.append(
static_information.samples_collected,
leaf.shape)
new_shape = tuple(int(s) for s in new_shape)
return new_shape
sample_cache = [onp.zeros(leaf_shape(leaf), dtype=leaf.dtype) for leaf in leaves]
# Only generate the keys for each leaf if necessary
if self._dir:
pytree_keys = pytree_dict_keys(init_sample)
else:
pytree_keys = [f"leaf~{idx}" for idx in range(len(leaves))]
self._finished.append(threading.Lock())
self._finished[chain_id].acquire()
self._samples.append(sample_cache)
self._treedefs.append(treedef)
self._samples_count.append(0)
self._leafnames.append(["/".join(key_tuple) for key_tuple in pytree_keys])
return chain_id
[docs] def save(self, chain_id: int, values):
"""Saves new leaves to dataset.
Args:
chain_id: ID from register_chain
values: Tree leaves of sample
"""
sample_cache = self._samples[chain_id]
for leaf, value in zip(sample_cache, values):
leaf[self._samples_count[chain_id]] = value
self._samples_count[chain_id] += 1
[docs] def finalize(self, chain_id: int):
"""Waits for all writing to be finished (scheduled via host_callback).
Args:
chain_id: ID of the chain which is finished
"""
# Is called after all host callback calls have been processed
self._finished[chain_id].release()
if self._dir:
output_dir = Path(self._dir)
output_file = output_dir / f"chain_{chain_id}.npz"
output_dir.mkdir(exist_ok=True)
onp.savez(
output_file,
**dict(zip(self._leafnames[chain_id], self._samples[chain_id])))
[docs] def finished(self, chain_id):
"""Returns samples after all data has been processed.
Finalize is scheduled via host_callback and finished is called in the normal
python flow. Via a barrier it is possible to pause the program flow until
all asynchronously saved data has been processed.
Args:
chain_id: ID of chain requesting continuation of normal program flow.
Returns:
Returns the collected samples in the original tree format but with numpy-
arrays as leaves.
"""
self._finished[chain_id].acquire()
self._finished[chain_id].release()
# Restore original tree shape
return tree_util.tree_unflatten(
self._treedefs[chain_id],
self._samples[chain_id])
[docs] def checkpoint(self, chain_id: int, state):
"""Called every nth step. """
[docs] def resume(self):
"""Called to restore data loader state and return sample states. """
def load(init_state, checkpoint):
"""Reconstructs an earlier checkpoint."""
raise NotImplementedError("Checkpointing is currently not supported.")
[docs]def save(data_collector: DataCollector = None,
checkpoint_every: int = 0
) -> Saving:
"""Initializes asynchronous saving of samples and checkpoints.
Accepted samples are sent to the host and processed there. This optimizes the
memory usage drastically and also allows gaining insight in the data while the
simulation is running.
Returns statistics and samples depending on the Data Collector. For
example hdf5 can be used for samples larger than the (device-)memory.
Therefore, no samples are returned. Instead, the memory collector returns
the samples collected as numpy arrays.
Example usage:
.. doctest::
>>> import jax.numpy as jnp
>>> from jax.lax import scan
>>> from jax_sgmc import io, scheduler
>>>
>>> dc = io.MemoryCollector()
>>> init_save, save, postprocess_save = io.save(dc)
>>>
>>> def update(saving_state, it):
... saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it})
... return saving_state
>>>
>>> # The information about the number of collected samples must be defined
>>> # before the run
>>> static_information = scheduler.static_information(samples_collected=3)
>>>
>>> # The saving function must now the form of the sample which should be saved
>>> saving_state = init_save({'iteration': jnp.array(0)}, {}, static_information)
>>> final_state, _ = scan(update, saving_state, jnp.arange(5))
>>>
>>> saved_samples = postprocess_save(final_state, None)
>>> print(saved_samples)
{'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': array([0, 2, 4], dtype=int32)}}
Args:
data_collector: Stateful object for data storage and serialization
checkpoint_every: Create a checkpoint for late resuming every n iterations
Warning:
Checkpointing is currently not supported.
Returns:
Returns a saving strategy.
"""
# Todo: Implement checkpointing
if checkpoint_every != 0:
raise NotImplementedError("Checkpointing is not supported yet.")
# Helper functions for host_callback
def _save(data, *unused_args):
chain_ids, data = data
# id_tap sends batched arguments to the host
if chain_ids.ndim == 0:
data_collector.save(chain_ids, data)
else:
for idx, chain_id in enumerate(chain_ids):
data_collector.save(chain_id, [leaf[idx] for leaf in data])
def _save_wrapper(args) -> int:
# Use the result to count the number of saved samples. The result must be
# used to avoid losing the call to Jax's optimizations.
# Only return the leaves, as tree structure is redundant and requires
# flattening on the host.
chain_id, data = args
flat_args = (chain_id, tree_util.tree_leaves(data))
counter = host_callback.id_tap(_save, flat_args, result=1)
return counter
def init(init_sample, init_checkpoint, static_information) -> saving_state:
"""Initializes the saving state.
Args:
init_sample: Determining shape and dtype of collected samples
init_checkpoint: Determining shape and dtype of checkpointed states
static_information: Information about e. g. the total count of samples
collected.
Returns:
Returns initial state.
"""
chain_id = data_collector.register_chain(
init_sample=init_sample,
init_checkpoint=init_checkpoint,
static_information=static_information)
# The count of saved samples is important to ensure that the callback
# function is not removed by jax's optimization procedures.
initial_state = saving_state(
chain_id=chain_id,
saved_samples=0,
data=None)
return initial_state
@stop_vmap.stop_vmap
def _save_helper(keep, state, sample):
return lax.cond(keep,
_save_wrapper,
lambda *args: 0,
(state.chain_id, sample))
# Todo: Generalize the saving by contracting the scheduler state and the
# solver state to a single checkpointing state.
def save(state: saving_state,
keep: jnp.bool_,
sample: Any,
scheduler_state: scheduler.scheduler_state = None,
solver_state: Any = None):
"""Calls the data collector on the host via host callback module."""
# Save sample if samples is not subject to burn in or discarded by thinning
saved = _save_helper(keep, state, sample)
# Todo: Implement checkpointing
# last_checkpoint = lax.cond(time_for_checkpoint,
# _checkpoint_wrapper,
# lambda *args: last_checkpoint,
# (scheduler_state, soler_state)
new_state = saving_state(
chain_id=state.chain_id,
saved_samples=state.saved_samples + saved,
data=None)
return new_state, None
def postprocess(state: saving_state, unused_saved):
# Call with host callback to ensure that finalized is called after all other
# id_tap processes were finished.
host_callback.id_tap(
lambda id, *unused: data_collector.finalize(id),
state.chain_id)
collected_samples = data_collector.finished(state.chain_id)
return {"sample_count": state.saved_samples,
"samples": collected_samples}
return init, save, postprocess
# Todo: Removed the sample collection, consider removing last argument in
# postprocess function
[docs]def no_save() -> Saving:
"""Does not save the data on the host but return it instead.
If the samples are small, collection on the device is possible. Samples must
be copied repeatedly.
Save keep every second element of a 5 element scan
.. doctest::
>>> import jax.numpy as jnp
>>> from jax.lax import scan
>>> from jax_sgmc import io, scheduler
>>>
>>> init_save, save, postprocess_save = io.no_save()
>>>
>>> def update(saving_state, it):
... saving_state = save(saving_state, jnp.mod(it, 2) == 0, {'iteration': it})
... return saving_state
>>>
>>> # The information about the number of collected samples must be defined
>>> # before the run
>>> static_information = scheduler.static_information(samples_collected=3)
>>>
>>> # The saving function must now the form of the sample which should be saved
>>> saving_state = init_save({'iteration': jnp.array(0)}, None, static_information)
>>> final_state, _ = scan(update, saving_state, jnp.arange(5))
>>>
>>> saved_samples = postprocess_save(final_state, None)
>>> print(saved_samples)
{'sample_count': Array(3, dtype=int32, weak_type=True), 'samples': {'iteration': Array([0, 2, 4], dtype=int32)}}
Returns:
Returns a saving strategy, which keeps the samples entirely in the
device's memory.
"""
def init(init_sample: PyTree,
unused_solver_state: PyTree,
static_information: PyTree
) -> saving_state:
"""Initializes the saving state.
Args:
init_sample: Determining shape and dtype of collected samples
init_checkpoint: Determining shape and dtype of checkpointed states
static_information: Information about e. g. the total count of samples
collected.
Returns:
Returns initial state.
"""
def init_zeros(leaf):
shape = leaf.shape
new_shape = tuple(onp.append(static_information.samples_collected, shape))
new_shape = tree_util.tree_map(int, new_shape)
return jnp.zeros(new_shape, dtype=leaf.dtype)
init_data = tree_util.tree_map(init_zeros, init_sample)
# The chain id is unnecessary
return saving_state(chain_id=0, saved_samples=0, data=init_data)
def _update_data_leaf(idx, data_leaf, new_slice):
return data_leaf.at[idx].set(new_slice)
def _save_sample(args):
state, sample = args
new_data = tree_util.tree_map(
partial(_update_data_leaf, state.saved_samples),
state.data,
sample)
new_state = saving_state(
chain_id=state.chain_id,
saved_samples=state.saved_samples + 1,
data=new_data)
return new_state
@stop_vmap.stop_vmap
def _save_helper(keep, state, sample):
return lax.cond(keep,
_save_sample,
lambda args: args[0],
(state, sample))
def save(state: saving_state,
keep: jnp.bool_,
sample: Any,
**unused_kwargs: Any
) -> Any:
"""Determines whether a sample should be saved.
A sample will be saved it is not subject to burn in and not discarded due
to thinning.
"""
new_state = _save_helper(keep, state, sample)
return new_state, None
def postprocess(state: saving_state, unused_saved):
"""Return the accepted samples. """
return {"sample_count": state.saved_samples,
"samples": state.data}
return init, save, postprocess