genjax.core¶
Core functionality for GenJAX including the Generative Function Interface, traces, and model construction.
Mathematical Foundation¶
The Generative Function Interface (GFI) is based on measure theory. A generative function \(g\) defines:
- A measure kernel \(P(dx; \text{args})\) over measurable space \(X\)
- A return value function \(f(x, \text{args}) \rightarrow R\)
- Internal proposal family \(Q(dx'; \text{args}, x)\)
The importance weight from generate
is:
core
¶
Pytree
¶
Bases: Struct
Pytree
is an abstract base class which registers a class with JAX's Pytree
system. JAX's Pytree
system tracks how data classes should behave across
JAX-transformed function boundaries, like jax.jit
or jax.vmap
.
Inheriting this class provides the implementor with the freedom to declare how the subfields of a class should behave:
Pytree.static(...)
: the value of the field cannot be a JAX traced value, it must be a Python literal, or a constant). The values of static fields are embedded in thePyTreeDef
of any instance of the class.Pytree.field(...)
or no annotation: the value may be a JAX traced value, and JAX will attempt to convert it to tracer values inside of its transformations.
If a field points to another Pytree
, it should not be declared as
Pytree.static()
, as the Pytree
interface will automatically handle
the Pytree
fields as dynamic fields.
dataclass
staticmethod
¶
Denote that a class (which is inheriting Pytree
) should be treated
as a dataclass, meaning it can hold data in fields which are
declared as part of the class.
A dataclass is to be distinguished from a "methods only"
Pytree
class, which does not have fields, but may define methods.
The latter cannot be instantiated, but can be inherited from,
while the former can be instantiated:
the Pytree.dataclass
declaration informs the system how
to instantiate the class as a dataclass,
and how to automatically define JAX's Pytree
interfaces
(tree_flatten
, tree_unflatten
, etc.) for the dataclass, based
on the fields declared in the class, and possibly Pytree.static(...)
or Pytree.field(...)
annotations (or lack thereof, the default is
that all fields are Pytree.field(...)
).
All Pytree
dataclasses support pretty printing, as well as rendering
to HTML.
Examples¶
from genjax import Pytree from jaxtyping import ArrayLike import jax.numpy as jnp
@Pytree.dataclass ... class MyClass(Pytree): ... my_static_field: int = Pytree.static() ... my_dynamic_field: ArrayLike
instance = MyClass(10, jnp.array(5.0)) instance.my_static_field 10 instance.my_dynamic_field # doctest: +ELLIPSIS Array(5., dtype=float32...)
Source code in src/genjax/core.py
static
staticmethod
¶
Declare a field of a Pytree
dataclass to be static.
Users can provide additional keyword argument options,
like default
or default_factory
, to customize how the field is
instantiated when an instance of
the dataclass is instantiated.` Fields which are provided with default
values must come after required fields in the dataclass declaration.
Examples¶
from genjax import Pytree from jaxtyping import ArrayLike import jax.numpy as jnp
@Pytree.dataclass ... class MyClass(Pytree): ... my_dynamic_field: ArrayLike ... my_static_field: int = Pytree.static(default=0)
instance = MyClass(jnp.array(5.0)) instance.my_static_field 0 instance.my_dynamic_field # doctest: +ELLIPSIS Array(5., dtype=float32...)
Source code in src/genjax/core.py
field
staticmethod
¶
Declare a field of a Pytree
dataclass to be dynamic.
Alternatively, one can leave the annotation off in the declaration.
Const
¶
Bases: Generic[A]
, Pytree
A Pytree wrapper for Python literals that should remain static.
This class wraps Python values that need to stay as literals (not become JAX tracers) when used inside JAX transformations. The wrapped value is marked as static, ensuring it's embedded in the PyTreeDef rather than becoming a traced value.
Example
NotFixedException
¶
Bases: Exception
Exception raised when trace verification finds unfixed values.
This exception provides a clear visualization of which parts of the choice map are properly fixed (True) vs. unfixed (False), helping users debug model structure issues during constrained inference.
Source code in src/genjax/core.py
Fixed
¶
Bases: Generic[A]
, Pytree
A Pytree wrapper that denotes a random choice was provided (fixed), not proposed by a GFI's internal proposal family.
This wrapper is used internally by Distribution implementations in
generate
, update
, and regenerate
methods to mark values that
were constrained or provided externally rather than sampled from the
distribution's internal proposal.
The Fixed
wrapper helps debug model structure issues during inference
by tracking which random choices were externally constrained vs.
internally proposed.
Trace
¶
Bases: Generic[X, R]
, Pytree
get_fixed_choices
abstractmethod
¶
Get choices preserving Fixed wrappers.
Returns the raw choice structure with Fixed wrappers intact, used for verification that values were constrained during inference.
verify
¶
Verify that all leaf values in the trace choices were fixed (constrained).
Checks that all random choices in the trace are wrapped with Fixed, indicating they were provided externally rather than proposed by the GFI's internal proposal family. This helps debug model structure issues during inference.
Source code in src/genjax/core.py
Tr
¶
Concrete implementation of the Trace interface.
Stores all components of an execution trace including the generative function, arguments, random choices, return value, and score.
TupleSel
¶
Bases: Pytree
Selection that matches a hierarchical tuple address.
Tuple addresses represent hierarchical paths like ("outer", "inner", "leaf"). When matched against a single string address, it checks if that string matches the first element of the tuple, and returns a selection for the remaining path.
Selection
¶
Bases: Pytree
A Selection acts as a filter to specify which random choices in a trace should
be regenerated during the regenerate
method call.
Selections are used in inference algorithms like MCMC to specify which subset of random choices should be updated while keeping others fixed. The Selection determines which addresses (random choice names) match the selection criteria.
A Selection wraps one of several concrete selection types: - StrSel: Matches a specific string address - DictSel: Matches addresses using a dictionary mapping - AllSel: Matches all addresses - NoneSel: Matches no addresses - ComplSel: Matches the complement of another selection - InSel: Matches the intersection of two selections - OrSel: Matches the union of two selections
Example
from genjax.core import sel
# Select a specific address
selection = sel("x") # Matches address "x"
# Select all addresses
selection = sel(()) # Matches all addresses
# Select nested addresses
selection = sel({"outer": sel("inner")}) # Matches "outer"/"inner"
# Use in regenerate
new_trace, weight, discard = gen_fn.regenerate(args, trace, selection)
GFI
¶
Bases: Generic[X, R]
, Pytree
Generative Function Interface - the core abstraction for probabilistic programs.
The GFI defines the standard interface that all generative functions must implement. It provides methods for simulation, assessment, generation, updating, and regeneration of probabilistic computations.
Mathematical Foundation: A generative function bundles three mathematical objects: 1. Measure kernel P(dx; args) - the probability distribution over choices 2. Return value function f(x, args) -> R - deterministic computation from choices 3. Internal proposal family Q(dx; args, context) - for efficient inference
The GFI methods provide access to these mathematical objects and enable: - Forward sampling (simulate) - Density evaluation (assess) - Constrained generation (generate) - Edit moves (update, regenerate)
All density computations are in log space for numerical stability. Weights from generate/update/regenerate enable importance sampling and MCMC.
Type Parameters
X: The type of the random choices (choice map). R: The type of the return value.
Core Methods
simulate: Sample (choices, retval) ~ P(·; args) assess: Compute log P(choices; args) generate: Sample with constraints, return importance weight update: Update trace arguments/choices, return incremental importance weight regenerate: Resample selected choices, return incremental importance weight
Additional Methods
merge: Combine choice maps (for compositional functions) log_density: Convenience method for assess that sums log densities vmap/repeat: Vectorization combinators cond: Conditional execution combinator
simulate
abstractmethod
¶
Sample an execution trace from the generative function.
Mathematical specification: - Samples (choices, retval) ~ P(·; args) where P is the generative function's measure kernel - Returns trace containing choices, return value, score, and arguments
The score in the returned trace is log(1/P(choices; args)), i.e., the negative log probability density of the sampled choices.
Example
model.simulate(mu, sigma) # Example usage¶
choices = trace.get_choices()¶
score = trace.get_score() # -log P(choices; mu, sigma)¶
pass # doctest placeholder
Source code in src/genjax/core.py
generate
abstractmethod
¶
Generate a trace with optional constraints on some choices.
Mathematical specification: - Samples unconstrained choices ~ Q(·; constrained_choices, args) - Computes importance weight: log [P(all_choices; args) / Q(unconstrained_choices; constrained_choices, args)] - When x=None, equivalent to simulate() but returns weight=0
The weight enables importance sampling and is crucial for inference algorithms. For fully constrained generation, the weight equals the log density.
Example
Constrain some choices¶
constraints =¶
trace, weight = model.generate(constraints, mu, sigma)¶
weight accounts for probability of constrained choices¶
pass # doctest placeholder
Source code in src/genjax/core.py
assess
abstractmethod
¶
Compute the log probability density of given choices.
Mathematical specification: - Computes log P(choices; args) where P is the generative function's measure kernel - Also computes the return value for the given choices - Requires P(choices; args) > 0 (choices must be valid)
Example
log_density, retval = model.assess(choices, mu, sigma)¶
log_density = log P(choices; mu, sigma)¶
pass # doctest placeholder
Source code in src/genjax/core.py
update
abstractmethod
¶
Update a trace with new arguments and/or choice constraints.
Mathematical specification: - Transforms trace from (old_args, old_choices) to (new_args, new_choices) - Computes incremental importance weight (edit move):
weight = log [P(new_choices; new_args) / Q(new_choices; new_args, old_choices, constraints)] - log [P(old_choices; old_args) / Q(old_choices; old_args)]
where Q is the internal proposal distribution used for updating.
Example
Update trace with new arguments¶
new_trace, weight, discarded = model.update(old_trace, None, new_mu, new_sigma)¶
weight = log P(new_choices; new_args) - log P(old_choices; old_args)¶
pass # doctest placeholder
Source code in src/genjax/core.py
regenerate
abstractmethod
¶
regenerate(tr: Trace[X, R], sel: Selection, *args, **kwargs) -> tuple[Trace[X, R], Weight, X | None]
Regenerate selected choices in a trace while keeping others fixed.
Mathematical specification: - Resamples choices at addresses selected by 'sel' from their conditional distribution - Keeps non-selected choices unchanged - Computes incremental importance weight (edit move):
weight = log P(new_selected_choices | non_selected_choices; args) - log P(old_selected_choices | non_selected_choices; args)
When sel selects all addresses, regenerate becomes equivalent to simulate. When sel selects no addresses, weight = 0 and trace unchanged.
Example
Regenerate choices at addresses "x" and "y"¶
selection = sel("x") | sel("y")¶
new_trace, weight, discarded = model.regenerate(trace, selection, mu, sigma)¶
weight accounts for probability change due to regeneration¶
pass # doctest placeholder
Source code in src/genjax/core.py
merge
abstractmethod
¶
Merge two choice maps, with the second taking precedence.
Used internally for compositional generative functions where choice maps from different components need to be combined. The merge operation resolves conflicts by preferring choices from x_ over x.
Source code in src/genjax/core.py
filter
abstractmethod
¶
Filter choice map into selected and unselected parts.
Used to partition choices based on a selection, enabling fine-grained manipulation of subsets of choices in inference algorithms. Each GFI implementation specializes this method for its choice type X.
Example
choices =¶
selection = sel("mu") | sel("sigma")¶
selected, unselected = model.filter(choices, selection)¶
selected = {"mu": 1.0, "sigma": 2.0}, unselected = {"obs": 3.0}¶
pass # doctest placeholder
Source code in src/genjax/core.py
Thunk
¶
Bases: Generic[X, R]
, Pytree
Delayed evaluation wrapper for generative functions.
A thunk represents a generative function call that has not yet been executed. It captures the function and its arguments for later evaluation.
Vmap
¶
Bases: Generic[X, R]
, GFI[X, R]
A Vmap
is a generative function combinator that vectorizes another generative function.
Vmap
applies a generative function across a batch dimension, similar to jax.vmap
,
but preserves probabilistic semantics. It uses GenJAX's modular_vmap
to handle
the vectorization of probabilistic computations correctly.
Mathematical ingredients: - If callee has measure kernel P_callee(dx; args), then Vmap has kernel P_vmap(dX; Args) = ∏_i P_callee(dx_i; args_i) where X = [x_1, ..., x_n] - Return value function f_vmap(X, Args) = [f_callee(x_1, args_1), ..., f_callee(x_n, args_n)] - Internal proposal family inherits from callee's proposal family
Example
from genjax import normal
Vectorize a normal distribution¶
vectorized_normal = normal.vmap(in_axes=(0, None)) # vectorize over first arg
mus = jnp.array([0.0, 1.0, 2.0]) sigma = 1.0 trace = vectorized_normal.simulate(mus, sigma) samples = trace.get_choices() # Array of 3 normal samples
filter
¶
Filter vectorized choices using the underlying generative function's filter.
For Vmap, choices are vectorized across the batch dimension. We apply the underlying GF's filter to each vectorized choice.
Source code in src/genjax/core.py
Distribution
¶
Bases: Generic[X]
, GFI[X, X]
A Distribution
is a generative function that implements a probability distribution.
Distributions are the fundamental building blocks of probabilistic programs. They implement the Generative Function Interface (GFI) by wrapping a sampling function and a log probability density function (logpdf).
Mathematical ingredients: - A measure kernel P(dx; args) over a measurable space X given arguments args - Return value function f(x, args) = x (identity function for distributions) - Internal proposal distribution family Q(dx; args, x') = P(dx; args) (prior)
Example
import jax import jax.numpy as jnp from genjax import Distribution, const
Create a custom normal distribution¶
def sample_normal(mu, sigma): ... key = jax.random.PRNGKey(0) # In practice, use proper key management ... return mu + sigma * jax.random.normal(key)
def logpdf_normal(x, mu, sigma): ... return -0.5 * ((x - mu) / sigma)**2 - jnp.log(sigma) - 0.5 * jnp.log(2 * jnp.pi)
normal = Distribution(const(sample_normal), const(logpdf_normal), const("normal")) trace = normal.simulate(0.0, 1.0) # mu=0.0, sigma=1.0
sample
¶
logpdf
¶
merge
¶
Merge distribution choices with optional conditional selection.
For distributions, choices are raw values from the sample space. When check is provided, we use jnp.where for conditional selection.
Source code in src/genjax/core.py
filter
¶
Filter choice into selected and unselected parts.
For Distribution, the choice is a single value X. Selection either matches the empty address () or it doesn't.
Source code in src/genjax/core.py
Simulate
dataclass
¶
Handler for simulating generative function executions.
Tracks the accumulated score and trace map during simulation.
Fn
¶
Bases: Generic[R]
, GFI[dict[str, Any], R]
A Fn
is a generative function created from a JAX Python function
using the @gen
decorator.
Fn
implements the GFI by executing the wrapped function in different execution contexts
(handlers) that intercept calls to other generative functions via the @
addressing syntax.
Mathematical ingredients: - Measure kernel P(dx; args) defined by the composition of distributions in the function - Return value function f(x, args) defined by the function's logic and return statement - Internal proposal distribution family Q(dx; args, x') defined by ancestral sampling
The choice space X is a dictionary mapping addresses (strings) to the choices made at those addresses during execution.
Example
import jax.numpy as jnp from genjax import gen, normal
@gen def linear_regression(xs): ... slope = normal(0.0, 1.0) @ "slope" ... intercept = normal(0.0, 1.0) @ "intercept" ... noise = normal(0.0, 0.1) @ "noise" ... return normal(slope * xs + intercept, noise) @ "y"
trace = linear_regression.simulate(jnp.array([1.0, 2.0, 3.0])) choices = trace.get_choices() # dict with keys "slope", "intercept", "noise", "y"
filter
¶
filter(x: dict[str, Any], selection: Selection) -> tuple[dict[str, Any] | None, dict[str, Any] | None]
Filter choice map into selected and unselected parts.
For Fn, choices are stored as dict[str, Any] with string addresses.
Source code in src/genjax/core.py
ScanTr
¶
Scan
¶
Bases: Generic[X, R]
, GFI[X, R]
A Scan
is a generative function combinator that implements sequential iteration.
Scan
repeatedly applies a generative function in a sequential loop, similar to
jax.lax.scan
, but preserves probabilistic semantics. The callee function should
take (carry, x) as input and return (new_carry, output).
Mathematical ingredients: - If callee has measure kernel P_callee(dx; carry, x), then Scan has kernel P_scan(dX; init_carry, xs) = ∏i P_callee(dx_i; carry_i, xs_i) where carry = f_callee(x_i, carry_i, xs_i)[0] - Return value function returns (final_carry, [output_1, ..., output_n]) - Internal proposal family inherits from callee's proposal family
Example
from genjax import gen, normal, Scan, seed, const import jax.numpy as jnp import jax.random as jrand
@gen def step(carry, x): ... noise = normal(0.0, 0.1) @ "noise" ... new_carry = carry + x + noise ... return new_carry, new_carry # output equals new carry
scan_fn = Scan(step, length=const(3)) init_carry = 0.0 xs = jnp.array([1.0, 2.0, 3.0])
Use seed transformation for PJAX primitives¶
key = jrand.key(0) trace = seed(scan_fn.simulate)(key, init_carry, xs) final_carry, outputs = trace.get_retval() assert len(outputs) == 3 # Should have 3 outputs
filter
¶
Filter scan choices using the underlying generative function's filter.
For Scan, choices are structured according to the scan iterations. We delegate to the underlying callee's filter method.
Source code in src/genjax/core.py
CondTr
¶
Cond
¶
Bases: Generic[X, R]
, GFI[X, R]
A Cond
is a generative function combinator that implements conditional branching.
Cond
takes a boolean condition and executes one of two generative functions
based on the condition, similar to jax.lax.cond
, but preserves probabilistic
semantics by evaluating both branches and selecting the appropriate one.
Mathematical ingredients: - If branches have measure kernels P_true(dx; args) and P_false(dx; args), then Cond has kernel P_cond(dx; check, args) = P_true(dx; args) if check else P_false(dx; args) - Return value function f_cond(x, check, args) = f_true(x, args) if check else f_false(x, args) - Internal proposal family selects appropriate branch proposal based on condition
Note: Both branches are always evaluated during simulation/generation to maintain JAX compatibility, but only the appropriate branch contributes to the final result.
Example
from genjax import gen, normal, exponential, Cond
@gen def positive_branch(): ... return exponential(1.0) @ "value"
@gen def negative_branch(): ... return exponential(2.0) @ "value"
cond_fn = Cond(positive_branch, negative_branch)
Use in a larger model¶
@gen def conditional_model(): ... x = normal(0.0, 1.0) @ "x" ... condition = x > 0 ... result = cond_fn((condition,)) @ "conditional" ... return result
filter
¶
Filter conditional choices using the underlying generative function's filter.
For Cond, choices are determined by which branch was executed. We delegate to the first callee's filter method.
Source code in src/genjax/core.py
const
¶
Create a Const wrapper for a static value.
Example
from genjax import const, Const
# Create a static value
length = const(10)
print(f"Value: {length.value}")
print(f"Type: {type(length)}")
# Use in arithmetic
doubled = length * 2
print(f"Doubled: {doubled.value}")
# Use as static parameter
# @gen
# def model(n: Const[int]):
# # n.value is guaranteed to be Python int, not JAX tracer
# for i in range(n.value): # This works in JAX transforms!
# ...
Source code in src/genjax/core.py
fixed
¶
Create a Fixed wrapper for a constrained value.
Source code in src/genjax/core.py
get_choices
¶
Extract choices from a trace or nested structure containing traces.
Also strips Fixed wrappers from the choices, returning the unwrapped values. Fixed wrappers are used internally to track constrained vs. proposed values.
Source code in src/genjax/core.py
get_fixed_choices
¶
Extract choices from a trace or nested structure containing traces, preserving Fixed wrappers.
Similar to get_choices() but preserves Fixed wrappers around the choices, which is needed for verification that values were constrained during inference.
Source code in src/genjax/core.py
get_score
¶
Extract the log probability score from a trace.
get_retval
¶
Extract the return value from a trace.
sel
¶
Create a Selection from various input types.
This is a convenience function to create Selection objects from common patterns. Selections specify which random choices in a trace should be regenerated during inference operations like MCMC.
Examples:
from genjax import sel
# Select specific address
s1 = sel("x")
print(f"sel('x'): {s1}")
# Select hierarchical address
s2 = sel(("outer", "inner"))
print(f"sel(('outer', 'inner')): {s2}")
# Select all addresses
s3 = sel(())
print(f"sel(()): {s3}")
# Select no addresses
s4 = sel()
print(f"sel(): {s4}")
# Combine selections with OR
s5 = sel("x") | sel("y")
print(f"sel('x') | sel('y'): {s5}")
# Complement selection
s6 = ~sel("x")
print(f"~sel('x'): {s6}")
Source code in src/genjax/core.py
distribution
¶
distribution(sampler: Callable[..., Any], logpdf: Callable[..., Any], /, name: str | None = None) -> Distribution[Any]
Create a Distribution from sampling and log probability functions.
Source code in src/genjax/core.py
tfp_distribution
¶
tfp_distribution(dist: Callable[..., Distribution], /, name: str | None = None) -> Distribution[Any]
Create a Distribution from a TensorFlow Probability distribution.
Wraps a TFP distribution constructor to create a GenJAX Distribution
that properly handles PJAX's sample_p
primitive.
Example
import tensorflow_probability.substrates.jax as tfp from genjax import tfp_distribution
Create a normal distribution from TFP¶
normal = tfp_distribution(tfp.distributions.Normal, name="normal")
Source code in src/genjax/core.py
gen
¶
Convert a function into a generative function.
The decorated function can use the @
operator to make addressed
random choices from distributions and other generative functions.
Example
from genjax import gen, normal
@gen ... def model(mu, sigma): ... x = normal(mu, sigma) @ "x" ... y = normal(x, 0.1) @ "y" ... return x + y
trace = model.simulate(0.0, 1.0) choices = trace.get_choices()
choices will contain¶
Source code in src/genjax/core.py
Live Examples¶
Basic Model Definition¶
import jax
import jax.numpy as jnp
from genjax import gen, distributions
@gen
def coin_flip_model(n_flips):
"""A simple coin flipping model with unknown bias."""
bias = distributions.beta(1.0, 1.0) @ "bias"
# For demonstration, we'll show manual unrolling
# In practice, use Scan combinator for loops
flip_0 = distributions.bernoulli(bias) @ "flip_0"
flip_1 = distributions.bernoulli(bias) @ "flip_1"
flip_2 = distributions.bernoulli(bias) @ "flip_2"
return jnp.array([flip_0, flip_1, flip_2])
print("Model defined successfully!")
Model defined successfully!
Assessing Log Probability¶
import jax
import jax.numpy as jnp
from genjax import gen, distributions
@gen
def coin_flip_model(n_flips):
"""A simple coin flipping model with unknown bias."""
bias = distributions.beta(1.0, 1.0) @ "bias"
# For demonstration, we'll show manual unrolling
# In practice, use Scan combinator for loops
flip_0 = distributions.bernoulli(bias) @ "flip_0"
flip_1 = distributions.bernoulli(bias) @ "flip_1"
flip_2 = distributions.bernoulli(bias) @ "flip_2"
return jnp.array([flip_0, flip_1, flip_2])
# Assess the log probability of specific choices
choices = {"bias": 0.7, "flip_0": 1, "flip_1": 1, "flip_2": 0}
log_prob, retval = coin_flip_model.assess(choices, 3)
print(f"Given choices: {choices}")
print(f"Log probability: {log_prob:.3f}")
print(f"Return value (flips): {retval}")
Given choices: {'bias': 0.7, 'flip_0': 1, 'flip_1': 1, 'flip_2': 0} Log probability: -1.910 Return value (flips): [1 1 0]
Using Selections¶
from genjax import sel, Selection
# Create various selections
s1 = sel("bias") # Select only bias
s2 = sel("flip_0") | sel("flip_1") # Select two flips with OR
s3 = sel("bias") | sel("flip_2") # Select bias OR flip_2
print(f"Selection s1 targets: bias")
print(f"Selection s2 targets: flip_0 or flip_1")
print(f"Selection s3 targets: bias or flip_2")
Selection s1 targets: bias Selection s2 targets: flip_0 or flip_1 Selection s3 targets: bias or flip_2