Skip to content

genjax.state

State interpreter for inspecting and manipulating probabilistic computations.

state

JAX interpreter for inspecting and organizing tagged state inside JAX Python functions.

This module provides a State interpreter that can collect and hierarchically organize tagged values from within JAX computations using JAX primitives. The interpreter works seamlessly with all JAX transformations while providing powerful state organization capabilities.

Core Features:

State Collection: Tag intermediate values during computation for inspection Hierarchical Organization: Use namespaces to create nested state structures JAX Integration: Full compatibility with jit, vmap, grad, scan, and other JAX transforms Error Safety: Automatic cleanup of namespace stack on exceptions Zero Overhead: No performance cost when not using the @state decorator

Primary API:

Basic State Collection: - state(f): Transform function to collect tagged state values - save(**tagged_values): Tag multiple values by name (named mode) - save(*values): Save values directly at current namespace leaf (leaf mode)

Hierarchical Organization: - namespace(fn, ns): Transform function to collect state under namespace - Supports arbitrary nesting: namespace(namespace(fn, "inner"), "outer")

Lower-level API:

  • tag_state(*values, name="..."): Tag individual values for collection

Usage Examples:

Basic state collection:

@state
def computation(x):
    y = x + 1
    save(intermediate=y, doubled=x*2)
    return y

result, state_dict = computation(5)
# state_dict = {"intermediate": 6, "doubled": 10}

Hierarchical organization with named mode:

@state
def complex_computation(x):
    save(input=x)

    # Namespace for processing steps (named mode)
    processing = namespace(
        lambda: save(step1=x*2, step2=x+1),
        "processing"
    )
    processing()

    # Nested namespaces
    analysis = namespace(
        namespace(lambda: save(mean=x), "stats"),
        "analysis"
    )
    analysis()

    return x

result, state_dict = complex_computation(5)
# state_dict = {
#     "input": 5,
#     "processing": {"step1": 10, "step2": 6},
#     "analysis": {"stats": {"mean": 5}}
# }

Leaf mode for direct namespace storage:

@state
def leaf_computation(x):
    save(input=x)

    # Leaf mode: save values directly at namespace (no additional keys)
    coords = namespace(lambda: save(x, x*2, x*3), "coordinates")
    coords()

    # Mixed with named mode in different namespace
    stats = namespace(lambda: save(mean=x, variance=x**2), "statistics")
    stats()

    return x

result, state_dict = leaf_computation(5)
# state_dict = {
#     "input": 5,
#     "coordinates": (5, 10, 15),  # Leaf mode: tuple stored directly
#     "statistics": {"mean": 5, "variance": 25}  # Named mode: dict
# }

JAX Integration:

# Works with all JAX transformations
jitted_fn = jax.jit(computation)
vmapped_fn = jax.vmap(computation)
grad_fn = jax.grad(lambda x: computation(x)[0])

Implementation Details:

The state interpreter uses JAX primitives (state_p, namespace_push_p, namespace_pop_p) to integrate with JAX's transformation system. This ensures proper behavior under jit, vmap, grad, and other JAX transforms.

The namespace functionality is implemented using a stack-based approach where namespace push/pop operations are tracked via JAX primitives, allowing the interpreter to maintain correct hierarchical structure even under complex JAX transformations.

State dataclass

State(collected_state: dict[str, Any], namespace_stack: list[str] = list())

JAX interpreter that collects tagged state values.

This interpreter processes JAX computations and collects values that are tagged with the state_p primitive. Tagged values are accumulated and returned alongside the original computation result.

eval_jaxpr_state

eval_jaxpr_state(jaxpr: Jaxpr, consts: list[Any], args: list[Any])

Evaluate a jaxpr while collecting tagged state values.

