genjax.inference.smc¶
Sequential Monte Carlo methods for particle-based inference.
smc
¶
Standard library of programmable inference algorithms for GenJAX.
This module provides implementations of common inference algorithms that can be composed with generative functions through the GFI (Generative Function Interface). Uses GenJAX distributions and modular_vmap for efficient vectorized computation.
References
[1] P. D. Moral, A. Doucet, and A. Jasra, "Sequential Monte Carlo samplers," Journal of the Royal Statistical Society: Series B (Statistical Methodology), vol. 68, no. 3, pp. 411–436, 2006.
ParticleCollection
¶
Bases: Pytree
Result of importance sampling containing traces, weights, and statistics.
log_marginal_likelihood
¶
Estimate log marginal likelihood using importance sampling.
Source code in src/genjax/inference/smc.py
estimate
¶
Compute weighted estimate of a function applied to particle traces.
Properly accounts for importance weights to give unbiased estimates.
Examples:
>>> import jax.numpy as jnp
>>> # particles.estimate(lambda choices: choices["param"]) # Posterior mean
>>> # particles.estimate(lambda choices: choices["param"]**2) - mean**2 # Variance
>>> # particles.estimate(lambda choices: jnp.sin(choices["x"]) + choices["y"]) # Custom
Source code in src/genjax/inference/smc.py
effective_sample_size
¶
Compute the effective sample size (ESS) from log importance weights.
The ESS measures the efficiency of importance sampling by estimating the number of independent samples that would provide equivalent statistical information. It quantifies particle degeneracy in SMC algorithms.
Mathematical Formulation
Given N particles with normalized weights w₁, ..., wₙ:
ESS = 1 / Σᵢ wᵢ² = (Σᵢ wᵢ)² / Σᵢ wᵢ²
Since Σᵢ wᵢ = 1 for normalized weights:
ESS = 1 / Σᵢ wᵢ²
Interpretation
- ESS = N: Perfect sampling (uniform weights)
- ESS = 1: Complete degeneracy (single particle has all weight)
- ESS/N: Efficiency ratio, often used to trigger resampling when < 0.5
Connection to Importance Sampling
The ESS approximates the variance inflation factor for importance sampling estimates. For self-normalized importance sampling:
Var[𝔼[f]] ≈ (N/ESS) × Var_π[f]
where π is the target distribution.
References
.. [1] Kong, A., Liu, J. S., & Wong, W. H. (1994). "Sequential imputations and Bayesian missing data problems". Journal of the American Statistical Association, 89(425), 278-288. .. [2] Liu, J. S. (2001). "Monte Carlo strategies in scientific computing". Springer, Chapter 3. .. [3] Doucet, A., de Freitas, N., & Gordon, N. (2001). "Sequential Monte Carlo methods in practice". Springer, Chapter 1.
Notes
- Computed in log-space for numerical stability
- Input weights need not be normalized (handled internally)
- Common resampling threshold: ESS < N/2 (Doucet et al., 2001)
Source code in src/genjax/inference/smc.py
systematic_resample
¶
Systematic resampling from importance weights with minimal variance.
Implements the systematic resampling algorithm (Kitagawa, 1996), which has lower variance than multinomial resampling while maintaining unbiasedness. This is the preferred resampling method for particle filters.
Mathematical Formulation
Given normalized weights w₁, ..., wₙ and cumulative sum Cᵢ = Σⱼ≤ᵢ wⱼ:
- Draw U ~ Uniform(0, 1/M) where M is the output sample size
- For i = 1, ..., M:
- Set pointer position: uᵢ = (i-1)/M + U
- Select particle: Iᵢ = min{j : Cⱼ ≥ uᵢ}
Properties
- Unbiased: 𝔼[Nᵢ] = M × wᵢ where Nᵢ is count of particle i
- Lower variance than multinomial: Var[Nᵢ] ≤ M × wᵢ × (1 - wᵢ)
- Deterministic given U: reduces Monte Carlo variance
- Preserves particle order (stratified structure)
Time Complexity: O(N + M) using binary search Space Complexity: O(N) for cumulative weights
References
.. [1] Kitagawa, G. (1996). "Monte Carlo filter and smoother for non-Gaussian nonlinear state space models". Journal of Computational and Graphical Statistics, 5(1), 1-25. .. [2] Doucet, A., & Johansen, A. M. (2009). "A tutorial on particle filtering and smoothing: Fifteen years later". Handbook of Nonlinear Filtering, 12(656-704), 3. .. [3] Hol, J. D., Schon, T. B., & Gustafsson, F. (2006). "On resampling algorithms for particle filters". In IEEE Nonlinear Statistical Signal Processing Workshop (pp. 79-82).
Notes
- Systematic resampling is preferred over multinomial for most applications
- Maintains particle diversity better than multinomial resampling
- For theoretical analysis of resampling methods, see [3]
Source code in src/genjax/inference/smc.py
resample_vectorized_trace
¶
resample_vectorized_trace(trace: Trace[X, R], log_weights: ndarray, n_samples: int, method: str = 'categorical') -> Trace[X, R]
Resample a vectorized trace using importance weights.
Uses categorical or systematic sampling to select indices and jax.tree_util.tree_map to index into the Pytree leaves.
Source code in src/genjax/inference/smc.py
init
¶
init(target_gf: GFI[X, R], target_args: tuple, n_samples: Const[int], constraints: X, proposal_gf: GFI[X, Any] | None = None) -> ParticleCollection
Initialize particle collection using importance sampling.
Uses either the target's default internal proposal or a custom proposal. Proposals use signature (constraints, *target_args).
Source code in src/genjax/inference/smc.py
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 |
|
change
¶
change(particles: ParticleCollection, new_target_gf: GFI[X, R], new_target_args: tuple, choice_fn: Callable[[X], X]) -> ParticleCollection
Change target move for particle collection.
Translates particles from one model to another by: 1. Mapping each particle's choices using choice_fn 2. Using generate with the new model to get new weights 3. Accumulating importance weights
Choice Function Specification
CRITICAL: choice_fn must be a bijection on address space only.
- If X is a scalar type (e.g., float): Must be identity function
- If X is dict[str, Any]: May remap keys but CANNOT modify values
- Values must be preserved exactly to maintain probability density
Valid Examples: - lambda x: x (identity mapping) - lambda d: {"new_key": d["old_key"]} (key remapping) - lambda d: {"mu": d["mean"], "sigma": d["std"]} (multiple key remap)
Invalid Examples: - lambda x: x + 1 (modifies scalar values - breaks assumptions) - lambda d: {"key": d["key"] * 2} (modifies dict values - breaks assumptions)
Source code in src/genjax/inference/smc.py
extend
¶
extend(particles: ParticleCollection, extended_target_gf: GFI[X, R], extended_target_args: Any, constraints: X, extension_proposal: GFI[X, Any] | None = None) -> ParticleCollection
Extension move for particle collection.
Extends each particle by generating from the extended target model: 1. Without extension proposal: Uses extended target's generate with constraints directly 2. With extension proposal: Samples extension, merges with constraints, then generates
The extended target model is responsible for recognizing and incorporating existing particle state through its internal structure.
Source code in src/genjax/inference/smc.py
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 |
|
rejuvenate
¶
rejuvenate(particles: ParticleCollection, mcmc_kernel: Callable[[Trace[X, R]], Trace[X, R]]) -> ParticleCollection
Rejuvenate move for particle collection.
Applies an MCMC kernel to each particle independently to improve particle diversity and reduce degeneracy. The importance weights and diagnostic weights remain unchanged due to detailed balance.
Mathematical Foundation
For an MCMC kernel satisfying detailed balance, the log incremental weight is 0:
log_incremental_weight = log[p(x_new | args) / p(x_old | args)] + log[q(x_old | x_new) / q(x_new | x_old)]
Where: - p(x_new | args) / p(x_old | args) is the model density ratio - q(x_old | x_new) / q(x_new | x_old) is the proposal density ratio
Detailed balance ensures: p(x_old) * q(x_new | x_old) = p(x_new) * q(x_old | x_new)
Therefore: p(x_new) / p(x_old) = q(x_new | x_old) / q(x_old | x_new)
The model density ratio and proposal density ratio exactly cancel: log[p(x_new) / p(x_old)] + log[q(x_old | x_new) / q(x_new | x_old)] = 0
This means the importance weight contribution from the MCMC move is 0, preserving the particle weights while improving sample diversity.
Source code in src/genjax/inference/smc.py
resample
¶
Resample particle collection to combat degeneracy.
Computes log normalized weights for diagnostics before resampling. After resampling, weights are reset to uniform (zero in log space) and the marginal likelihood estimate is updated to include the average weight before resampling.
Source code in src/genjax/inference/smc.py
rejuvenation_smc
¶
rejuvenation_smc(model: GFI[X, R], transition_proposal: GFI[X, Any] | None = None, mcmc_kernel: Const[Callable[[Trace[X, R]], Trace[X, R]]] | None = None, observations: X | None = None, initial_model_args: tuple | None = None, n_particles: Const[int] = const(1000), return_all_particles: Const[bool] = const(False), n_rejuvenation_moves: Const[int] = const(1)) -> ParticleCollection
Complete SMC algorithm with rejuvenation using jax.lax.scan.
Implements sequential Monte Carlo with particle extension, resampling, and MCMC rejuvenation. Uses a single model with feedback loop where the return value becomes the next timestep's arguments, creating sequential dependencies.
Note on Return Value
This function returns only the FINAL ParticleCollection after processing all observations. Intermediate timesteps are computed but not returned. If you need all timesteps, you can modify the return statement to:
final_particles, all_particles = jax.lax.scan(smc_step, particles, remaining_obs)
return all_particles # Returns vectorized ParticleCollection with time dimension
The all_particles object would have an additional leading time dimension in all its fields (traces, log_weights, etc.), allowing access to the full particle trajectory across all timesteps.
Source code in src/genjax/inference/smc.py
655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 |
|
init_csmc
¶
init_csmc(target_gf: GFI[X, R], target_args: tuple, n_samples: Const[int], constraints: X, retained_choices: X, proposal_gf: GFI[X, Any] | None = None) -> ParticleCollection
Initialize particle collection for conditional SMC with retained particle.
Simple approach: run regular init and manually override particle 0.
Source code in src/genjax/inference/smc.py
extend_csmc
¶
extend_csmc(particles: ParticleCollection, extended_target_gf: GFI[X, R], extended_target_args: Any, constraints: X, retained_choices: X, extension_proposal: GFI[X, Any] | None = None) -> ParticleCollection
Extension move for conditional SMC with retained particle.
Like extend() but ensures particle 0 follows retained trajectory.
Source code in src/genjax/inference/smc.py
861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 |
|
Core Functions¶
importance_sampling¶
Basic importance sampling with multiple particles.
particle_filter¶
Sequential Monte Carlo for state-space models.
rejuvenation_smc¶
SMC with MCMC rejuvenation steps for better particle diversity.
Usage Examples¶
Importance Sampling¶
from genjax.inference.smc import importance_sampling
# Run with 1000 particles
keys = jax.random.split(key, 1000)
traces = jax.vmap(lambda k: model.generate(k, constraints, args))(keys)
# Extract weights
log_weights = traces.score
weights = jax.nn.softmax(log_weights)
# Weighted posterior mean
posterior_mean = jnp.sum(weights * traces["parameter"])
Particle Filter¶
from genjax.inference.smc import particle_filter
# For sequential data
@gen
def transition(prev_state, t):
return distributions.normal(prev_state, 0.1) @ f"state_{t}"
@gen
def observation(state, t):
return distributions.normal(state, 0.5) @ f"obs_{t}"
# Run particle filter
particles = particle_filter(
initial_model,
transition,
observation,
observations,
n_particles=100,
key=key
)
SMC with Rejuvenation¶
from genjax.inference.smc import rejuvenation_smc
# SMC with optional MCMC moves
result = rejuvenation_smc(
model,
observations,
n_particles=100,
n_mcmc_steps=5, # Optional: rejuvenation steps
key=key
)
Best Practices¶
- Particle Count: Use enough particles (typically 100-10000)
- Resampling: Monitor effective sample size for resampling
- Proposal Design: Use good proposal distributions
- Rejuvenation: Add MCMC steps to maintain diversity