genjax.pjax¶
Probabilistic JAX (PJAX) - foundational probabilistic programming primitives.
pjax
¶
PJAX: Probabilistic JAX
This module implements PJAX (Probabilistic JAX), which extends JAX with probabilistic primitives and specialized interpreters for handling probabilistic computations.
PJAX provides the foundational infrastructure for GenJAX's probabilistic programming capabilities by introducing:
-
Probabilistic Primitives: Custom JAX primitives (
sample_p
,log_density_p
) that represent random sampling and density evaluation operations. -
JAX-aware Interpreters: Specialized interpreters that handle probabilistic primitives while preserving JAX's transformation semantics:
Seed
: Eliminates PJAX's sampling primitive for JAX PRNG implementations-
ModularVmap
: Vectorizes probabilistic computations -
Staging Infrastructure: Tools for converting Python functions to JAX's intermediate representation (Jaxpr) while preserving probabilistic semantics.
Key Concepts
- Assume Primitive: Represents random sampling operations in Jaxpr
- Seed Transformation: Converts probabilistic functions to accept explicit keys
- Modular Vmap: Vectorizes probabilistic functions while preserving semantics
- Elaborated Primitives: Enhanced primitives with metadata for pretty printing
Keyful Sampler Contract
All samplers used with PJAX must follow this signature contract:
def keyful_sampler(key: PRNGKey, *args, sample_shape: tuple[int, ...], **kwargs) -> Array:
'''Sample from a distribution.
Args:
key: JAX PRNGKey for randomness
*args: Distribution parameters (positional)
sample_shape: REQUIRED keyword argument specifying sample shape
**kwargs: Additional distribution parameters (keyword)
Returns:
Array with shape sample_shape + distribution.event_shape
'''
The sample_shape parameter is REQUIRED and must be accepted as a keyword argument. This ensures compatibility with JAX/TensorFlow Probability conventions.
Usage
from genjax.pjax import seed, modular_vmap, sample_binder
import tensorflow_probability.substrates.jax as tfp
# Create a keyful sampler following the contract
def normal_sampler(key, loc, scale, sample_shape=(), **kwargs):
dist = tfp.distributions.Normal(loc, scale)
return dist.sample(seed=key, sample_shape=sample_shape)
# Bind to PJAX primitive
normal = sample_binder(normal_sampler, name="normal")
# Transform probabilistic function to accept explicit keys
seeded_fn = seed(probabilistic_function)
result = seeded_fn(key, args)
# Vectorize probabilistic computations
vmap_fn = modular_vmap(probabilistic_function, in_axes=(0,))
results = vmap_fn(batched_args)
Technical Details
PJAX works by representing sampling and density evaluation as JAX primitives that can be
interpreted differently depending on the transformation applied. The seed
transformation eliminates the sampling primitive by providing explicit randomness,
while modular_vmap
preserves both primitives for probability-aware vectorization.
References
- JAX Primitives: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
- GenJAX Documentation: See src/genjax/CLAUDE.md for PJAX usage patterns
sample_p
module-attribute
¶
Core primitive representing random sampling operations.
sample_p
is the fundamental primitive in PJAX that represents the act of
drawing a random sample from a probability distribution. It appears in
Jaxpr when probabilistic functions are staged, and different interpreters
handle it in different ways:
Seed
: Replaces with actual sampling using provided PRNG keyModularVmap
: Vectorizes the sampling operation- Standard JAX: Raises warning/exception (requires transformation)
The primitive carries metadata about the sampler function, distribution parameters, and optional support constraints.
log_density_p
module-attribute
¶
Core primitive representing log-density evaluation operations.
log_density_p
represents the evaluation of log probability density at
a given value. This is dual to sample_p
- while sample_p
generates
samples, log_density_p
evaluates how likely those samples are under
the distribution.
Used primarily in: - Density evaluation for inference algorithms - Gradient computation in variational methods - Importance weight calculations in SMC
enforce_lowering_exception
module-attribute
¶
Whether to raise exceptions when sample_p primitives reach MLIR lowering.
When True, attempting to compile probabilistic functions (e.g., with jax.jit)
without first applying seed()
will raise a LoweringSamplePrimitiveToMLIRException.
This prevents silent errors where PRNG keys get baked into compiled code.
Set to False for debugging or if you want warnings instead of exceptions.
lowering_warning
module-attribute
¶
Whether to show warnings when sample_p primitives reach MLIR lowering.
When True, shows warning messages instead of raising exceptions when probabilistic primitives reach compilation without proper transformation. Generally, exceptions (enforce_lowering_exception=True) are preferred as they prevent subtle bugs.
InitialStylePrimitive
¶
Bases: Primitive
JAX primitive with configurable transformation implementations.
This class extends JAX's Primitive
to provide a convenient way to define
custom primitives where transformation semantics are provided at the binding
site rather than registration time. This is essential for PJAX's dynamic
primitive creation where the same primitive can have different behaviors
depending on the probabilistic context.
The primitive expects implementations for JAX transformations to be provided
as parameters during initial_style_bind(...)
calls using these keys:
Transformation Keys
impl
: Evaluation semantics - the concrete implementation that executes when the primitive is evaluated with concrete valuesabstract
: Abstract semantics - used by JAX when tracing a Python program to a Jaxpr; determines output shapes and dtypes from input abstract valuesjvp
: Forward-mode automatic differentiation - defines how to compute Jacobian-vector products for this primitivebatch
: Vectorization semantics forvmap
- defines how the primitive behaves when vectorized over a batch dimensionlowering
: Compilation semantics forjit
- defines how to lower the primitive to MLIR for XLA compilation
Technical Details
Unlike standard JAX primitives where transformation rules are registered once, InitialStylePrimitive defers all rule definitions to binding time. The primitive acts as a parameterizable template where transformation semantics are injected dynamically, enabling PJAX's context-dependent reinterpretation of probabilistic operations.
Example
my_primitive = InitialStylePrimitive("my_op")
# Transformation semantics provided at binding time
result = my_primitive.bind(
inputs,
impl=lambda x: x + 1, # Evaluation: add 1
abstract=lambda aval: aval, # Same shape/dtype
jvp=lambda primals, tangents: (primals[0] + 1, tangents[0]),
batch=lambda args, dim: (args[0] + 1,), # Vectorized add
lowering=my_lowering_rule
)
Source code in src/genjax/pjax.py
PPPrimitive
¶
Bases: Primitive
A primitive wrapper that hides metadata from JAX's Jaxpr pretty printer.
PPPrimitive
(Pretty Print Primitive) wraps an underlying InitialStylePrimitive
and stores metadata parameters in a hidden field to prevent them from cluttering
JAX's Jaxpr pretty printer output. PJAX's probabilistic
primitives carry complex metadata (samplers, distributions, transformation
rules, etc.) that would make Jaxpr representations unreadable if displayed.
The wrapper: - Stores the underlying primitive and its parameters in a private field - Hides metadata from JAX's Jaxpr pretty printer - Acts as a transparent proxy for all JAX transformations - Preserves all transformation behavior of the wrapped primitive
Technical Details
When JAX creates a Jaxpr representation, it only shows the primitive name and visible parameters. By storing metadata in the PPPrimitive's internal state rather than as binding parameters, we get clean Jaxpr output while preserving all the functionality and metadata needed for transformations.
Example
base_prim = InitialStylePrimitive("sample")
# Without PPPrimitive: cluttered Jaxpr with all metadata visible
# sample[impl=<function>, abstract=<function>, distribution="normal", ...]
# With PPPrimitive: clean Jaxpr output
pretty_prim = PPPrimitive(base_prim, distribution="normal", name="x")
# Jaxpr shows: sample
result = pretty_prim.bind(args, mu=0.0, sigma=1.0)
Source code in src/genjax/pjax.py
TerminalStyle
¶
ANSI terminal styling for pretty-printed primitives.
LoweringSamplePrimitiveToMLIRException
¶
Bases: Exception
Exception raised when PJAX sample_p primitives reach MLIR lowering.
This exception occurs when probabilistic functions containing sample_p primitives
are passed to JAX transformations (like jit, grad, vmap) without first applying
the seed()
transformation. This prevents silent errors where PRNG keys get
baked into compiled code, leading to deterministic behavior.
The exception includes execution context to help identify where the problematic binding occurred in the user's code.
Initialize the exception with lowering message and binding context.
Source code in src/genjax/pjax.py
SamplerConfig
dataclass
¶
SamplerConfig(keyful_sampler: Callable[..., Any], name: str | None = None, sample_shape: tuple[int, ...] = (), support: Callable[..., Any] | None = None)
Configuration for a probabilistic sampler.
Encapsulates all the information needed to create and execute a sampler, making the relationships between different components explicit.
Keyful Sampler Contract
The keyful_sampler must follow this exact signature:
def keyful_sampler(key: PRNGKey, args, sample_shape: tuple[int, ...], *kwargs) -> Array: '''Sample from a distribution.
Args:
key: JAX PRNGKey for randomness
*args: Distribution parameters (positional)
sample_shape: REQUIRED keyword argument specifying sample shape
**kwargs: Additional distribution parameters (keyword)
Returns:
Array with shape sample_shape + distribution.event_shape
'''
The sample_shape parameter is REQUIRED and must be accepted as a keyword argument. This ensures compatibility with JAX/TensorFlow Probability conventions and enables proper vectorization under PJAX's modular_vmap.
KeylessWrapper
¶
Wrapper that provides keyless sampling interface using global counter.
This encapsulates the "cheeky" global counter pattern and makes it explicit that this is a convenience wrapper with known limitations.
Source code in src/genjax/pjax.py
FlatSamplerCache
¶
Manages the flattened version of samplers for JAX interpretation.
JAX interpreters work with flattened argument lists, so we need to pre-compute a flattened version of the sampler for efficiency.
Source code in src/genjax/pjax.py
get_flat_sampler
¶
Get or create the flattened sampler for these arguments.
Source code in src/genjax/pjax.py
VmapBatchHandler
¶
Handles the complex vmap batching logic for probabilistic primitives.
This encapsulates the logic for how sample shapes change under vmap and how primitives get rebound with new sample shapes.
Source code in src/genjax/pjax.py
create_batch_rule
¶
Create the batch rule function for this sampler.
Source code in src/genjax/pjax.py
LogDensityConfig
dataclass
¶
Configuration for a log density function.
Encapsulates all the information needed to create and execute a log density primitive, following the same component-based architecture as SamplerConfig.
Log Density Function Contract
The log_density_impl must follow this signature:
def log_density_impl(value, args, *kwargs) -> float: '''Compute log probability density.
Args:
value: The value to evaluate density at
*args: Distribution parameters (positional)
**kwargs: Additional distribution parameters (keyword)
Returns:
Scalar log probability density
'''
Log density functions always return scalars - there is no sample_shape concept.
LogDensityVmapHandler
¶
Handles the complex vmap batching logic for log density primitives.
This encapsulates the logic for how log density functions get vectorized under vmap, handling both args-only and args+kwargs cases.
Source code in src/genjax/pjax.py
create_batch_rule
¶
Create the batch rule function for this log density function.
Source code in src/genjax/pjax.py
Environment
dataclass
¶
Variable environment for Jaxpr interpretation.
Manages the mapping between JAX variables (from Jaxpr) and their concrete values during interpretation. This is essential for interpreters that need to execute Jaxpr equations step-by-step while maintaining state.
The environment handles both: - Var objects: Variables with unique identifiers - Literal objects: Constant values embedded in the Jaxpr
This design enables efficient interpretation of probabilistic Jaxpr by PJAX's specialized interpreters.
read
¶
Read a value from a variable in the environment.
Source code in src/genjax/pjax.py
write
¶
Write a value to a variable in the environment.
Source code in src/genjax/pjax.py
copy
¶
Environment.copy
is sometimes used to create a new environment with
the same variables and values as the original, especially in CPS
interpreters (where a continuation closes over the application of an
interpreter to a Jaxpr
).
Source code in src/genjax/pjax.py
Seed
dataclass
¶
Interpreter that eliminates probabilistic primitives with explicit randomness.
The Seed
interpreter is PJAX's core mechanism for making probabilistic
computations compatible with standard JAX transformations. It works by
traversing a Jaxpr and replacing sample_p
primitives with actual sampling
operations using explicit PRNG keys.
Key Features: - Eliminates PJAX primitives: Converts sample_p to concrete sampling - Explicit randomness: Uses provided PRNG key for all random operations - JAX compatibility: Output can be jit'd, vmap'd, grad'd normally - Deterministic: Same key produces same results (good for debugging) - Hierarchical key splitting: Automatically manages keys for nested operations
The interpreter handles JAX control flow primitives (cond, scan) by recursively applying the seed transformation to their sub-computations.
Usage Pattern
This interpreter is primarily used via the seed()
transformation:
Technical Details
The interpreter maintains a PRNG key that is split at each random operation, ensuring proper randomness while maintaining determinism. For control flow, it passes seeded versions of sub-computations to JAX's control primitives.
ModularVmap
dataclass
¶
Vectorization interpreter that preserves probabilistic primitives.
The ModularVmap
interpreter extends JAX's vmap
to handle
PJAX's probabilistic primitives correctly. Unlike standard vmap
, which isn't aware
of PJAX primitives, this interpreter knows how to vectorize probabilistic
operations while preserving their semantic meaning.
Key Capabilities:
- Probabilistic vectorization: Correctly handles sample_p
under vmap
- Sample shape inference: Automatically adjusts distribution sample shapes
- Control flow support: Handles cond/scan within vectorized computations
- Semantic preservation: Maintains probabilistic meaning across batches
How It Works
The interpreter uses a "dummy argument" technique to track the vectorization axis size and injects this information into probabilistic primitives so they can adjust their behavior appropriately (e.g., sampling multiple independent values vs. broadcasting parameters).
Usage
Primarily used via the modular_vmap()
function:
Technical Details
The interpreter maintains PJAX primitives in the Jaxpr rather than eliminating them (unlike Seed). This allows proper vectorization semantics for probabilistic operations.
get_shaped_aval
¶
cached_stage_dynamic
¶
Cache-enabled function to stage a flattened function to Jaxpr.
Source code in src/genjax/pjax.py
stage
¶
Stage a function to JAX's intermediate representation (Jaxpr).
Converts a Python function into JAX's Jaxpr format, which enables interpretation and transformation of the function's computation graph. This is essential for PJAX's ability to inspect and transform probabilistic computations.
Example
Source code in src/genjax/pjax.py
initial_style_bind
¶
Binds a primitive to a function call.
Source code in src/genjax/pjax.py
396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 |
|
create_log_density_primitive
¶
Create a log density primitive from a log density configuration.
This is the main entry point that orchestrates all the log density components.
Source code in src/genjax/pjax.py
log_density_binder
¶
Create a log density primitive using component-based architecture.
Source code in src/genjax/pjax.py
create_sample_primitive
¶
Create a sample primitive from a sampler configuration.
This is the main entry point that orchestrates all the components. Replaces the current sample_binder function with clearer separation of concerns.
Source code in src/genjax/pjax.py
sample_binder
¶
sample_binder(keyful_sampler: Callable[..., Any], name: str | None = None, sample_shape: tuple[int, ...] = (), support: Callable[..., Any] | None = None)
Create a sample primitive that binds a keyful sampler to PJAX's sample_p primitive.
Uses a component-based architecture to handle the complex interactions between keyless sampling, sample shapes, JAX flattening, and vmap transformations.
Note
The keyful_sampler MUST accept sample_shape as a keyword argument. This is required for compatibility with JAX transformations and proper vectorization.
Source code in src/genjax/pjax.py
wrap_logpdf
¶
Wrap a log-density function to work with PJAX primitives.
Source code in src/genjax/pjax.py
seed
¶
Transform a function to accept an explicit PRNG key.
This transformation eliminates probabilistic primitives by providing explicit randomness through a PRNG key, enabling the use of standard JAX transformations like jit and vmap.
Example
import jax.random as jrand from genjax import gen, normal, seed
@gen ... def model(): ... return normal(0.0, 1.0) @ "x"
seeded_model = seed(model.simulate) key = jrand.key(0) trace = seeded_model(key)
Source code in src/genjax/pjax.py
modular_vmap
¶
modular_vmap(f: Callable[..., R], in_axes: int | tuple[int | None, ...] | Sequence[Any] | None = 0, axis_size: int | None = None, axis_name: str | None = None, spmd_axis_name: str | None = None) -> Callable[..., R]
Vectorize a function while preserving probabilistic semantics.
This is PJAX's probabilistic-aware version of jax.vmap
. Unlike standard
vmap
, which fails on probabilistic primitives, modular_vmap
correctly
handles probabilistic computations by preserving their semantic meaning
across the vectorized dimension.
Key Differences from jax.vmap
:
- Probabilistic awareness: Handles sample_p
and log_density_p
primitives
- Sample shape handling: Automatically adjusts distribution sample shapes
- Independent sampling: Each vectorized element gets independent randomness
- Semantic correctness: Maintains probabilistic meaning across batches
Example
import jax.random as jrand
from genjax import normal, modular_vmap
def sample_normal(mu):
return normal.sample(mu, 1.0) # Contains sample_p primitive
# Vectorize over different means
batch_sample = modular_vmap(sample_normal, in_axes=(0,))
mus = jnp.array([0.0, 1.0, 2.0])
samples = batch_sample(mus) # Shape: (3,), independent samples
# Compare with seed for JAX compatibility
seeded_fn = seed(batch_sample)
samples = seeded_fn(key, mus) # Can be jit'd, vmap'd, etc.
Note
For JAX transformations (jit, grad, etc.), use seed()
first:
jax.jit(seed(modular_vmap(f)))
rather than trying to jit
the modular_vmap directly.
Source code in src/genjax/pjax.py
1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 |
|