Source code in src/genjax/state.py
def eval_jaxpr_state(
    self,
    jaxpr: Jaxpr,
    consts: list[Any],
    args: list[Any],
):
    """Evaluate a jaxpr while collecting tagged state values."""
    env = Environment()
    safe_map(env.write, jaxpr.constvars, consts)
    safe_map(env.write, jaxpr.invars, args)

    for eqn in jaxpr.eqns:
        invals = safe_map(env.read, eqn.invars)
        subfuns, params = eqn.primitive.get_bind_params(eqn.params)
        args = subfuns + invals
        primitive, inner_params = PPPrimitive.unwrap(eqn.primitive)

        if primitive == state_p:
            # Collect the tagged values with namespace support
            name = params.get("name", inner_params.get("name"))
            if name is None:
                raise ValueError("tag_state requires a 'name' parameter")
            values = list(invals) if invals else []
            value = (
                tuple(values)
                if len(values) > 1
                else (values[0] if values else None)
            )

            # Handle leaf mode storage (special case for save(*args))
            if name == "__NAMESPACE_LEAF__":
                namespace_path = tuple(self.namespace_stack)
                if namespace_path:
                    # Store directly at the namespace path (no additional key)
                    current = self.collected_state
                    for namespace in namespace_path[:-1]:
                        if namespace not in current:
                            current[namespace] = {}
                        current = current[namespace]
                    # Store at the final namespace level
                    current[namespace_path[-1]] = value
                else:
                    # If no namespace, we can't do leaf storage at root
                    raise ValueError(
                        "Leaf mode save() requires being inside a namespace"
                    )
            else:
                # Handle namespace path using interpreter's stack (named mode)
                namespace_path = tuple(self.namespace_stack)
                if namespace_path:
                    _nested_dict_set(
                        self.collected_state, namespace_path, name, value
                    )
                else:
                    self.collected_state[name] = value

            # The state primitive returns the values as-is due to multiple_results
            outvals = values

        elif primitive == namespace_push_p:
            # Push namespace onto interpreter's stack
            namespace = params.get("namespace", inner_params.get("namespace"))
            if namespace is None:
                raise ValueError("namespace_push requires a 'namespace' parameter")
            self.namespace_stack.append(namespace)
            # Namespace push doesn't take or return values
            outvals = []

        elif primitive == namespace_pop_p:
            # Pop namespace from interpreter's stack
            if not self.namespace_stack:
                raise ValueError("namespace_pop called with empty namespace stack")
            self.namespace_stack.pop()
            # Namespace pop doesn't take or return values
            outvals = []

        elif primitive == scan_p:
            # Handle scan primitive by transforming body to collect state
            body_jaxpr = params["jaxpr"]
            length = params["length"]
            reverse = params["reverse"]
            num_consts = params["num_consts"]
            num_carry = params["num_carry"]
            const_vals, carry_vals, xs_vals = split_list(
                invals, [num_consts, num_carry]
            )

            body_fun = jex.core.jaxpr_as_fun(body_jaxpr)

            def new_body(carry, scanned_in):
                in_carry = carry
                all_values = const_vals + jtu.tree_leaves((in_carry, scanned_in))
                # Apply state transformation to the body
                body_result, body_state = state(body_fun)(*all_values)
                # Split the body result back into carry and scan parts
                out_carry, out_scan = split_list(
                    jtu.tree_leaves(body_result), [num_carry]
                )
                # Return carry, scan output, and collected state
                return out_carry, (out_scan, body_state)

            flat_carry_out, (scanned_out, scan_states) = scan(
                new_body,
                carry_vals,
                xs_vals,
                length=length,
                reverse=reverse,
            )

            # Merge vectorized scan states into collected state
            # scan_states is already vectorized by scan - just merge it
            for name, vectorized_values in scan_states.items():
                self.collected_state[name] = vectorized_values

            outvals = jtu.tree_leaves(
                (flat_carry_out, scanned_out),
            )

        else:
            # For all other primitives, use normal JAX evaluation
            outvals = eqn.primitive.bind(*args, **params)
            if not eqn.outvars:
                outvals = []
            elif isinstance(outvals, (list, tuple)):
                outvals = list(outvals)
            else:
                outvals = [outvals]

        safe_map(env.write, eqn.outvars, outvals)

    return safe_map(env.read, jaxpr.outvars)

eval

eval(fn, *args)

Run the interpreter on a function with given arguments.

Source code in src/genjax/state.py
def eval(self, fn, *args):
    """Run the interpreter on a function with given arguments."""
    closed_jaxpr, (flat_args, _, out_tree) = stage(fn)(*args)
    jaxpr, consts = closed_jaxpr.jaxpr, closed_jaxpr.literals
    flat_out = self.eval_jaxpr_state(
        jaxpr,
        consts,
        flat_args,
    )
    result = jtu.tree_unflatten(out_tree(), flat_out)
    return result, self.collected_state

state

state(f: Callable[..., Any])

Transform a function to collect tagged state values.

This transformation wraps a function to intercept and collect values that are tagged with the state_p primitive. The transformed function returns both the original result and a dictionary of collected state.

Example

from genjax.state import state, save

