genjax.sp¶
Structural primitives and combinators for building complex models.
sp
¶
Stochastic Probabilities (SP) for GenJAX
This module implements SPDistribution following the design of GenSP.jl, enabling probabilistic programming with importance-weighted samples. SP distributions produce weighted samples that enable unbiased estimation of probabilities and expectations.
References
- "Probabilistic Programming with Stochastic Probabilities" Alexander K. Lew, Matin Ghavami, Martin Rinard, Vikash K. Mansinghka
- GenSP.jl: https://github.com/probcomp/GenSP.jl
SPDistribution
¶
Bases: GFI[X, X]
, ABC
Abstract base class for Stochastic Probability distributions.
SPDistributions extend the GFI interface with importance-weighted sampling. Instead of implementing simulate/assess directly, subclasses implement random_weighted and estimate_logpdf.
Note: SPDistribution is GFI[X, X] - its return value is the same as its choices (like Distribution in core.py).
random_weighted
abstractmethod
¶
Sample a value and compute its importance weight.
estimate_logpdf
abstractmethod
¶
Estimate the log probability density of a value.
simulate
¶
Simulate by calling random_weighted.
Source code in src/genjax/sp.py
assess
¶
generate
¶
Generate - for SP distributions this is just simulate.
Source code in src/genjax/sp.py
update
¶
regenerate
¶
merge
¶
filter
¶
get_selection
¶
Get selection for SP distribution choices.
For SPDistribution, we should determine the selection based on the actual structure of x. This is implementation-specific.
Source code in src/genjax/sp.py
SMCAlgorithm
¶
Bases: SPDistribution
Abstract base class for SMC-based SP distributions.
Extends SPDistribution with composable SMC functionality, bridging to the GenJAX inference.smc module.
run_smc
abstractmethod
¶
Run standard Sequential Monte Carlo algorithm.
Source code in src/genjax/sp.py
run_csmc
abstractmethod
¶
Run Conditional Sequential Monte Carlo algorithm.
Ensures one particle follows the retained trajectory while maintaining proper importance weighting.
Source code in src/genjax/sp.py
Target
¶
ImportanceSampling
¶
Bases: SMCAlgorithm
, Pytree
Importance sampling as an SPDistribution.
Samples from a target distribution using a proposal distribution and importance weighting.
random_weighted
¶
Sample using importance sampling with vectorization.
Source code in src/genjax/sp.py
estimate_logpdf
¶
Estimate log probability using importance sampling with vectorization.
Source code in src/genjax/sp.py
run_smc
¶
Run SMC algorithm using existing importance sampling implementation.
Bridges to GenJAX inference.smc.init functionality.
Source code in src/genjax/sp.py
run_csmc
¶
Run Conditional SMC algorithm with retained particle.
Uses conditional SMC functionality from inference.smc module.
Source code in src/genjax/sp.py
Marginal
¶
Bases: SPDistribution
, Pytree
Marginal distribution over a specific address using an SMC algorithm.
Following GenSP.jl design: parameterized by an algorithm that handles the actual inference, while Marginal specifies which address to marginalize.
Returns the value at the specified address extracted from algorithm samples.
random_weighted
¶
Sample from marginal distribution using the algorithm.
Source code in src/genjax/sp.py
estimate_logpdf
¶
Estimate marginal log probability using the algorithm.
Source code in src/genjax/sp.py
get_selection
¶
Create a Selection object from a choice map.
This function creates a Selection that matches all addresses present in the given choice map structure. It handles different types of choice maps used by various generative functions:
- None: Returns NoneSel (matches no addresses)
- dict: Returns selection matching all keys in the dictionary
- other: Returns AllSel (matches all addresses)
Source code in src/genjax/sp.py
importance_sampling
¶
importance_sampling(target: Target, proposal: Optional[GFI[X, Any]] = None, n_particles: int = 100) -> ImportanceSampling
Create an importance sampling SPDistribution.
Source code in src/genjax/sp.py
marginal
¶
Create a marginal distribution over a specific address using an algorithm.
Following GenSP.jl design: parameterized by an algorithm.