jax-sgmc
latest
Getting Started
Installation
Quickstart
Reference Documentation
Data Loading
Compute Potential from Likelihood
Saving of Samples
Setup Schedulers
Setup Custom Solver
Advanced Topics
Extend Adapted Quantities
Extending Schedulers
Examples
Image Classification on CIFAR-10
API Documentation
jax_sgmc
jax-sgmc
Index
Edit on GitHub
Index
_
|
A
|
B
|
C
|
D
|
F
|
G
|
H
|
I
|
J
|
K
|
L
|
M
|
N
|
O
|
P
|
R
|
S
|
T
|
U
|
V
_
__call__() (jax_sgmc.data.core.FullDataMapFunction method)
(jax_sgmc.data.core.FullDataMapperFunction method)
(jax_sgmc.data.core.GetBatchFunction method)
(jax_sgmc.data.core.MaskedMappedFunction method)
(jax_sgmc.data.core.UnmaskedMappedFunction method)
(jax_sgmc.potential.FullPotential method)
(jax_sgmc.potential.StochasticPotential method)
__init__() (jax_sgmc.util.Array method)
A
accept (in module jax_sgmc.scheduler)
acceptance_ratio (jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
adapt_state (jax_sgmc.integrator.LangevinState attribute)
adaption() (in module jax_sgmc.adaption)
AdaptionState (class in jax_sgmc.adaption)
adaptive_step_size() (in module jax_sgmc.scheduler)
amagold() (in module jax_sgmc.alias)
(in module jax_sgmc.solver)
AMAGOLDState (class in jax_sgmc.solver)
Array (class in jax_sgmc.util)
B
b_sqrt (jax_sgmc.adaption.NoiseModel attribute)
batch_format() (jax_sgmc.data.core.HostDataLoader method)
burn_in (in module jax_sgmc.scheduler)
burn_in_state (in module jax_sgmc.scheduler)
C
CacheState (class in jax_sgmc.data.core)
cb_diff_sqrt (jax_sgmc.adaption.NoiseModel attribute)
checkpoint() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
constant_temperature() (in module jax_sgmc.scheduler)
cyclic_burn_in() (in module jax_sgmc.scheduler)
cyclic_temperature() (in module jax_sgmc.scheduler)
D
data_state (jax_sgmc.integrator.LangevinState attribute)
(jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
DataCollector (class in jax_sgmc.io)
DataLoader (class in jax_sgmc.data.core)
DeviceDataLoader (class in jax_sgmc.data.core)
DeviceNumpyDataLoader (class in jax_sgmc.data.numpy_loader)
dict_to_pytree() (in module jax_sgmc.io)
F
finalize() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
finished() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
fisher_information() (in module jax_sgmc.adaption)
flat_potential (jax_sgmc.adaption.AdaptionState attribute)
friction_leapfrog() (in module jax_sgmc.integrator)
full_data_mapper() (in module jax_sgmc.data.core)
full_data_state (jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
full_potential() (in module jax_sgmc.potential)
full_reference_data() (in module jax_sgmc.data.core)
FullDataMapFunction (class in jax_sgmc.data.core)
FullDataMapperFunction (class in jax_sgmc.data.core)
FullPotential (class in jax_sgmc.potential)
G
g_inv (jax_sgmc.adaption.Manifold attribute)
gamma (jax_sgmc.adaption.Manifold attribute)
get_batches() (jax_sgmc.data.core.HostDataLoader method)
(jax_sgmc.data.numpy_loader.NumpyDataLoader method)
(jax_sgmc.data.tensorflow_loader.TensorflowDataLoader method)
get_fn (jax_sgmc.scheduler.specific_scheduler attribute)
get_full_data() (jax_sgmc.data.core.DeviceDataLoader method)
(jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader method)
get_random_data() (jax_sgmc.data.core.DeviceDataLoader method)
(jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader method)
get_unravel_fn() (in module jax_sgmc.adaption)
GetBatchFunction (class in jax_sgmc.data.core)
H
HDF5Collector (class in jax_sgmc.io)
HDF5Loader (class in jax_sgmc.data.hdf5_loader)
HostDataLoader (class in jax_sgmc.data.core)
I
init_fn (jax_sgmc.scheduler.specific_scheduler attribute)
init_mass() (in module jax_sgmc.integrator)
init_random_data() (jax_sgmc.data.core.DeviceDataLoader method)
(jax_sgmc.data.numpy_loader.DeviceNumpyDataLoader method)
init_scheduler() (in module jax_sgmc.scheduler)
initial_burn_in() (in module jax_sgmc.scheduler)
initializer_batch() (jax_sgmc.data.core.DataLoader method)
integrator_state (jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
inv (jax_sgmc.adaption.MassMatrix attribute)
J
jax_sgmc.adaption
module
jax_sgmc.alias
module
jax_sgmc.data.core
module
jax_sgmc.data.hdf5_loader
module
jax_sgmc.data.numpy_loader
module
jax_sgmc.data.tensorflow_loader
module
jax_sgmc.integrator
module
jax_sgmc.io
module
jax_sgmc.potential
module
jax_sgmc.scheduler
module
jax_sgmc.solver
module
jax_sgmc.util
module
K
key (jax_sgmc.integrator.LangevinState attribute)
(jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
(jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
kinetic_energy_end (jax_sgmc.integrator.ObaboState attribute)
kinetic_energy_start (jax_sgmc.integrator.ObaboState attribute)
L
langevin_diffusion() (in module jax_sgmc.integrator)
LangevinState (class in jax_sgmc.integrator)
latent_variables (jax_sgmc.integrator.LangevinState attribute)
LeapfrogState (class in jax_sgmc.integrator)
list_pmap() (in module jax_sgmc.util)
list_vmap() (in module jax_sgmc.util)
load_state() (jax_sgmc.data.core.HostDataLoader method)
(jax_sgmc.data.numpy_loader.NumpyDataLoader method)
M
Manifold (class in jax_sgmc.adaption)
MaskedMappedFunction (class in jax_sgmc.data.core)
mass_matrix() (in module jax_sgmc.adaption)
mass_state (jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
MassMatrix (class in jax_sgmc.adaption)
mcmc (class in jax_sgmc.solver)
MemoryCollector (class in jax_sgmc.io)
minibatch_potential() (in module jax_sgmc.potential)
MiniBatchInformation (class in jax_sgmc.data.core)
model_state (jax_sgmc.integrator.LangevinState attribute)
(jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
module
jax_sgmc.adaption
jax_sgmc.alias
jax_sgmc.data.core
jax_sgmc.data.hdf5_loader
jax_sgmc.data.numpy_loader
jax_sgmc.data.tensorflow_loader
jax_sgmc.integrator
jax_sgmc.io
jax_sgmc.potential
jax_sgmc.scheduler
jax_sgmc.solver
jax_sgmc.util
momentum (jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
N
no_save() (in module jax_sgmc.io)
NoiseModel (class in jax_sgmc.adaption)
NumpyBase (class in jax_sgmc.data.numpy_loader)
NumpyDataLoader (class in jax_sgmc.data.numpy_loader)
O
obabo() (in module jax_sgmc.alias)
(in module jax_sgmc.integrator)
ObaboState (class in jax_sgmc.integrator)
P
parallel_tempering() (in module jax_sgmc.solver)
polynomial_step_size() (in module jax_sgmc.scheduler)
polynomial_step_size_first_last() (in module jax_sgmc.scheduler)
positions (jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
potential (jax_sgmc.integrator.LangevinState attribute)
(jax_sgmc.integrator.LeapfrogState attribute)
(jax_sgmc.integrator.ObaboState attribute)
(jax_sgmc.solver.AMAGOLDState attribute)
(jax_sgmc.solver.SGGMCState attribute)
progress_bar_state (in module jax_sgmc.scheduler)
pytree_dict_keys() (in module jax_sgmc.io)
pytree_leaves_to_list() (in module jax_sgmc.util)
pytree_list_to_leaves() (in module jax_sgmc.util)
pytree_to_dict() (in module jax_sgmc.io)
R
random_reference_data() (in module jax_sgmc.data.core)
random_thinning() (in module jax_sgmc.scheduler)
random_tree() (in module jax_sgmc.integrator)
ravel_fn (jax_sgmc.adaption.AdaptionState attribute)
re_sgld() (in module jax_sgmc.alias)
reference_data (jax_sgmc.data.numpy_loader.NumpyBase property)
register_chain() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
register_data_loader() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
register_dictionize_rule() (in module jax_sgmc.io)
register_ordered_pipeline() (jax_sgmc.data.core.HostDataLoader method)
(jax_sgmc.data.numpy_loader.NumpyDataLoader method)
(jax_sgmc.data.tensorflow_loader.TensorflowDataLoader method)
register_random_pipeline() (jax_sgmc.data.core.HostDataLoader method)
(jax_sgmc.data.numpy_loader.NumpyDataLoader method)
(jax_sgmc.data.tensorflow_loader.TensorflowDataLoader method)
resume() (jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
reversible_leapfrog() (in module jax_sgmc.integrator)
rms_prop() (in module jax_sgmc.adaption)
S
samples_collected (in module jax_sgmc.scheduler)
save() (in module jax_sgmc.io)
(jax_sgmc.io.DataCollector method)
(jax_sgmc.io.HDF5Collector method)
(jax_sgmc.io.MemoryCollector method)
save_state() (jax_sgmc.data.core.HostDataLoader method)
(jax_sgmc.data.numpy_loader.NumpyDataLoader method)
schedule() (in module jax_sgmc.scheduler)
scheduler_state() (in module jax_sgmc.scheduler)
sggmc() (in module jax_sgmc.alias)
(in module jax_sgmc.solver)
SGGMCState (class in jax_sgmc.solver)
sghmc() (in module jax_sgmc.alias)
sgld() (in module jax_sgmc.alias)
sgmc() (in module jax_sgmc.solver)
specific_scheduler (class in jax_sgmc.scheduler)
sqrt (jax_sgmc.adaption.MassMatrix attribute)
sqrt_g_inv (jax_sgmc.adaption.Manifold attribute)
state (in module jax_sgmc.scheduler)
(jax_sgmc.adaption.AdaptionState attribute)
static_information (jax_sgmc.data.core.DataLoader property)
(jax_sgmc.data.numpy_loader.NumpyBase property)
(jax_sgmc.data.tensorflow_loader.TensorflowDataLoader property)
static_information() (in module jax_sgmc.scheduler)
step_size (in module jax_sgmc.scheduler)
step_size_state (in module jax_sgmc.scheduler)
StochasticPotential (class in jax_sgmc.potential)
T
temperature (in module jax_sgmc.scheduler)
temperature_state (in module jax_sgmc.scheduler)
TensorflowDataLoader (class in jax_sgmc.data.tensorflow_loader)
thinning_state (in module jax_sgmc.scheduler)
tree_add() (in module jax_sgmc.util)
tree_dtype_struct() (in module jax_sgmc.data.core)
tree_index() (in module jax_sgmc.data.core)
tree_multiply() (in module jax_sgmc.util)
tree_scale() (in module jax_sgmc.util)
U
UnmaskedMappedFunction (class in jax_sgmc.data.core)
unravel_fn (jax_sgmc.adaption.AdaptionState attribute)
update_fn (jax_sgmc.scheduler.specific_scheduler attribute)
V
variance (jax_sgmc.integrator.LangevinState attribute)
Read the Docs
v: latest
Versions
latest
stable
Downloads
On Read the Docs
Project Home
Builds