@state def computation(x): ... y = x + 1 ... z = x * 2 ... values = save(intermediate=y, doubled=z) ... return values["intermediate"] * 2

result, state_dict = computation(5) print(result) # 12 print(state_dict) # {"intermediate": 6, "doubled": 10}

Source code in src/genjax/state.py
def state(f: Callable[..., Any]):
    """Transform a function to collect tagged state values.

    This transformation wraps a function to intercept and collect values
    that are tagged with the `state_p` primitive. The transformed function
    returns both the original result and a dictionary of collected state.

    Args:
        f: Function containing state tags to transform.

    Returns:
        Function that returns a tuple of (original_result, collected_state).

    Example:
        >>> from genjax.state import state, save
        >>>
        >>> @state
        >>> def computation(x):
        ...     y = x + 1
        ...     z = x * 2
        ...     values = save(intermediate=y, doubled=z)
        ...     return values["intermediate"] * 2
        >>>
        >>> result, state_dict = computation(5)
        >>> print(result)  # 12
        >>> print(state_dict)  # {"intermediate": 6, "doubled": 10}
    """

    @wraps(f)
    def wrapped(*args):
        interpreter = State(collected_state={}, namespace_stack=[])
        return interpreter.eval(f, *args)

    return wrapped

tag_state

tag_state(*values: Any, name: str) -> Any

Tag one or more values to be collected by the StateInterpreter.

Note: Consider using save(**tagged_values) for most use cases, as it provides a more convenient API for tagging multiple values.

This function marks values to be collected when the computation is run through the state transformation. The values are passed through unchanged in normal execution.

Example

x = 42 y = tag_state(x, name="my_value") # y == x == 42

Multiple values

a, b = tag_state(1, 2, name="pair") # a == 1, b == 2

When run through state() transformation,

values will be collected in state dict

Prefer save() for multiple named values:

values = save(x=42, y=24) # More convenient

Source code in src/genjax/state.py
def tag_state(*values: Any, name: str) -> Any:
    """Tag one or more values to be collected by the StateInterpreter.

    **Note: Consider using `save(**tagged_values)` for most use cases, as it
    provides a more convenient API for tagging multiple values.**

    This function marks values to be collected when the computation
    is run through the `state` transformation. The values are passed
    through unchanged in normal execution.

    Args:
        *values: The values to tag and collect.
        name: Required string identifier for this state value.

    Returns:
        The original values (identity function). If single value, returns
        the value directly. If multiple values, returns a tuple.

    Example:
        >>> x = 42
        >>> y = tag_state(x, name="my_value")  # y == x == 42
        >>> # Multiple values
        >>> a, b = tag_state(1, 2, name="pair")  # a == 1, b == 2
        >>> # When run through state() transformation,
        >>> # values will be collected in state dict
        >>>
        >>> # Prefer save() for multiple named values:
        >>> values = save(x=42, y=24)  # More convenient
    """
    if not values:
        raise ValueError("tag_state requires at least one value")

    # Use initial_style_bind for proper JAX transformation compatibility
    def identity_fn(*args):
        return tuple(args) if len(args) > 1 else args[0]

    # Create a batch rule that re-inserts the primitive under vmap
    def batch_rule(vector_args, dims, **params):
        # Re-insert the state primitive with the vectorized args
        def vectorized_identity(*args):
            return tuple(args) if len(args) > 1 else args[0]

        # Apply the primitive to the vectorized args
        result = initial_style_bind(
            state_p,
            batch=batch_rule,  # Self-reference for nested vmaps
        )(vectorized_identity, name=params.get("name"))(*vector_args)

        # Return result with appropriate batching dimensions
        if isinstance(result, tuple):
            # For multiple outputs, each has the same dims as inputs
            return result, tuple(dims[0] if dims else () for _ in result)
        else:
            # For single output, return as tuple (JAX expects a sequence for dims_out)
            return (result,), (dims[0] if dims else (),)

    result = initial_style_bind(
        state_p,
        batch=batch_rule,
    )(identity_fn, name=name)(*values)

    return result

save

save(*values, **tagged_values) -> Any

Save values either at current namespace leaf or with explicit names (primary API).

This is the recommended way to tag state values. Supports two modes:

  1. Leaf mode (*args): Save values directly at current namespace leaf
  2. Named mode (**kwargs): Save values with explicit names (original behavior)
Example

Leaf mode (saves at current namespace):

@state def computation(): ... namespace_fn = namespace(lambda: save(1, 2, 3), "coords") ... namespace_fn() ... return 42 result, state_dict = computation()

