Abstract #
We present GenJAX, a new language and compiler for vectorized programmable probabilistic inference.
GenJAX integrates the vectorizing map (vmap) operation from array programming frameworks such as JAX
into the programmable inference paradigm, enabling compositional vectorization of features such as probabilistic program traces,
stochastic branching (for expressing mixture models), and programmable inference interfaces for writing custom probabilistic inference algorithms.
We formalize vectorization as a source-to-source program transformation on a core calculus for probabilistic programming ($\gen$), and
prove that it correctly vectorizes both modeling and inference operations.
We have implemented our approach in the GenJAX language and compiler,
and have empirically evaluated this implementation on several benchmarks and case studies. Our results show that our implementation
supports a wide and expressive set of programmable inference patterns and delivers performance comparable to hand-optimized JAX code.
Presentation #
Watch on YouTube: POPL 2026 Presentation
Introduction #
In recent years, probabilistic programming has demonstrated remarkable effectiveness in a range of application domains, including 3D perception and scene understanding, probabilistic robotics, automated data cleaning and analysis, particle physics, time series structure discovery, test-time control of large language models, and cognitive modeling of theory of mind. All of these applications require sophisticated probabilistic reasoning over complex, structured data and rely on probabilistic programming languages (PPLs) with programmable inference — the ability to customize probabilistic inference algorithms through proposals, kernels, and variational families — to improve the quality of posterior approximation. But fully realizing the benefits that probabilistic programming can deliver often requires substantial computational resources, as probabilistic inference scales by increasing the number of likelihood evaluations, sequential Monte Carlo particles, or Markov chain Monte Carlo chains.
We present GenJAX, a new language and compiler for vectorized programmable probabilistic inference.
GenJAX integrates the vectorizing map (vmap) operation from array programming frameworks such as JAX
into the context of probabilistic programming with programmable inference, enabling the compositional vectorization of features such as
probabilistic program traces, stochastic branching (for expressing mixture models), and programmable inference interfaces. This vectorization
enables the implementation of compute-intensive probabilistic programming and probabilistic inference operations on modern GPUs, making it
possible to deploy the substantial computational resources that GPUs provide to accelerate large-scale probabilistic inference.
Design Considerations
GenJAX is designed around the interaction between vmap and several probabilistic programming features that support the
implementation of sophisticated models and inference algorithms:
- Compositional vectorization. Our target class of probabilistic programs features multiple vectorizable computational patterns.
Examples include computing likelihoods simultaneously on many pieces of data (as part of modeling) and evolving collections of particles or chains
(as part of inference). Our integration of
vmaptherefore supports vectorization of both modeling and inference code. - Vectorization of probabilistic program traces. Traces are structured record objects used to represent samples. They are a data lingua franca
for Monte Carlo and variational inference. Under vectorization by
vmap, traces support an efficient struct-of-array representation. - Vectorized stochastic branching. Probabilistic mixture models, regime-switching dynamics models, and adaptive inference algorithms require stochastic branching using random values. GenJAX supports stochastic branching while maintaining vectorization.
vmap should be applicable in both settings. Center: Traces are records used to represent samples
from probabilistic programs. Both vectorized models and vectorized inference algorithms are designed to work with vectorized (struct-of-array) traces.
Right: Probabilistic programs can branch on random values, and vmap of probabilistic programs should preserve this capability.vmap to apply to both generative models and inference algorithms. Our system implements inference on a vectorized model by
vectorizing inference applied to the model, which is justified by the commutativity corollary in the paper. Right: Survey of features in our language and compiler:
usage of these features is illustrated in Section 2 and Section 7, and implementation is discussed in Section 6.Contributions
- GenJAX: high-performance compiler (Section 6). GenJAX is an open-source compiler that extends JAX and
vmapto support programmable probabilistic inference. Probabilistic programs in GenJAX can be systematically transformed to take advantage of opportunities for vectorization in both modeling and inference. Our compiler also eliminates the overhead present in many libraries for programmable inference: we implement simulation and density interfaces using lightweight effect handlers, and exploit JAX's support for program tracing to partially evaluate inference logic away at compile time, leaving only optimized array operations. Our design maintains full compatibility with JAX's underlying ecosystem for automatic differentiation and CPU/GPU/TPU compilation. - Formal model: interaction between
vmapand programmable inference features (Section 3). We develop a formal model characterizing howvmapinteracts with probabilistic program traces and programmable inference interfaces. We introduce $\gen$, a calculus for probabilistic programming and programmable inference, on top of a core probabilistic array language for stochastic parallel computations. We definevmapas a program transformation, prove its correctness, and show how it interacts with programmable inference interfaces to support vectorization of probabilistic computations and traces. - Empirical evaluation (Section 7). We evaluate our design and implementation through a series of benchmarks and case studies: performance comparison, where GenJAX achieves near-handcoded JAX performance and can outperform existing vectorized and high-performance PPLs and array programming frameworks (JAX, PyTorch, Pyro, NumPyro, and Gen); and high-dimensional vectorized inference, where we study approximate Game of Life inversion (find the previous 512 x 512 board state which leads to the observed state) and sequential 2D robot localization with simulated LIDAR measurements. In both case studies, we use GenJAX to develop sophisticated vectorized inference algorithms, including vectorized Gibbs sampling and sequential Monte Carlo with vectorized proposals. Our final GenJAX programs exhibit high approximation accuracy and run in milliseconds on consumer-grade GPUs.
Explore the Paper
Choose your path through the material. The Tutorial track follows the paper's Overview section with a running polynomial regression example. The Theory track follows the Formal Model and Compiler sections with semantics, transformations, and proofs. All shows everything.
Evaluation #
We evaluate our language and compiler implementation on benchmarks and case studies designed to assess the following criteria:
- Performance. How does the performance of our compiler implementation compare to leading programmable inference systems? Do our abstractions introduce overhead compared to handcoded implementations of inference? We survey the performance properties of GenJAX against open-source PPLs and tensor frameworks on standard modeling and inference tasks, for both embarrassingly-parallel algorithms (importance sampling) and iterative differentiable algorithms (Hamiltonian Monte Carlo).
- Inference quality.
vmapprovides a convenient way to express inference problems over high-dimensional spaces. We study probabilistic Game of Life inversion on large boards using approximate inference and use GenJAX to construct an efficient nested vectorized Gibbs sampler. We also study a probabilistic model of robot localization using simulated LIDAR measurements and construct sequential Monte Carlo algorithms, including an efficient algorithm using proposals with vectorized locally optimal grid approximations.
Performance Survey Evaluation
The top panel examines the runtime characteristics of our compiler on importance sampling in a Beta-Bernoulli model. The model infers the bias of a coin from observed flips, using a Beta(1,1) prior and Bernoulli likelihood. We observe 50 flips and construct a posterior approximation using importance sampling. The results confirm that all frameworks accurately recover the true posterior distribution. GenJAX achieves near-identical performance to handcoded JAX (100.1% relative time).
The bottom panel presents performance results for importance sampling and Hamiltonian Monte Carlo (HMC) on the polynomial regression problem from the overview. Importance sampling exhibits parallel scaling with the number of particles: vectorized PPLs and tensor frameworks have near constant scaling while the GPU is not saturated. HMC is run iteratively, so scaling is linear in the length of the chain. GenJAX is consistently close to handcoded and optimized JAX, validating that our abstractions for programmable inference introduce minimal overhead.
Probabilistic Game of Life Inversion
Game of Life inversion is the problem of inverting the dynamics of Conway's Game of Life: given a final state, what is a possible previous state that evolves to the final state under the rules of the game? Brute force discrete search is computationally intractable, requiring evaluation of $2^{N \times N}$ states, where $N$ is the linear dimension of a square Game of Life board. In this case study, we introduce probabilistic noise into the dynamics: from an initial state, we evolve forward using the deterministic rules, but then sample with Bernoulli noise around the true value of each pixel. We illustrate approximate inversion using vectorized Gibbs sampling.
Because each cell's value is conditionally independent from non-neighboring cells' values, given its eight neighbors, we partition the board's cells into conditionally independent groups (given the other groups). Within each group, we can perform parallel Gibbs updates on all the cells, an example of chromatic Gibbs. The result is a highly efficient probabilistic inversion algorithm which can invert Life states with up to 90% accuracy in a few seconds.
Robot Localization
In robotics, simultaneous mapping and localization (SLAM) refers to the problem of constructing a representation of the map of an environment and the position of the robot within the map based on measurements. If the map is given, the problem is called localization: a robot maneuvers through a known space and receives measurements of distance to the walls. The goal is to construct a probabilistic representation of where the robot is located. In this case study, we use GenJAX to write a model for robot localization, with Gaussian drift dynamics and a simulated LIDAR measurement. Given a sequence of LIDAR measurements over time as observations, we constrain the model to produce a posterior over robot locations.
- Bootstrap filter. Sequential Monte Carlo where the prior (from the model) is used as the proposal for the latent position of the robot.
- SMC + HMC. Adds HMC moves to the bootstrap filter. These moves are applied to the particle collection after resampling.
- SMC + Locally Optimal. Uses a smart proposal based on enumerative grids: enumerate a grid in position space, evaluate each position against the observation likelihood, select the maximum likelihood grid point, and sample from a normal distribution around that point.
SMC supports natural vectorization over particles. In our experiments, the best algorithm from the standpoint of efficiency and accuracy is locally optimal SMC, which adds
another layer of vectorization within the custom proposal. The likelihood grid evaluations can be fully vectorized, and the model already features vectorization in the LIDAR
measurement model. These opportunities for vectorization (in the model, in the locally optimal proposal, and across the particle collection) are convenient to program against
with vmap and lead to a highly efficient inference algorithm which can accurately track the 2D robot's location within the map in milliseconds.
Conclusion #
This work presents GenJAX, a language and compiler for vectorized probabilistic programming with programmable inference. This system integrates vmap with
programmable inference features: we extend vmap support to generative functions, including support for vectorization using vmap of probabilistic
program traces, stochastic branching, and programmable inference interfaces. Benchmarks show this approach yields low overhead relative to hand-optimized JAX, and
simultaneously delivers greater expressiveness and competitive performance with other probabilistic programming systems targeting modern accelerators.
Future Work
- Vectorized inference diagnostics. By automating the vectorized implementation of nested models and inference algorithms, GenJAX makes it easy to experiment with parallel implementations of custom Monte Carlo estimators of a broad range of information-theoretic quantities derived from probabilistic programs, including KL divergence between inference algorithms and the conditional mutual information among subsets of latent variables. Although computationally intensive on CPUs, these estimators are comprised of nested, massively parallel computations, and may become more practical and widespread given suitable automation.
- Spatial or geometric probabilistic programs. We expect that GenJAX's support for array programming and programmable probabilistic inference may be well-suited for spatial computing applications. Domains such as robotics, autonomous navigation, computational imaging, and scientific simulation increasingly require sophisticated probabilistic reasoning over high-dimensional spatial data — including LiDAR point clouds, depth images, and other spatial data types. Probabilistic programming applications in these domains naturally involve computations that manipulate multi-dimensional arrays. GenJAX's design is uniquely suited to support practitioners writing these types of probabilistic programs and provides useful vectorization automation and support for compilation to efficient GPU implementations.
Artifact & Data Availability #
The artifact associated with this paper is available on Zenodo. The source code is available at https://github.com/probcomp/genjax.
Citation #
@article{becker2026genjax,
title = {Probabilistic Programming with Vectorized Programmable Inference},
author = {Becker, McCoy R. and Huot, Mathieu and Matheos, George and
Wang, Xiaoyan and Chung, Karen and Smith, Colin and
Ritchie, Sam and Saurous, Rif A. and Lew, Alexander K. and
Rinard, Martin C. and Mansinghka, Vikash K.},
journal = {Proceedings of the ACM on Programming Languages},
volume = {10},
number = {POPL},
articleno = {87},
year = {2026},
publisher = {ACM},
doi = {10.1145/3776729},
}