Skip to content

genjax.core

Core functionality for GenJAX including the Generative Function Interface, traces, and model construction.

Mathematical Foundation

The Generative Function Interface (GFI) is based on measure theory. A generative function \(g\) defines:

  • A measure kernel \(P(dx; \text{args})\) over measurable space \(X\)
  • A return value function \(f(x, \text{args}) \rightarrow R\)
  • Internal proposal family \(Q(dx'; \text{args}, x)\)

The importance weight from generate is:

\[w = \log \frac{P(\text{all\_choices})}{Q(\text{free\_choices} | \text{constrained\_choices})}\]

core

Pytree

Bases: Struct

Pytree is an abstract base class which registers a class with JAX's Pytree system. JAX's Pytree system tracks how data classes should behave across JAX-transformed function boundaries, like jax.jit or jax.vmap.

Inheriting this class provides the implementor with the freedom to declare how the subfields of a class should behave:

  • Pytree.static(...): the value of the field cannot be a JAX traced value, it must be a Python literal, or a constant). The values of static fields are embedded in the PyTreeDef of any instance of the class.
  • Pytree.field(...) or no annotation: the value may be a JAX traced value, and JAX will attempt to convert it to tracer values inside of its transformations.

If a field points to another Pytree, it should not be declared as Pytree.static(), as the Pytree interface will automatically handle the Pytree fields as dynamic fields.

dataclass staticmethod

dataclass(incoming: None = None, /, **kwargs) -> Callable[[type[R]], type[R]]
dataclass(incoming: type[R], /, **kwargs) -> type[R]
dataclass(incoming: type[R] | None = None, /, **kwargs) -> type[R] | Callable[[type[R]], type[R]]

Denote that a class (which is inheriting Pytree) should be treated as a dataclass, meaning it can hold data in fields which are declared as part of the class.

A dataclass is to be distinguished from a "methods only" Pytree class, which does not have fields, but may define methods. The latter cannot be instantiated, but can be inherited from, while the former can be instantiated: the Pytree.dataclass declaration informs the system how to instantiate the class as a dataclass, and how to automatically define JAX's Pytree interfaces (tree_flatten, tree_unflatten, etc.) for the dataclass, based on the fields declared in the class, and possibly Pytree.static(...) or Pytree.field(...) annotations (or lack thereof, the default is that all fields are Pytree.field(...)).

All Pytree dataclasses support pretty printing, as well as rendering to HTML.

Examples

from genjax import Pytree from jaxtyping import ArrayLike import jax.numpy as jnp

@Pytree.dataclass ... class MyClass(Pytree): ... my_static_field: int = Pytree.static() ... my_dynamic_field: ArrayLike

instance = MyClass(10, jnp.array(5.0)) instance.my_static_field 10 instance.my_dynamic_field # doctest: +ELLIPSIS Array(5., dtype=float32...)

Source code in src/genjax/core.py
@dataclass_transform(
    frozen_default=True,
)
@staticmethod
def dataclass(
    incoming: type[R] | None = None,
    /,
    **kwargs,
) -> type[R] | Callable[[type[R]], type[R]]:
    """
    Denote that a class (which is inheriting `Pytree`) should be treated
    as a dataclass, meaning it can hold data in fields which are
    declared as part of the class.

    A dataclass is to be distinguished from a "methods only"
    `Pytree` class, which does not have fields, but may define methods.
    The latter cannot be instantiated, but can be inherited from,
    while the former can be instantiated:
    the `Pytree.dataclass` declaration informs the system _how
    to instantiate_ the class as a dataclass,
    and how to automatically define JAX's `Pytree` interfaces
    (`tree_flatten`, `tree_unflatten`, etc.) for the dataclass, based
    on the fields declared in the class, and possibly `Pytree.static(...)`
    or `Pytree.field(...)` annotations (or lack thereof, the default is
    that all fields are `Pytree.field(...)`).

    All `Pytree` dataclasses support pretty printing, as well as rendering
    to HTML.

    Examples
    --------

    >>> from genjax import Pytree
    >>> from jaxtyping import ArrayLike
    >>> import jax.numpy as jnp
    >>>
    >>> @Pytree.dataclass
    ... class MyClass(Pytree):
    ...     my_static_field: int = Pytree.static()
    ...     my_dynamic_field: ArrayLike
    >>>
    >>> instance = MyClass(10, jnp.array(5.0))
    >>> instance.my_static_field
    10
    >>> instance.my_dynamic_field  # doctest: +ELLIPSIS
    Array(5., dtype=float32...)
    """

    return pz.pytree_dataclass(
        incoming,
        overwrite_parent_init=True,
        **kwargs,
    )

static staticmethod

static(**kwargs)

Declare a field of a Pytree dataclass to be static. Users can provide additional keyword argument options, like default or default_factory, to customize how the field is instantiated when an instance of the dataclass is instantiated.` Fields which are provided with default values must come after required fields in the dataclass declaration.

Examples

from genjax import Pytree from jaxtyping import ArrayLike import jax.numpy as jnp

@Pytree.dataclass ... class MyClass(Pytree): ... my_dynamic_field: ArrayLike ... my_static_field: int = Pytree.static(default=0)

instance = MyClass(jnp.array(5.0)) instance.my_static_field 0 instance.my_dynamic_field # doctest: +ELLIPSIS Array(5., dtype=float32...)

Source code in src/genjax/core.py
@staticmethod
def static(**kwargs):
    """Declare a field of a `Pytree` dataclass to be static.
    Users can provide additional keyword argument options,
    like `default` or `default_factory`, to customize how the field is
    instantiated when an instance of
    the dataclass is instantiated.` Fields which are provided with default
    values must come after required fields in the dataclass declaration.

    Examples
    --------

    >>> from genjax import Pytree
    >>> from jaxtyping import ArrayLike
    >>> import jax.numpy as jnp
    >>>
    >>> @Pytree.dataclass
    ... class MyClass(Pytree):
    ...     my_dynamic_field: ArrayLike
    ...     my_static_field: int = Pytree.static(default=0)
    >>>
    >>> instance = MyClass(jnp.array(5.0))
    >>> instance.my_static_field
    0
    >>> instance.my_dynamic_field  # doctest: +ELLIPSIS
    Array(5., dtype=float32...)

    """
    return field(metadata={"pytree_node": False}, **kwargs)

field staticmethod

field(**kwargs)

Declare a field of a Pytree dataclass to be dynamic. Alternatively, one can leave the annotation off in the declaration.

Source code in src/genjax/core.py
@staticmethod
def field(**kwargs):
    """Declare a field of a `Pytree` dataclass to be dynamic.
    Alternatively, one can leave the annotation off in the declaration."""
    return field(**kwargs)

Const

Bases: Generic[A], Pytree

A Pytree wrapper for Python literals that should remain static.

This class wraps Python values that need to stay as literals (not become JAX tracers) when used inside JAX transformations. The wrapped value is marked as static, ensuring it's embedded in the PyTreeDef rather than becoming a traced value.

Example
# Instead of: n_steps: int (becomes tracer in JAX transforms)
# Use: n_steps: Const[int] (stays as Python int)

def my_function(n_steps: Const[int]):
    for i in range(n_steps.value):  # n_steps.value is Python int
        ...

NotFixedException

NotFixedException(choice_map_status: X)

Bases: Exception

Exception raised when trace verification finds unfixed values.

This exception provides a clear visualization of which parts of the choice map are properly fixed (True) vs. unfixed (False), helping users debug model structure issues during constrained inference.

Source code in src/genjax/core.py
def __init__(self, choice_map_status: X):
    self.choice_map_status = choice_map_status
    super().__init__(self._format_message())

Fixed

Bases: Generic[A], Pytree

A Pytree wrapper that denotes a random choice was provided (fixed), not proposed by a GFI's internal proposal family.

This wrapper is used internally by Distribution implementations in generate, update, and regenerate methods to mark values that were constrained or provided externally rather than sampled from the distribution's internal proposal.

The Fixed wrapper helps debug model structure issues during inference by tracking which random choices were externally constrained vs. internally proposed.

Trace

Bases: Generic[X, R], Pytree

get_fixed_choices abstractmethod

get_fixed_choices() -> X

Get choices preserving Fixed wrappers.

Returns the raw choice structure with Fixed wrappers intact, used for verification that values were constrained during inference.

Source code in src/genjax/core.py
@abstractmethod
def get_fixed_choices(self) -> X:
    """Get choices preserving Fixed wrappers.

    Returns the raw choice structure with Fixed wrappers intact,
    used for verification that values were constrained during inference.
    """
    pass

verify

verify() -> None

Verify that all leaf values in the trace choices were fixed (constrained).

Checks that all random choices in the trace are wrapped with Fixed, indicating they were provided externally rather than proposed by the GFI's internal proposal family. This helps debug model structure issues during inference.

Source code in src/genjax/core.py
def verify(self) -> None:
    """Verify that all leaf values in the trace choices were fixed (constrained).

    Checks that all random choices in the trace are wrapped with Fixed,
    indicating they were provided externally rather than proposed by the
    GFI's internal proposal family. This helps debug model structure issues
    during inference.

    Raises:
        NotFixedException: If any leaf value is not wrapped with Fixed.
                          The exception includes a detailed choice map showing
                          which values are fixed vs. unfixed.
    """
    # Get choices preserving Fixed wrappers
    choice_values = get_fixed_choices(self)

    # Check if value is Fixed
    def check_instance_fixed(x):
        return isinstance(x, Fixed)

    # Flatten the tree to get all leaf choice values
    leaf_values, tree_def = jtu.tree_flatten(
        choice_values, is_leaf=check_instance_fixed
    )

    # Check if all leaves are Fixed
    all_fixed = all(isinstance(leaf, Fixed) for leaf in leaf_values)

    if not all_fixed:
        # Create a boolean choice map showing which values are fixed
        def make_bool_status(x):
            if isinstance(x, Fixed):
                return True
            else:
                return False

        choice_map_status = jtu.tree_map(
            make_bool_status, choice_values, is_leaf=check_instance_fixed
        )

        raise NotFixedException(choice_map_status)

Tr

Bases: Trace[X, R], Pytree

Concrete implementation of the Trace interface.

Stores all components of an execution trace including the generative function, arguments, random choices, return value, and score.

get_fixed_choices

get_fixed_choices() -> X

Get choices preserving Fixed wrappers.

Source code in src/genjax/core.py
def get_fixed_choices(self) -> X:
    """Get choices preserving Fixed wrappers."""
    return get_fixed_choices(self._choices)

AllSel

Bases: Pytree

Selection that matches all addresses.

NoneSel

Bases: Pytree

Selection that matches no addresses.

StrSel

Bases: Pytree

Selection that matches a specific string address.

TupleSel

Bases: Pytree

Selection that matches a hierarchical tuple address.

Tuple addresses represent hierarchical paths like ("outer", "inner", "leaf"). When matched against a single string address, it checks if that string matches the first element of the tuple, and returns a selection for the remaining path.

DictSel

Bases: Pytree

Selection that matches addresses using a dictionary mapping.

ComplSel

Bases: Pytree

Selection that matches the complement of another selection.

InSel

Bases: Pytree

Selection representing intersection of two selections.

Selection

Bases: Pytree

A Selection acts as a filter to specify which random choices in a trace should be regenerated during the regenerate method call.

Selections are used in inference algorithms like MCMC to specify which subset of random choices should be updated while keeping others fixed. The Selection determines which addresses (random choice names) match the selection criteria.

A Selection wraps one of several concrete selection types: - StrSel: Matches a specific string address - DictSel: Matches addresses using a dictionary mapping - AllSel: Matches all addresses - NoneSel: Matches no addresses - ComplSel: Matches the complement of another selection - InSel: Matches the intersection of two selections - OrSel: Matches the union of two selections

Example
from genjax.core import sel

# Select a specific address
selection = sel("x")  # Matches address "x"

# Select all addresses
selection = sel(())   # Matches all addresses

# Select nested addresses
selection = sel({"outer": sel("inner")})  # Matches "outer"/"inner"

# Use in regenerate
new_trace, weight, discard = gen_fn.regenerate(args, trace, selection)

GFI

Bases: Generic[X, R], Pytree

Generative Function Interface - the core abstraction for probabilistic programs.

The GFI defines the standard interface that all generative functions must implement. It provides methods for simulation, assessment, generation, updating, and regeneration of probabilistic computations.

Mathematical Foundation: A generative function bundles three mathematical objects: 1. Measure kernel P(dx; args) - the probability distribution over choices 2. Return value function f(x, args) -> R - deterministic computation from choices 3. Internal proposal family Q(dx; args, context) - for efficient inference

The GFI methods provide access to these mathematical objects and enable: - Forward sampling (simulate) - Density evaluation (assess) - Constrained generation (generate) - Edit moves (update, regenerate)

All density computations are in log space for numerical stability. Weights from generate/update/regenerate enable importance sampling and MCMC.

Type Parameters

X: The type of the random choices (choice map). R: The type of the return value.

Core Methods

simulate: Sample (choices, retval) ~ P(·; args) assess: Compute log P(choices; args) generate: Sample with constraints, return importance weight update: Update trace arguments/choices, return incremental importance weight regenerate: Resample selected choices, return incremental importance weight

Additional Methods

merge: Combine choice maps (for compositional functions) log_density: Convenience method for assess that sums log densities vmap/repeat: Vectorization combinators cond: Conditional execution combinator

simulate abstractmethod

simulate(*args, **kwargs) -> Trace[X, R]

Sample an execution trace from the generative function.

Mathematical specification: - Samples (choices, retval) ~ P(·; args) where P is the generative function's measure kernel - Returns trace containing choices, return value, score, and arguments

The score in the returned trace is log(1/P(choices; args)), i.e., the negative log probability density of the sampled choices.

Example
model.simulate(mu, sigma) # Example usage
choices = trace.get_choices()
score = trace.get_score() # -log P(choices; mu, sigma)

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def simulate(
    self,
    *args,
    **kwargs,
) -> Trace[X, R]:
    """Sample an execution trace from the generative function.

    Mathematical specification:
    - Samples (choices, retval) ~ P(·; args) where P is the generative function's measure kernel
    - Returns trace containing choices, return value, score, and arguments

    The score in the returned trace is log(1/P(choices; args)), i.e., the negative
    log probability density of the sampled choices.

    Args:
        *args: Arguments to the generative function.
        **kwargs: Keyword arguments to the generative function.

    Returns:
        A trace containing the sampled choices, return value, score, and arguments.

    Example:
        >>> # model.simulate(mu, sigma)  # Example usage
        >>> # choices = trace.get_choices()
        >>> # score = trace.get_score()  # -log P(choices; mu, sigma)
        >>> pass  # doctest placeholder
    """
    pass

generate abstractmethod

generate(x: X | None, *args, **kwargs) -> tuple[Trace[X, R], Weight]

Generate a trace with optional constraints on some choices.

Mathematical specification: - Samples unconstrained choices ~ Q(·; constrained_choices, args) - Computes importance weight: log [P(all_choices; args) / Q(unconstrained_choices; constrained_choices, args)] - When x=None, equivalent to simulate() but returns weight=0

The weight enables importance sampling and is crucial for inference algorithms. For fully constrained generation, the weight equals the log density.

Example
Constrain some choices
constraints =
trace, weight = model.generate(constraints, mu, sigma)
weight accounts for probability of constrained choices

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def generate(
    self,
    x: X | None,
    *args,
    **kwargs,
) -> tuple[Trace[X, R], Weight]:
    """Generate a trace with optional constraints on some choices.

    Mathematical specification:
    - Samples unconstrained choices ~ Q(·; constrained_choices, args)
    - Computes importance weight: log [P(all_choices; args) / Q(unconstrained_choices; constrained_choices, args)]
    - When x=None, equivalent to simulate() but returns weight=0

    The weight enables importance sampling and is crucial for inference algorithms.
    For fully constrained generation, the weight equals the log density.

    Args:
        x: Optional constraints on subset of choices. If None, equivalent to simulate.
        *args: Arguments to the generative function.
        **kwargs: Keyword arguments to the generative function.

    Returns:
        A tuple (trace, weight) where:
        - trace: contains all choices (constrained + sampled) and return value
        - weight: log [P(all_choices; args) / Q(unconstrained_choices; constrained_choices, args)]

    Example:
        >>> # Constrain some choices
        >>> # constraints = {"x": 1.5, "y": 2.0}
        >>> # trace, weight = model.generate(constraints, mu, sigma)
        >>> # weight accounts for probability of constrained choices
        >>> pass  # doctest placeholder
    """
    pass

assess abstractmethod

assess(x: X, *args, **kwargs) -> tuple[Density, R]

Compute the log probability density of given choices.

Mathematical specification: - Computes log P(choices; args) where P is the generative function's measure kernel - Also computes the return value for the given choices - Requires P(choices; args) > 0 (choices must be valid)

Example
log_density, retval = model.assess(choices, mu, sigma)
log_density = log P(choices; mu, sigma)

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def assess(
    self,
    x: X,
    *args,
    **kwargs,
) -> tuple[Density, R]:
    """Compute the log probability density of given choices.

    Mathematical specification:
    - Computes log P(choices; args) where P is the generative function's measure kernel
    - Also computes the return value for the given choices
    - Requires P(choices; args) > 0 (choices must be valid)

    Args:
        x: The choices to evaluate.
        *args: Arguments to the generative function.
        **kwargs: Keyword arguments to the generative function.

    Returns:
        A tuple (log_density, retval) where:
        - log_density: log P(choices; args)
        - retval: return value computed with the given choices

    Example:
        >>> # log_density, retval = model.assess(choices, mu, sigma)
        >>> # log_density = log P(choices; mu, sigma)
        >>> pass  # doctest placeholder
    """
    pass

update abstractmethod

update(tr: Trace[X, R], x_: X | None, *args, **kwargs) -> tuple[Trace[X, R], Weight, X | None]

Update a trace with new arguments and/or choice constraints.

Mathematical specification: - Transforms trace from (old_args, old_choices) to (new_args, new_choices) - Computes incremental importance weight (edit move):

weight = log [P(new_choices; new_args) / Q(new_choices; new_args, old_choices, constraints)] - log [P(old_choices; old_args) / Q(old_choices; old_args)]

where Q is the internal proposal distribution used for updating.

Example
Update trace with new arguments
new_trace, weight, discarded = model.update(old_trace, None, new_mu, new_sigma)
weight = log P(new_choices; new_args) - log P(old_choices; old_args)

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def update(
    self,
    tr: Trace[X, R],
    x_: X | None,
    *args,
    **kwargs,
) -> tuple[Trace[X, R], Weight, X | None]:
    """Update a trace with new arguments and/or choice constraints.

    Mathematical specification:
    - Transforms trace from (old_args, old_choices) to (new_args, new_choices)
    - Computes incremental importance weight (edit move):

    weight = log [P(new_choices; new_args) / Q(new_choices; new_args, old_choices, constraints)]
           - log [P(old_choices; old_args) / Q(old_choices; old_args)]

    where Q is the internal proposal distribution used for updating.

    Args:
        tr: Current trace to update.
        x_: Optional constraints on choices to enforce during update.
        *args: New arguments to the generative function.
        **kwargs: New keyword arguments to the generative function.

    Returns:
        A tuple (new_trace, weight, discarded_choices) where:
        - new_trace: updated trace with new arguments and choices
        - weight: incremental importance weight for the update (enables MCMC, SMC)
        - discarded_choices: old choice values that were changed

    Example:
        >>> # Update trace with new arguments
        >>> # new_trace, weight, discarded = model.update(old_trace, None, new_mu, new_sigma)
        >>> # weight = log P(new_choices; new_args) - log P(old_choices; old_args)
        >>> pass  # doctest placeholder
    """
    pass

regenerate abstractmethod

regenerate(tr: Trace[X, R], sel: Selection, *args, **kwargs) -> tuple[Trace[X, R], Weight, X | None]

Regenerate selected choices in a trace while keeping others fixed.

Mathematical specification: - Resamples choices at addresses selected by 'sel' from their conditional distribution - Keeps non-selected choices unchanged - Computes incremental importance weight (edit move):

weight = log P(new_selected_choices | non_selected_choices; args) - log P(old_selected_choices | non_selected_choices; args)

When sel selects all addresses, regenerate becomes equivalent to simulate. When sel selects no addresses, weight = 0 and trace unchanged.

Example
Regenerate choices at addresses "x" and "y"
selection = sel("x") | sel("y")
new_trace, weight, discarded = model.regenerate(trace, selection, mu, sigma)
weight accounts for probability change due to regeneration

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def regenerate(
    self,
    tr: Trace[X, R],
    sel: Selection,
    *args,
    **kwargs,
) -> tuple[Trace[X, R], Weight, X | None]:
    """Regenerate selected choices in a trace while keeping others fixed.

    Mathematical specification:
    - Resamples choices at addresses selected by 'sel' from their conditional distribution
    - Keeps non-selected choices unchanged
    - Computes incremental importance weight (edit move):

    weight = log P(new_selected_choices | non_selected_choices; args)
           - log P(old_selected_choices | non_selected_choices; args)

    When sel selects all addresses, regenerate becomes equivalent to simulate.
    When sel selects no addresses, weight = 0 and trace unchanged.

    Args:
        tr: Current trace to regenerate from.
        sel: Selection specifying which addresses to regenerate.
        *args: Arguments to the generative function.
        **kwargs: Keyword arguments to the generative function.

    Returns:
        A tuple (new_trace, weight, discarded_choices) where:
        - new_trace: trace with selected choices resampled
        - weight: incremental importance weight for the regeneration
        - discarded_choices: old values of the regenerated choices

    Example:
        >>> # Regenerate choices at addresses "x" and "y"
        >>> # selection = sel("x") | sel("y")
        >>> # new_trace, weight, discarded = model.regenerate(trace, selection, mu, sigma)
        >>> # weight accounts for probability change due to regeneration
        >>> pass  # doctest placeholder
    """
    pass

merge abstractmethod

merge(x: X, x_: X, check: ndarray | None = None) -> tuple[X, X | None]

Merge two choice maps, with the second taking precedence.

Used internally for compositional generative functions where choice maps from different components need to be combined. The merge operation resolves conflicts by preferring choices from x_ over x.

Source code in src/genjax/core.py
@abstractmethod
def merge(
    self, x: X, x_: X, check: jnp.ndarray | None = None
) -> tuple[X, X | None]:
    """Merge two choice maps, with the second taking precedence.

    Used internally for compositional generative functions where choice maps
    from different components need to be combined. The merge operation resolves
    conflicts by preferring choices from x_ over x.

    Args:
        x: First choice map.
        x_: Second choice map (takes precedence in conflicts).
        check: Optional boolean array for conditional selection.
               If provided, selects x where True, x_ where False.

    Returns:
        Tuple of (merged choice map, discarded values).
        - merged: Combined choices with x_ values overriding x values at conflicts
        - discarded: Values from x that were overridden by x_ (None if no conflicts)
    """
    pass

filter abstractmethod

filter(x: X, selection: Selection) -> tuple[X | None, X | None]

Filter choice map into selected and unselected parts.

Used to partition choices based on a selection, enabling fine-grained manipulation of subsets of choices in inference algorithms. Each GFI implementation specializes this method for its choice type X.

Example
choices =
selection = sel("mu") | sel("sigma")
selected, unselected = model.filter(choices, selection)
selected = {"mu": 1.0, "sigma": 2.0}, unselected = {"obs": 3.0}

pass # doctest placeholder

Source code in src/genjax/core.py
@abstractmethod
def filter(self, x: X, selection: "Selection") -> tuple[X | None, X | None]:
    """Filter choice map into selected and unselected parts.

    Used to partition choices based on a selection, enabling fine-grained manipulation
    of subsets of choices in inference algorithms. Each GFI implementation specializes
    this method for its choice type X.

    Args:
        x: Choice map to filter.
        selection: Selection specifying which addresses to include.

    Returns:
        Tuple of (selected_choices, unselected_choices) where:
        - selected_choices: Choice map containing only selected addresses, or None if no matches
        - unselected_choices: Choice map containing only unselected addresses, or None if no matches
        Both have the same structure as X but contain disjoint subsets of addresses.

    Example:
        >>> # choices = {"mu": 1.0, "sigma": 2.0, "obs": 3.0}
        >>> # selection = sel("mu") | sel("sigma")
        >>> # selected, unselected = model.filter(choices, selection)
        >>> # selected = {"mu": 1.0, "sigma": 2.0}, unselected = {"obs": 3.0}
        >>> pass  # doctest placeholder
    """
    pass

Thunk

Bases: Generic[X, R], Pytree

Delayed evaluation wrapper for generative functions.

A thunk represents a generative function call that has not yet been executed. It captures the function and its arguments for later evaluation.

Vmap

Bases: Generic[X, R], GFI[X, R]

A Vmap is a generative function combinator that vectorizes another generative function.

Vmap applies a generative function across a batch dimension, similar to jax.vmap, but preserves probabilistic semantics. It uses GenJAX's modular_vmap to handle the vectorization of probabilistic computations correctly.

Mathematical ingredients: - If callee has measure kernel P_callee(dx; args), then Vmap has kernel P_vmap(dX; Args) = ∏_i P_callee(dx_i; args_i) where X = [x_1, ..., x_n] - Return value function f_vmap(X, Args) = [f_callee(x_1, args_1), ..., f_callee(x_n, args_n)] - Internal proposal family inherits from callee's proposal family

Example

from genjax import normal

Vectorize a normal distribution

vectorized_normal = normal.vmap(in_axes=(0, None)) # vectorize over first arg

mus = jnp.array([0.0, 1.0, 2.0]) sigma = 1.0 trace = vectorized_normal.simulate(mus, sigma) samples = trace.get_choices() # Array of 3 normal samples

filter

filter(x: X, selection: Selection) -> tuple[X | None, X | None]

Filter vectorized choices using the underlying generative function's filter.

For Vmap, choices are vectorized across the batch dimension. We apply the underlying GF's filter to each vectorized choice.

Source code in src/genjax/core.py
def filter(self, x: X, selection: "Selection") -> tuple[X | None, X | None]:
    """Filter vectorized choices using the underlying generative function's filter.

    For Vmap, choices are vectorized across the batch dimension. We apply
    the underlying GF's filter to each vectorized choice.

    Args:
        x: Vectorized choice to filter.
        selection: Selection specifying which addresses to include.

    Returns:
        Tuple of (selected_choices, unselected_choices) where each is vectorized or None.
    """
    # Use modular_vmap to apply filter across the batch dimension
    selected, unselected = modular_vmap(
        self.gen_fn.filter,
        in_axes=(0, None),
        axis_size=self.axis_size.value,
    )(x, selection)

    return selected, unselected

Distribution

Bases: Generic[X], GFI[X, X]

A Distribution is a generative function that implements a probability distribution.

Distributions are the fundamental building blocks of probabilistic programs. They implement the Generative Function Interface (GFI) by wrapping a sampling function and a log probability density function (logpdf).

Mathematical ingredients: - A measure kernel P(dx; args) over a measurable space X given arguments args - Return value function f(x, args) = x (identity function for distributions) - Internal proposal distribution family Q(dx; args, x') = P(dx; args) (prior)

Example

import jax import jax.numpy as jnp from genjax import Distribution, const

Create a custom normal distribution

def sample_normal(mu, sigma): ... key = jax.random.PRNGKey(0) # In practice, use proper key management ... return mu + sigma * jax.random.normal(key)

def logpdf_normal(x, mu, sigma): ... return -0.5 * ((x - mu) / sigma)**2 - jnp.log(sigma) - 0.5 * jnp.log(2 * jnp.pi)

normal = Distribution(const(sample_normal), const(logpdf_normal), const("normal")) trace = normal.simulate(0.0, 1.0) # mu=0.0, sigma=1.0

sample

sample(*args, **kwargs) -> X

Sample from the distribution.

Source code in src/genjax/core.py
def sample(self, *args, **kwargs) -> X:
    """Sample from the distribution."""
    return self._sample.value(*args, **kwargs)

logpdf

logpdf(x: X, *args, **kwargs) -> Weight

Compute log probability density.

Source code in src/genjax/core.py
def logpdf(self, x: X, *args, **kwargs) -> Weight:
    """Compute log probability density."""
    return self._logpdf.value(x, *args, **kwargs)

merge

merge(x: X, x_: X, check: ndarray | None = None) -> tuple[X, X | None]

Merge distribution choices with optional conditional selection.

For distributions, choices are raw values from the sample space. When check is provided, we use jnp.where for conditional selection.

Source code in src/genjax/core.py
def merge(
    self, x: X, x_: X, check: jnp.ndarray | None = None
) -> tuple[X, X | None]:
    """Merge distribution choices with optional conditional selection.

    For distributions, choices are raw values from the sample space.
    When check is provided, we use jnp.where for conditional selection.
    """
    if check is not None:
        # Conditional merge using jnp.where
        merged = jtu.tree_map(lambda v1, v2: jnp.where(check, v1, v2), x, x_)
        # No values are truly "discarded" in conditional selection
        return merged, None
    else:
        # Without check, Distribution doesn't support merge
        raise Exception(
            "Can't merge: the underlying sample space `X` for the type `Distribution` doesn't support merging without a check parameter."
        )

filter

filter(x: X, selection: Selection) -> tuple[X | None, X | None]

Filter choice into selected and unselected parts.

For Distribution, the choice is a single value X. Selection either matches the empty address () or it doesn't.

Source code in src/genjax/core.py
def filter(self, x: X, selection: "Selection") -> tuple[X | None, X | None]:
    """Filter choice into selected and unselected parts.

    For Distribution, the choice is a single value X. Selection either
    matches the empty address () or it doesn't.

    Args:
        x: Choice value to potentially filter.
        selection: Selection specifying whether to include the choice.

    Returns:
        Tuple of (selected_choice, unselected_choice) where exactly one is x and the other is None.
    """
    is_selected, _ = selection.match(())
    if is_selected:
        return x, None
    else:
        return None, x

Simulate dataclass

Simulate(score: Weight, trace_map: dict[str, Any], parent_fn: GFI = None)

Handler for simulating generative function executions.

Tracks the accumulated score and trace map during simulation.

Fn

Bases: Generic[R], GFI[dict[str, Any], R]

A Fn is a generative function created from a JAX Python function using the @gen decorator.

Fn implements the GFI by executing the wrapped function in different execution contexts (handlers) that intercept calls to other generative functions via the @ addressing syntax.

Mathematical ingredients: - Measure kernel P(dx; args) defined by the composition of distributions in the function - Return value function f(x, args) defined by the function's logic and return statement - Internal proposal distribution family Q(dx; args, x') defined by ancestral sampling

The choice space X is a dictionary mapping addresses (strings) to the choices made at those addresses during execution.

Example

import jax.numpy as jnp from genjax import gen, normal

@gen def linear_regression(xs): ... slope = normal(0.0, 1.0) @ "slope" ... intercept = normal(0.0, 1.0) @ "intercept" ... noise = normal(0.0, 0.1) @ "noise" ... return normal(slope * xs + intercept, noise) @ "y"

trace = linear_regression.simulate(jnp.array([1.0, 2.0, 3.0])) choices = trace.get_choices() # dict with keys "slope", "intercept", "noise", "y"

filter

filter(x: dict[str, Any], selection: Selection) -> tuple[dict[str, Any] | None, dict[str, Any] | None]

Filter choice map into selected and unselected parts.

For Fn, choices are stored as dict[str, Any] with string addresses.

Source code in src/genjax/core.py
def filter(
    self, x: dict[str, Any], selection: "Selection"
) -> tuple[dict[str, Any] | None, dict[str, Any] | None]:
    """Filter choice map into selected and unselected parts.

    For Fn, choices are stored as dict[str, Any] with string addresses.

    Args:
        x: Choice dictionary to filter.
        selection: Selection specifying which addresses to include.

    Returns:
        Tuple of (selected_choices, unselected_choices) where each is a dict or None.
    """
    if not x:
        return None, None

    selected = {}
    unselected = {}
    found_selected = False
    found_unselected = False

    for addr, value in x.items():
        is_selected, subselection = selection.match(addr)
        if is_selected:
            if isinstance(value, dict) and subselection is not None:
                # Recursively filter nested choices
                selected_sub, unselected_sub = self.filter(value, subselection)
                if selected_sub is not None:
                    selected[addr] = selected_sub
                    found_selected = True
                if unselected_sub is not None:
                    unselected[addr] = unselected_sub
                    found_unselected = True
            else:
                # Include the entire value in selected
                selected[addr] = value
                found_selected = True
        else:
            # Include the entire value in unselected
            unselected[addr] = value
            found_unselected = True

    return (
        selected if found_selected else None,
        unselected if found_unselected else None,
    )

ScanTr

Bases: Generic[X, R], Trace[X, R]

get_fixed_choices

get_fixed_choices() -> X

Get choices preserving Fixed wrappers.

Source code in src/genjax/core.py
def get_fixed_choices(self) -> X:
    """Get choices preserving Fixed wrappers."""
    return self.traces.get_fixed_choices()

Scan

Bases: Generic[X, R], GFI[X, R]

A Scan is a generative function combinator that implements sequential iteration.

Scan repeatedly applies a generative function in a sequential loop, similar to jax.lax.scan, but preserves probabilistic semantics. The callee function should take (carry, x) as input and return (new_carry, output).

Mathematical ingredients: - If callee has measure kernel P_callee(dx; carry, x), then Scan has kernel P_scan(dX; init_carry, xs) = ∏i P_callee(dx_i; carry_i, xs_i) where carry = f_callee(x_i, carry_i, xs_i)[0] - Return value function returns (final_carry, [output_1, ..., output_n]) - Internal proposal family inherits from callee's proposal family

Example

from genjax import gen, normal, Scan, seed, const import jax.numpy as jnp import jax.random as jrand

@gen def step(carry, x): ... noise = normal(0.0, 0.1) @ "noise" ... new_carry = carry + x + noise ... return new_carry, new_carry # output equals new carry

scan_fn = Scan(step, length=const(3)) init_carry = 0.0 xs = jnp.array([1.0, 2.0, 3.0])

Use seed transformation for PJAX primitives

key = jrand.key(0) trace = seed(scan_fn.simulate)(key, init_carry, xs) final_carry, outputs = trace.get_retval() assert len(outputs) == 3 # Should have 3 outputs

filter

filter(x: X, selection: Selection) -> tuple[X | None, X | None]

Filter scan choices using the underlying generative function's filter.

For Scan, choices are structured according to the scan iterations. We delegate to the underlying callee's filter method.

Source code in src/genjax/core.py
def filter(self, x: X, selection: "Selection") -> tuple[X | None, X | None]:
    """Filter scan choices using the underlying generative function's filter.

    For Scan, choices are structured according to the scan iterations.
    We delegate to the underlying callee's filter method.

    Args:
        x: Scan choice structure to filter.
        selection: Selection specifying which addresses to include.

    Returns:
        Tuple of (selected_choices, unselected_choices) from the underlying callee.
    """
    return self.callee.filter(x, selection)

CondTr

Bases: Generic[X, R], Trace[X, R]

get_fixed_choices

get_fixed_choices() -> X

Get choices preserving Fixed wrappers.

Source code in src/genjax/core.py
def get_fixed_choices(self) -> X:
    """Get choices preserving Fixed wrappers."""
    chm, chm_ = map(lambda tr: tr.get_fixed_choices(), self.trs)

    # Use merge with check parameter for conditional selection
    merged, _ = self.gen_fn.merge(chm, chm_, self.check)
    return merged

Cond

Bases: Generic[X, R], GFI[X, R]

A Cond is a generative function combinator that implements conditional branching.

Cond takes a boolean condition and executes one of two generative functions based on the condition, similar to jax.lax.cond, but preserves probabilistic semantics by evaluating both branches and selecting the appropriate one.

Mathematical ingredients: - If branches have measure kernels P_true(dx; args) and P_false(dx; args), then Cond has kernel P_cond(dx; check, args) = P_true(dx; args) if check else P_false(dx; args) - Return value function f_cond(x, check, args) = f_true(x, args) if check else f_false(x, args) - Internal proposal family selects appropriate branch proposal based on condition

Note: Both branches are always evaluated during simulation/generation to maintain JAX compatibility, but only the appropriate branch contributes to the final result.

Example

from genjax import gen, normal, exponential, Cond

@gen def positive_branch(): ... return exponential(1.0) @ "value"

@gen def negative_branch(): ... return exponential(2.0) @ "value"

cond_fn = Cond(positive_branch, negative_branch)

Use in a larger model

@gen def conditional_model(): ... x = normal(0.0, 1.0) @ "x" ... condition = x > 0 ... result = cond_fn((condition,)) @ "conditional" ... return result

filter

filter(x: X, selection: Selection) -> tuple[X | None, X | None]

Filter conditional choices using the underlying generative function's filter.

For Cond, choices are determined by which branch was executed. We delegate to the first callee's filter method.

Source code in src/genjax/core.py
def filter(self, x: X, selection: "Selection") -> tuple[X | None, X | None]:
    """Filter conditional choices using the underlying generative function's filter.

    For Cond, choices are determined by which branch was executed.
    We delegate to the first callee's filter method.

    Args:
        x: Conditional choice structure to filter.
        selection: Selection specifying which addresses to include.

    Returns:
        Tuple of (selected_choices, unselected_choices) from the underlying callee.
    """
    return self.callee.filter(x, selection)

const

const(a: A) -> Const[A]

Create a Const wrapper for a static value.

Example
from genjax import const, Const

# Create a static value
length = const(10)
print(f"Value: {length.value}")
print(f"Type: {type(length)}")

# Use in arithmetic
doubled = length * 2
print(f"Doubled: {doubled.value}")

# Use as static parameter
# @gen
# def model(n: Const[int]):
#     # n.value is guaranteed to be Python int, not JAX tracer
#     for i in range(n.value):  # This works in JAX transforms!
#         ...
Source code in src/genjax/core.py
def const(a: A) -> Const[A]:
    """Create a Const wrapper for a static value.

    Args:
        a: The Python literal to wrap as static.

    Returns:
        A Const wrapper that keeps the value static in JAX transformations.

    Example:
        ```python
        from genjax import const, Const

        # Create a static value
        length = const(10)
        print(f"Value: {length.value}")
        print(f"Type: {type(length)}")

        # Use in arithmetic
        doubled = length * 2
        print(f"Doubled: {doubled.value}")

        # Use as static parameter
        # @gen
        # def model(n: Const[int]):
        #     # n.value is guaranteed to be Python int, not JAX tracer
        #     for i in range(n.value):  # This works in JAX transforms!
        #         ...
        ```
    """
    return Const(a)

fixed

fixed(a: A) -> Fixed[A]

Create a Fixed wrapper for a constrained value.

Source code in src/genjax/core.py
def fixed(a: A) -> Fixed[A]:
    """Create a Fixed wrapper for a constrained value.

    Args:
        a: The value that was provided/constrained externally.

    Returns:
        A Fixed wrapper indicating the value was not proposed internally.
    """
    return Fixed(a)

get_choices

get_choices(x: Trace[X, R] | X) -> X

Extract choices from a trace or nested structure containing traces.

Also strips Fixed wrappers from the choices, returning the unwrapped values. Fixed wrappers are used internally to track constrained vs. proposed values.

Source code in src/genjax/core.py
def get_choices(x: Trace[X, R] | X) -> X:
    """Extract choices from a trace or nested structure containing traces.

    Also strips Fixed wrappers from the choices, returning the unwrapped values.
    Fixed wrappers are used internally to track constrained vs. proposed values.

    Args:
        x: A trace object or nested structure that may contain traces.

    Returns:
        The random choices, with any nested traces recursively unwrapped and
        Fixed wrappers stripped.
    """
    x = x.get_choices() if isinstance(x, Trace) else x

    def _get_choices(x):
        if isinstance(x, Trace):
            return get_choices(x)
        else:
            return x

    # First unwrap any nested traces
    x = jtu.tree_map(
        _get_choices,
        x,
        is_leaf=lambda x: isinstance(x, Trace),
    )

    # Then strip Fixed wrappers
    def _strip_fixed(x):
        if isinstance(x, Fixed):
            return x.value  # Unwrap Fixed wrapper
        else:
            return x

    return jtu.tree_map(
        _strip_fixed,
        x,
        is_leaf=lambda x: isinstance(x, Fixed),
    )

get_fixed_choices

get_fixed_choices(x: Trace[X, R] | X) -> X

Extract choices from a trace or nested structure containing traces, preserving Fixed wrappers.

Similar to get_choices() but preserves Fixed wrappers around the choices, which is needed for verification that values were constrained during inference.

Source code in src/genjax/core.py
def get_fixed_choices(x: Trace[X, R] | X) -> X:
    """Extract choices from a trace or nested structure containing traces, preserving Fixed wrappers.

    Similar to get_choices() but preserves Fixed wrappers around the choices,
    which is needed for verification that values were constrained during inference.

    Args:
        x: A trace object or nested structure that may contain traces.

    Returns:
        The random choices, with any nested traces recursively unwrapped but
        Fixed wrappers preserved.
    """
    x = x.get_fixed_choices() if isinstance(x, Trace) else x

    def _get_fixed_choices(x):
        if isinstance(x, Trace):
            return get_fixed_choices(x)
        else:
            return x

    # Unwrap any nested traces but preserve Fixed wrappers
    # Note: Unlike get_choices(), we do NOT strip Fixed wrappers
    return jtu.tree_map(
        _get_fixed_choices,
        x,
        is_leaf=lambda x: isinstance(x, Trace),
    )

get_score

get_score(x: Trace[X, R]) -> Weight

Extract the log probability score from a trace.

Source code in src/genjax/core.py
def get_score(x: Trace[X, R]) -> Weight:
    """Extract the log probability score from a trace.

    Args:
        x: Trace object to extract score from.

    Returns:
        The log probability score of the trace.
    """
    return x.get_score()

get_retval

get_retval(x: Trace[X, R]) -> R

Extract the return value from a trace.

Source code in src/genjax/core.py
def get_retval(x: Trace[X, R]) -> R:
    """Extract the return value from a trace.

    Args:
        x: Trace object to extract return value from.

    Returns:
        The return value of the trace.
    """
    return x.get_retval()

sel

sel(*v: tuple[] | str | tuple[str, ...] | dict[str, Any] | None) -> Selection

Create a Selection from various input types.

This is a convenience function to create Selection objects from common patterns. Selections specify which random choices in a trace should be regenerated during inference operations like MCMC.

Examples:

from genjax import sel

# Select specific address
s1 = sel("x")
print(f"sel('x'): {s1}")

# Select hierarchical address
s2 = sel(("outer", "inner"))
print(f"sel(('outer', 'inner')): {s2}")

# Select all addresses
s3 = sel(())
print(f"sel(()): {s3}")

# Select no addresses
s4 = sel()
print(f"sel(): {s4}")

# Combine selections with OR
s5 = sel("x") | sel("y")
print(f"sel('x') | sel('y'): {s5}")

# Complement selection
s6 = ~sel("x")
print(f"~sel('x'): {s6}")
Source code in src/genjax/core.py
def sel(*v: tuple[()] | str | tuple[str, ...] | dict[str, Any] | None) -> Selection:
    """Create a Selection from various input types.

    This is a convenience function to create Selection objects from common patterns.
    Selections specify which random choices in a trace should be regenerated during
    inference operations like MCMC.

    Args:
        *v: Variable arguments specifying the selection pattern:
            - str: Select a specific address (e.g., sel("x"))
            - tuple[str, ...]: Select hierarchical address (e.g., sel(("outer", "inner")))
            - (): Select all addresses (e.g., sel(()))
            - dict: Select nested addresses (e.g., sel({"outer": sel("inner")}))
            - None or no args: Select no addresses (e.g., sel() or sel(None))

    Returns:
        Selection object that can be used with regenerate methods

    Examples:
        ```python
        from genjax import sel

        # Select specific address
        s1 = sel("x")
        print(f"sel('x'): {s1}")

        # Select hierarchical address
        s2 = sel(("outer", "inner"))
        print(f"sel(('outer', 'inner')): {s2}")

        # Select all addresses
        s3 = sel(())
        print(f"sel(()): {s3}")

        # Select no addresses
        s4 = sel()
        print(f"sel(): {s4}")

        # Combine selections with OR
        s5 = sel("x") | sel("y")
        print(f"sel('x') | sel('y'): {s5}")

        # Complement selection
        s6 = ~sel("x")
        print(f"~sel('x'): {s6}")
        ```
    """
    assert len(v) <= 1
    if len(v) == 1:
        if v[0] is None:
            return Selection(NoneSel())
        if v[0] == ():
            return Selection(AllSel())
        elif isinstance(v[0], dict):
            return Selection(DictSel(v[0]))
        elif isinstance(v[0], tuple) and all(isinstance(s, str) for s in v[0]):
            # Tuple of strings for hierarchical addresses
            return Selection(TupleSel(const(v[0])))
        else:
            assert isinstance(v[0], str)
            return Selection(StrSel(const(v[0])))
    else:
        return Selection(NoneSel())

distribution

distribution(sampler: Callable[..., Any], logpdf: Callable[..., Any], /, name: str | None = None) -> Distribution[Any]

Create a Distribution from sampling and log probability functions.

Source code in src/genjax/core.py
def distribution(
    sampler: Callable[..., Any],
    logpdf: Callable[..., Any],
    /,
    name: str | None = None,
) -> Distribution[Any]:
    """Create a Distribution from sampling and log probability functions.

    Args:
        sampler: Function that takes parameters and returns a sample.
        logpdf: Function that takes (value, *parameters) and returns log probability.
        name: Optional name for the distribution.

    Returns:
        A Distribution instance implementing the Generative Function Interface.
    """
    return Distribution(
        _sample=const(sampler),
        _logpdf=const(logpdf),
        name=const(name),
    )

tfp_distribution

tfp_distribution(dist: Callable[..., Distribution], /, name: str | None = None) -> Distribution[Any]

Create a Distribution from a TensorFlow Probability distribution.

Wraps a TFP distribution constructor to create a GenJAX Distribution that properly handles PJAX's sample_p primitive.

Example

import tensorflow_probability.substrates.jax as tfp from genjax import tfp_distribution

Create a normal distribution from TFP

normal = tfp_distribution(tfp.distributions.Normal, name="normal")

Source code in src/genjax/core.py
def tfp_distribution(
    dist: Callable[..., "tfd.Distribution"],
    /,
    name: str | None = None,
) -> Distribution[Any]:
    """Create a Distribution from a TensorFlow Probability distribution.

    Wraps a TFP distribution constructor to create a GenJAX Distribution
    that properly handles PJAX's `sample_p` primitive.

    Args:
        dist: TFP distribution constructor function.
        name: Optional name for the distribution.

    Returns:
        A Distribution that wraps the TFP distribution.

    Example:
        >>> import tensorflow_probability.substrates.jax as tfp
        >>> from genjax import tfp_distribution
        >>>
        >>> # Create a normal distribution from TFP
        >>> normal = tfp_distribution(tfp.distributions.Normal, name="normal")
    """

    def keyful_sampler(key, *args, sample_shape=(), **kwargs):
        d = dist(*args, **kwargs)
        return d.sample(seed=key, sample_shape=sample_shape)

    def logpdf(v, *args, **kwargs):
        d = dist(*args, **kwargs)
        return d.log_prob(v)

    return distribution(
        wrap_sampler(
            keyful_sampler,
            name=name,
        ),
        wrap_logpdf(logpdf),
        name=name,
    )

gen

gen(fn: Callable[..., R]) -> Fn[R]

Convert a function into a generative function.

The decorated function can use the @ operator to make addressed random choices from distributions and other generative functions.

Example

from genjax import gen, normal

@gen ... def model(mu, sigma): ... x = normal(mu, sigma) @ "x" ... y = normal(x, 0.1) @ "y" ... return x + y

trace = model.simulate(0.0, 1.0) choices = trace.get_choices()

choices will contain

Source code in src/genjax/core.py
def gen(fn: Callable[..., R]) -> Fn[R]:
    """Convert a function into a generative function.

    The decorated function can use the `@` operator to make addressed
    random choices from distributions and other generative functions.

    Args:
        fn: Function to convert into a generative function.

    Returns:
        A Fn instance that implements the Generative Function Interface.

    Example:
        >>> from genjax import gen, normal
        >>>
        >>> @gen
        ... def model(mu, sigma):
        ...     x = normal(mu, sigma) @ "x"
        ...     y = normal(x, 0.1) @ "y"
        ...     return x + y
        >>>
        >>> trace = model.simulate(0.0, 1.0)
        >>> choices = trace.get_choices()
        >>> # choices will contain {"x": <value>, "y": <value>}
    """
    gf = Fn(source=const(fn))
    # Copy function metadata to preserve name and module information
    try:
        gf.__name__ = fn.__name__
        gf.__qualname__ = fn.__qualname__
        gf.__module__ = fn.__module__
        gf.__doc__ = fn.__doc__
        gf.__annotations__ = getattr(fn, "__annotations__", {})
    except (AttributeError, TypeError):
        # If we can't set these attributes (e.g., on frozen dataclasses), continue anyway
        pass
    return gf

Live Examples

Basic Model Definition

import jax
import jax.numpy as jnp
from genjax import gen, distributions

@gen
def coin_flip_model(n_flips):
    """A simple coin flipping model with unknown bias."""
    bias = distributions.beta(1.0, 1.0) @ "bias"

    # For demonstration, we'll show manual unrolling
    # In practice, use Scan combinator for loops
    flip_0 = distributions.bernoulli(bias) @ "flip_0"
    flip_1 = distributions.bernoulli(bias) @ "flip_1" 
    flip_2 = distributions.bernoulli(bias) @ "flip_2"

    return jnp.array([flip_0, flip_1, flip_2])

print("Model defined successfully!")

Model defined successfully!

Assessing Log Probability

import jax
import jax.numpy as jnp
from genjax import gen, distributions

@gen
def coin_flip_model(n_flips):
    """A simple coin flipping model with unknown bias."""
    bias = distributions.beta(1.0, 1.0) @ "bias"

    # For demonstration, we'll show manual unrolling
    # In practice, use Scan combinator for loops
    flip_0 = distributions.bernoulli(bias) @ "flip_0"
    flip_1 = distributions.bernoulli(bias) @ "flip_1" 
    flip_2 = distributions.bernoulli(bias) @ "flip_2"

    return jnp.array([flip_0, flip_1, flip_2])

# Assess the log probability of specific choices
choices = {"bias": 0.7, "flip_0": 1, "flip_1": 1, "flip_2": 0}
log_prob, retval = coin_flip_model.assess(choices, 3)

print(f"Given choices: {choices}")
print(f"Log probability: {log_prob:.3f}")
print(f"Return value (flips): {retval}")

Given choices: {'bias': 0.7, 'flip_0': 1, 'flip_1': 1, 'flip_2': 0} Log probability: -1.910 Return value (flips): [1 1 0]

Using Selections

from genjax import sel, Selection

# Create various selections
s1 = sel("bias")  # Select only bias
s2 = sel("flip_0") | sel("flip_1")  # Select two flips with OR
s3 = sel("bias") | sel("flip_2")  # Select bias OR flip_2

print(f"Selection s1 targets: bias")
print(f"Selection s2 targets: flip_0 or flip_1")
print(f"Selection s3 targets: bias or flip_2")

Selection s1 targets: bias Selection s2 targets: flip_0 or flip_1 Selection s3 targets: bias or flip_2