state_dict ==

Named mode (original behavior):

@state def computation(): ... values = save(first=1, second=2) ... return sum(values.values()) result, state_dict = computation()

state_dict ==

Source code in src/genjax/state.py
def save(*values, **tagged_values) -> Any:
    """Save values either at current namespace leaf or with explicit names (primary API).

    **This is the recommended way to tag state values.** Supports two modes:

    1. **Leaf mode** (`*args`): Save values directly at current namespace leaf
    2. **Named mode** (`**kwargs`): Save values with explicit names (original behavior)

    Args:
        *values: Values to save at current namespace leaf (mutually exclusive with **tagged_values)
        **tagged_values: Keyword arguments where keys are names and values are the values to save

    Returns:
        - Leaf mode: The values as a tuple (or single value if only one)
        - Named mode: Dictionary of the saved values (for convenience)

    Example:
        Leaf mode (saves at current namespace):
        >>> @state
        >>> def computation():
        ...     namespace_fn = namespace(lambda: save(1, 2, 3), "coords")
        ...     namespace_fn()
        ...     return 42
        >>> result, state_dict = computation()
        >>> # state_dict == {"coords": (1, 2, 3)}

        Named mode (original behavior):
        >>> @state
        >>> def computation():
        ...     values = save(first=1, second=2)
        ...     return sum(values.values())
        >>> result, state_dict = computation()
        >>> # state_dict == {"first": 1, "second": 2}
    """
    if values and tagged_values:
        raise ValueError(
            "Cannot use both positional args (*values) and keyword args (**tagged_values) in save()"
        )

    if values:
        # Leaf mode: save values directly at current namespace leaf
        # Use a special reserved name to indicate leaf storage
        leaf_value = values if len(values) > 1 else values[0]
        tag_state(leaf_value, name="__NAMESPACE_LEAF__")
        return leaf_value
    else:
        # Named mode: original behavior with explicit names
        result = {}
        for name, value in tagged_values.items():
            result[name] = tag_state(value, name=name)
        return result

namespace

namespace(f: Callable[..., Any], ns: str) -> Callable[..., Any]

Transform a function to collect state under a namespace.

This function wraps another function so that any state collected within it will be organized under the specified namespace. Namespaces can be nested by applying this function multiple times.

Example

@state def computation(x): ... # State collected directly at root level ... save(root_val=x) ... ... # State collected under "inner" namespace ... inner_fn = namespace(lambda y: save(nested_val=y * 2), "inner") ... inner_fn(x) ... ... # Nested namespaces: state under "outer.deep" ... deep_fn = namespace( ... namespace(lambda z: save(deep_val=z * 3), "deep"), ... "outer" ... ) ... deep_fn(x) ... ... return x

result, state_dict = computation(5)

state_dict == {

"root_val": 5,

"inner": {"nested_val": 10},

"outer": {"deep": {"deep_val": 15}}

}

Source code in src/genjax/state.py
def namespace(f: Callable[..., Any], ns: str) -> Callable[..., Any]:
    """Transform a function to collect state under a namespace.

    This function wraps another function so that any state collected within
    it will be organized under the specified namespace. Namespaces can be
    nested by applying this function multiple times.

    Args:
        f: Function to wrap with namespace context
        ns: Namespace string to organize state under

    Returns:
        Function that collects state under the specified namespace

    Example:
        >>> @state
        >>> def computation(x):
        ...     # State collected directly at root level
        ...     save(root_val=x)
        ...
        ...     # State collected under "inner" namespace
        ...     inner_fn = namespace(lambda y: save(nested_val=y * 2), "inner")
        ...     inner_fn(x)
        ...
        ...     # Nested namespaces: state under "outer.deep"
        ...     deep_fn = namespace(
        ...         namespace(lambda z: save(deep_val=z * 3), "deep"),
        ...         "outer"
        ...     )
        ...     deep_fn(x)
        ...
        ...     return x
        >>>
        >>> result, state_dict = computation(5)
        >>> # state_dict == {
        >>> #     "root_val": 5,
        >>> #     "inner": {"nested_val": 10},
        >>> #     "outer": {"deep": {"deep_val": 15}}
        >>> # }
    """

    @wraps(f)
    def namespaced_fn(*args, **kwargs):
        # Push namespace using JAX primitive
        _namespace_push(ns)
        try:
            result = f(*args, **kwargs)
            return result
        finally:
            # Always pop namespace, even if function raises
            _namespace_pop()

    return namespaced_fn