Skip to content

genjax.adev

Automatic differentiation of expected values for gradient estimation.

adev

ADEV: Sound Automatic Differentiation of Expected Values

This module implements ADEV (Automatic Differentiation of Expectation Values), a system for computing sound, unbiased gradient estimators of expectations involving stochastic functions. Based on the research presented in "ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs" (Lew et al., POPL 2023, arXiv:2212.06386).

ADEV is a source-to-source transformation that extends forward-mode automatic differentiation to correctly handle probabilistic computations. The key insight is transforming a probabilistic program into a new program whose expected return value is the derivative of the original program's expectation.

Theoetical Foundation: ADEV uses a continuation-passing style (CPS) transformation that reflects the law of iterated expectation: E[f(X)] = E[E[f(X) | Z]] where X depends on Z. This enables modular composition of different gradient estimation strategies while maintaining soundness guarantees proven via logical relations.

Key Concepts
  • ADEVPrimitive: Stochastic primitives with custom gradient estimation strategies
  • Dual Numbers: Pairs (primal, tangent) for forward-mode automatic differentiation
  • Continuations: Higher-order functions representing "the rest of the computation"
    • Pure continuation: Operates on primal values only (no differentiation)
    • Dual continuation: Operates on dual numbers (differentiation applied)
  • CPS Transformation: Allows modular selection of gradient strategies per distribution
Gradient Estimation Strategies
  • REINFORCE: Score function estimator ∇E[f(X)] = E[f(X) * ∇log p(X)]
  • Reparameterization: Pathwise estimator ∇E[f(g(ε))] = E[∇f(g(ε)) * ∇g(ε)]
  • Enumeration: Exact computation for finite discrete distributions
  • Measure-Valued Derivatives: Advanced discrete gradient estimators
Example
from genjax.adev import expectation, normal_reparam

@expectation
def objective(theta):
    x = normal_reparam(theta, 1.0)  # Reparameterizable distribution
    return x**2

grad = objective.grad_estimate(0.5)  # Unbiased gradient estimate
References

Lew, A. K., Huot, M., Staton, S., & Mansinghka, V. K. (2023). ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs. Proceedings of the ACM on Programming Languages, 7(POPL), 121-148.

DualTree module-attribute

DualTree = Annotated[Any, Is[lambda v: static_check_dual_tree(v)]]

DualTree is the type of Pytree argument values with Dual leaves.

ADEVPrimitive

Bases: Pytree

Base class for stochastic primitives with custom gradient estimation strategies.

An ADEVPrimitive represents a stochastic operation (like sampling from a distribution) that can provide custom gradient estimates through the ADEV system. Each primitive implements both forward sampling and a strategy for computing Jacobian-Vector Product (JVP) estimates during automatic differentiation.

The key insight is that different stochastic operations benefit from different gradient estimation strategies (REINFORCE, reparameterization, enumeration, etc.), and ADEVPrimitive allows each operation to specify its optimal strategy.

sample abstractmethod

sample(*args) -> Any

Forward sampling operation.

Source code in src/genjax/adev/__init__.py
@abstractmethod
def sample(self, *args) -> Any:
    """Forward sampling operation.

    Args:
        *args: Parameters for the stochastic operation

    Returns:
        Sample from the distribution/stochastic process
    """
    pass

prim_jvp_estimate abstractmethod

prim_jvp_estimate(dual_tree: tuple[DualTree, ...], konts: tuple[Callable[..., Any], Callable[..., Any]]) -> Dual

Custom JVP gradient estimation strategy.

This method implements the core gradient estimation logic for this primitive. It receives dual numbers (primal + tangent values) for the arguments and two continuations representing the rest of the computation.

Note

The choice of continuation reflects ADEV's CPS transformation approach: - Pure continuation: Evaluates the remaining computation on primal values - Dual continuation: Applies ADEV transformation to remaining computation

Different gradient strategies utilize these continuations differently: - REINFORCE: Uses dual continuation to evaluate f(X), computes ∇log p(X) - Reparameterization: Uses dual continuation with reparameterized samples - Enumeration: May use both to compute weighted exact expectations

This design enables modular composition as described in the ADEV paper.

Source code in src/genjax/adev/__init__.py
@abstractmethod
def prim_jvp_estimate(
    self,
    dual_tree: tuple[DualTree, ...],
    konts: tuple[
        Callable[..., Any],  # Pure continuation (kpure)
        Callable[..., Any],  # Dual continuation (kdual)
    ],
) -> "Dual":
    """Custom JVP gradient estimation strategy.

    This method implements the core gradient estimation logic for this primitive.
    It receives dual numbers (primal + tangent values) for the arguments and
    two continuations representing the rest of the computation.

    Args:
        dual_tree: Arguments as dual numbers (primal, tangent) pairs
        konts: Pair of continuations:
            - konts[0] (kpure): Pure continuation - no ADEV transformation
            - konts[1] (kdual): Dual continuation - ADEV transformation applied

    Returns:
        Dual number representing the gradient estimate for this operation

    Note:
        The choice of continuation reflects ADEV's CPS transformation approach:
        - Pure continuation: Evaluates the remaining computation on primal values
        - Dual continuation: Applies ADEV transformation to remaining computation

        Different gradient strategies utilize these continuations differently:
        - REINFORCE: Uses dual continuation to evaluate f(X), computes ∇log p(X)
        - Reparameterization: Uses dual continuation with reparameterized samples
        - Enumeration: May use both to compute weighted exact expectations

        This design enables modular composition as described in the ADEV paper.
    """
    pass

Dual

Bases: Pytree

Dual number for forward-mode automatic differentiation.

A Dual number represents both a value (primal) and its derivative (tangent) with respect to some input. This is the fundamental data structure for forward-mode AD in the ADEV system.

Example

x = Dual(3.0, 1.0) # x = 3, dx/dx = 1 y = Dual(x.primal2, 2x.primalx.tangent) # y = x^2, dy/dx = 2x

tree_pure staticmethod

tree_pure(v)

Convert a tree to have Dual leaves with zero tangents.

This is used to "lift" regular values into the dual number system by pairing them with zero tangents, indicating no sensitivity.

Source code in src/genjax/adev/__init__.py
@staticmethod
def tree_pure(v):
    """Convert a tree to have Dual leaves with zero tangents.

    This is used to "lift" regular values into the dual number system
    by pairing them with zero tangents, indicating no sensitivity.

    Args:
        v: Pytree that may contain mix of Dual and regular values

    Returns:
        Pytree where all leaves are Dual numbers
    """

    def _inner(v):
        if isinstance(v, Dual):
            return v
        else:
            return Dual(v, jnp.zeros_like(v))

    return jtu.tree_map(_inner, v, is_leaf=lambda v: isinstance(v, Dual))

dual_tree staticmethod

dual_tree(primals, tangents)

Combine primal and tangent trees into a tree of Dual numbers.

Source code in src/genjax/adev/__init__.py
@staticmethod
def dual_tree(primals, tangents):
    """Combine primal and tangent trees into a tree of Dual numbers.

    Args:
        primals: Tree of primal values
        tangents: Tree of tangent values (same structure as primals)

    Returns:
        Tree of Dual numbers combining corresponding primals and tangents
    """
    return jtu.tree_map(lambda v1, v2: Dual(v1, v2), primals, tangents)

tree_primal staticmethod

tree_primal(v)

Extract primal values from a tree of Dual numbers.

Source code in src/genjax/adev/__init__.py
@staticmethod
def tree_primal(v):
    """Extract primal values from a tree of Dual numbers.

    Args:
        v: Tree that may contain Dual numbers

    Returns:
        Tree with Dual numbers replaced by their primal values
    """

    def _inner(v):
        if isinstance(v, Dual):
            return v.primal
        else:
            return v

    return jtu.tree_map(_inner, v, is_leaf=lambda v: isinstance(v, Dual))

tree_tangent staticmethod

tree_tangent(v)

Extract tangent values from a tree of Dual numbers.

Source code in src/genjax/adev/__init__.py
@staticmethod
def tree_tangent(v):
    """Extract tangent values from a tree of Dual numbers.

    Args:
        v: Tree that may contain Dual numbers

    Returns:
        Tree with Dual numbers replaced by their tangent values
    """

    def _inner(v):
        if isinstance(v, Dual):
            return v.tangent
        else:
            return v

    return jtu.tree_map(_inner, v, is_leaf=lambda v: isinstance(v, Dual))

tree_leaves staticmethod

tree_leaves(v)

Get leaves of a tree, treating Dual numbers as atomic.

Source code in src/genjax/adev/__init__.py
@staticmethod
def tree_leaves(v):
    """Get leaves of a tree, treating Dual numbers as atomic.

    Args:
        v: Tree structure

    Returns:
        List of Dual leaves
    """
    v = Dual.tree_pure(v)
    return jtu.tree_leaves(v, is_leaf=lambda v: isinstance(v, Dual))

tree_unzip staticmethod

tree_unzip(v)

Separate a tree of Dual numbers into primal and tangent trees.

Source code in src/genjax/adev/__init__.py
@staticmethod
def tree_unzip(v):
    """Separate a tree of Dual numbers into primal and tangent trees.

    Args:
        v: Tree containing Dual numbers

    Returns:
        Tuple of (primal_leaves, tangent_leaves) as flat lists
    """
    primals = jtu.tree_leaves(Dual.tree_primal(v))
    tangents = jtu.tree_leaves(Dual.tree_tangent(v))
    return tuple(primals), tuple(tangents)

static_check_is_dual staticmethod

static_check_is_dual(v) -> bool

Check if a value is a Dual number.

Source code in src/genjax/adev/__init__.py
@staticmethod
def static_check_is_dual(v) -> bool:
    """Check if a value is a Dual number."""
    return isinstance(v, Dual)

static_check_dual_tree staticmethod

static_check_dual_tree(v) -> bool

Check if all leaves in a tree are Dual numbers.

Source code in src/genjax/adev/__init__.py
@staticmethod
def static_check_dual_tree(v) -> bool:
    """Check if all leaves in a tree are Dual numbers."""
    return all(
        map(
            lambda v: isinstance(v, Dual),
            jtu.tree_leaves(v, is_leaf=Dual.static_check_is_dual),
        )
    )

ADEV

Bases: Pytree

Interpreter for ADEV's continuation-passing style automatic differentiation.

The ADEV interpreter processes JAX computation graphs (Jaxpr) and transforms them to support stochastic automatic differentiation. It implements a continuation-passing style (CPS) transformation that reflects the law of iterated expectation, allowing different gradient estimation strategies for each stochastic operation.

Key responsibilities: 1. Propagate dual numbers through deterministic JAX operations 2. Apply CPS transformation at stochastic operations (sample_p primitives) 3. Create continuation closures for gradient estimation strategies 4. Handle control flow (conditionals, loops) within the AD system

The CPS transformation is crucial: when encountering a stochastic operation, the interpreter creates two continuations representing the rest of the computation: - Pure continuation: For sampling-based gradient estimates - Dual continuation: For the ADEV-transformed remainder

This allows each ADEVPrimitive to choose its optimal gradient strategy while maintaining composability across the entire computation graph.

ADEVProgram

Bases: Pytree

Internal representation of a stochastic program for ADEV gradient estimation.

An ADEVProgram wraps a source function containing stochastic operations and provides the infrastructure for computing Jacobian-Vector Product (JVP) estimates through the ADEV system. This class serves as an intermediate representation between user-defined @expectation functions and the low-level ADEV interpreter.

The ADEVProgram handles the integration between: 1. User source code containing ADEV primitives 2. The ADEV interpreter's CPS transformation 3. Continuation-based gradient estimation strategies

Note

This class is typically not used directly by users. It's created internally by the @expectation decorator and managed by the Expectation class.

jvp_estimate

jvp_estimate(duals: tuple[DualTree, ...], dual_kont: Callable[..., Any]) -> Dual

Compute JVP estimate for the stochastic program.

This method applies the ADEV forward-mode transformation to compute an unbiased estimate of the Jacobian-Vector Product for expectations involving stochastic operations. It uses the continuation-passing style transformation to integrate different gradient estimation strategies.

Note

This method coordinates between the user's source function and the ADEV interpreter to apply the appropriate gradient estimation strategies for each stochastic primitive encountered during execution.

Source code in src/genjax/adev/__init__.py
def jvp_estimate(
    self,
    duals: tuple[DualTree, ...],  # Pytree with Dual leaves.
    dual_kont: Callable[..., Any],
) -> Dual:
    """Compute JVP estimate for the stochastic program.

    This method applies the ADEV forward-mode transformation to compute
    an unbiased estimate of the Jacobian-Vector Product for expectations
    involving stochastic operations. It uses the continuation-passing style
    transformation to integrate different gradient estimation strategies.

    Args:
        duals: Input arguments as dual numbers (primal, tangent) pairs
        dual_kont: Continuation representing the computation after this program

    Returns:
        Dual number containing the JVP estimate (primal value + gradient estimate)

    Note:
        This method coordinates between the user's source function and the
        ADEV interpreter to apply the appropriate gradient estimation strategies
        for each stochastic primitive encountered during execution.
    """

    def adev_jvp(f):
        @wraps(f)
        def wrapped(*duals: DualTree):
            return ADEV.forward_mode(self.source.value, dual_kont)(*duals)

        return wrapped

    return adev_jvp(self.source.value)(*duals)

Expectation

Bases: Pytree

Represents an expectation with automatic differentiation support.

An Expectation object encapsulates a stochastic computation and provides methods to compute unbiased gradient estimates of expectation values. This is the primary user-facing interface for ADEV (Automatic Differentiation of Expectation Values).

The key insight is that for expectations E[f(X)] where X is a random variable, we can compute unbiased gradient estimates ∇E[f(X)] using various strategies: - REINFORCE: ∇E[f(X)] = E[f(X) * ∇log p(X)] - Reparameterization: ∇E[f(X)] = E[∇f(g(ε))] where X = g(ε), ε ~ fixed distribution - Enumeration: Exact computation for discrete distributions with finite support

Example

from genjax.adev import expectation, normal_reparam import jax.numpy as jnp

@expectation ... def loss_function(theta): ... x = normal_reparam(theta, 1.0) ... return x**2

Compute gradient estimate

grad = loss_function.grad_estimate(0.5) jnp.isfinite(grad) # doctest: +ELLIPSIS Array(True, dtype=bool...)

Compute expectation value

value = loss_function.estimate(0.5) jnp.isfinite(value) # doctest: +ELLIPSIS Array(True, dtype=bool...)

jvp_estimate

jvp_estimate(*duals: DualTree)

Compute Jacobian-Vector Product estimate for the expectation.

This method provides the core JVP computation for ADEV. It applies the continuation-passing style transformation with an identity continuation, meaning this expectation represents the "final" computation in the chain.

Note

This is the foundational method that enables both grad_estimate and integration with JAX's automatic differentiation system.

Source code in src/genjax/adev/__init__.py
def jvp_estimate(self, *duals: DualTree):
    """Compute Jacobian-Vector Product estimate for the expectation.

    This method provides the core JVP computation for ADEV. It applies the
    continuation-passing style transformation with an identity continuation,
    meaning this expectation represents the "final" computation in the chain.

    Args:
        *duals: Input arguments as dual numbers (primal, tangent) pairs

    Returns:
        Dual number with primal value E[f(X)] and tangent containing ∇E[f(X)]

    Note:
        This is the foundational method that enables both grad_estimate and
        integration with JAX's automatic differentiation system.
    """

    # Identity continuation - this expectation is the final computation
    def _identity(v):
        return v

    return self.prog.jvp_estimate(duals, _identity)

grad_estimate

grad_estimate(*primals)

Compute unbiased gradient estimate of the expectation.

This method provides the primary interface for computing gradients of expectation values. It leverages JAX's grad transformation combined with ADEV's custom JVP rules to produce unbiased gradient estimates.

Example

from genjax.adev import expectation, normal_reparam import jax.numpy as jnp

@expectation ... def objective(mu, sigma): ... x = normal_reparam(mu, sigma) ... return x**2

Compute gradient with respect to both parameters

grad_mu, grad_sigma = objective.grad_estimate(1.0, 0.5) jnp.isfinite(grad_mu) # doctest: +ELLIPSIS Array(True, dtype=bool...) jnp.isfinite(grad_sigma) # doctest: +ELLIPSIS Array(True, dtype=bool...)

Note

The gradient estimates are unbiased, meaning E[∇̂f] = ∇E[f], but they may have variance. The choice of gradient estimation strategy (REINFORCE, reparameterization, etc.) affects this variance.

Source code in src/genjax/adev/__init__.py
def grad_estimate(self, *primals):
    """Compute unbiased gradient estimate of the expectation.

    This method provides the primary interface for computing gradients of
    expectation values. It leverages JAX's grad transformation combined with
    ADEV's custom JVP rules to produce unbiased gradient estimates.

    Args:
        *primals: Input values to compute gradients with respect to

    Returns:
        If single argument: Single gradient estimate array
        If multiple arguments: Tuple of gradient estimates

    Example:
        >>> from genjax.adev import expectation, normal_reparam
        >>> import jax.numpy as jnp
        >>>
        >>> @expectation
        ... def objective(mu, sigma):
        ...     x = normal_reparam(mu, sigma)
        ...     return x**2
        >>>
        >>> # Compute gradient with respect to both parameters
        >>> grad_mu, grad_sigma = objective.grad_estimate(1.0, 0.5)
        >>> jnp.isfinite(grad_mu)  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)
        >>> jnp.isfinite(grad_sigma)  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)

    Note:
        The gradient estimates are unbiased, meaning E[∇̂f] = ∇E[f], but they
        may have variance. The choice of gradient estimation strategy (REINFORCE,
        reparameterization, etc.) affects this variance.
    """

    def _invoke_closed_over(primals):
        return invoke_closed_over(self, primals)

    grad_result = jax.grad(_invoke_closed_over)(primals)

    # Return single gradient for single argument, tuple for multiple arguments
    if len(primals) == 1:
        return grad_result[0]
    else:
        return grad_result

estimate

estimate(*args)

Compute the expectation value (forward pass only).

This method evaluates E[f(X)] without computing gradients. It's useful when you only need the expectation value itself, not its derivatives.

Example

from genjax.adev import expectation, normal_reparam import jax.numpy as jnp

@expectation ... def mean_squared(mu): ... x = normal_reparam(mu, 1.0) ... return x**2

Just compute E[X^2] where X ~ Normal(mu, 1)

expectation_value = mean_squared.estimate(2.0) jnp.isfinite(expectation_value) # doctest: +ELLIPSIS Array(True, dtype=bool...) expectation_value > 0 # Should be positive for squared values # doctest: +ELLIPSIS Array(True, dtype=bool...)

Note

This method uses zero tangents in the dual number computation, effectively performing only the forward pass through the stochastic computation graph.

Source code in src/genjax/adev/__init__.py
def estimate(self, *args):
    """Compute the expectation value (forward pass only).

    This method evaluates E[f(X)] without computing gradients. It's useful
    when you only need the expectation value itself, not its derivatives.

    Args:
        *args: Arguments to the expectation function

    Returns:
        The expectation value E[f(X)] as computed by the stochastic program

    Example:
        >>> from genjax.adev import expectation, normal_reparam
        >>> import jax.numpy as jnp
        >>>
        >>> @expectation
        ... def mean_squared(mu):
        ...     x = normal_reparam(mu, 1.0)
        ...     return x**2
        >>>
        >>> # Just compute E[X^2] where X ~ Normal(mu, 1)
        >>> expectation_value = mean_squared.estimate(2.0)
        >>> jnp.isfinite(expectation_value)  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)
        >>> expectation_value > 0  # Should be positive for squared values  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)

    Note:
        This method uses zero tangents in the dual number computation,
        effectively performing only the forward pass through the stochastic
        computation graph.
    """
    tangents = jtu.tree_map(lambda _: 0.0, args)
    return self.jvp_estimate(*Dual.dual_tree(args, tangents)).primal

REINFORCE

Bases: ADEVPrimitive

REINFORCE (score function) gradient estimator primitive.

Implements the REINFORCE gradient estimator (Williams, 1992), also known as the score function estimator or likelihood ratio method. This estimator is one of the key gradient estimation strategies supported by the ADEV framework.

Theoretical Foundation: The REINFORCE estimator is based on the score function identity: ∇_θ E[f(X)] = E[f(X) * ∇_θ log p(X; θ)]

where ∇_θ log p(X; θ) is the score function. This identity holds for any distribution p(X; θ) with differentiable log-density, making REINFORCE universally applicable but potentially high-variance.

ADEV Implementation: Within ADEV's CPS framework, REINFORCE: 1. Samples X ~ p(·; θ) using the current parameters 2. Evaluates f(X) using the dual continuation (kdual) 3. Computes the score function ∇_θ log p(X; θ) via JAX's JVP 4. Returns f(X) + f(X) * ∇_θ log p(X; θ) as the gradient estimate

Note

While general-purpose, REINFORCE can exhibit high variance. Reparameterization is preferred when available, as proven more efficient in the ADEV paper.

sample

sample(*args)

Forward sampling using the provided sample function.

Source code in src/genjax/adev/__init__.py
def sample(self, *args):
    """Forward sampling using the provided sample function."""
    return self.sample_function.value(*args)

prim_jvp_estimate

prim_jvp_estimate(dual_tree: DualTree, konts: tuple[Any, ...])

REINFORCE gradient estimation using the score function identity.

Implements the score function estimator: ∇_θ E[f(X)] = E[f(X) * ∇_θ log p(X; θ)]

This method applies the ADEV CPS transformation for REINFORCE: 1. Sample X ~ p(·; θ) from the distribution with current parameters 2. Evaluate f(X) using the dual continuation (kdual) to get the function value 3. Compute the score function ∇_θ log p(X; θ) using JAX's forward-mode AD 4. Combine via the REINFORCE identity: f(X) + f(X) * ∇_θ log p(X; θ)

The dual continuation captures the ADEV-transformed "rest of the computation" after this stochastic choice, enabling modular composition with other gradient estimation strategies as described in the ADEV paper.

Source code in src/genjax/adev/__init__.py
def prim_jvp_estimate(
    self,
    dual_tree: DualTree,
    konts: tuple[Any, ...],
):
    """REINFORCE gradient estimation using the score function identity.

    Implements the score function estimator: ∇_θ E[f(X)] = E[f(X) * ∇_θ log p(X; θ)]

    This method applies the ADEV CPS transformation for REINFORCE:
    1. Sample X ~ p(·; θ) from the distribution with current parameters
    2. Evaluate f(X) using the dual continuation (kdual) to get the function value
    3. Compute the score function ∇_θ log p(X; θ) using JAX's forward-mode AD
    4. Combine via the REINFORCE identity: f(X) + f(X) * ∇_θ log p(X; θ)

    The dual continuation captures the ADEV-transformed "rest of the computation"
    after this stochastic choice, enabling modular composition with other
    gradient estimation strategies as described in the ADEV paper.
    """
    (_, kdual) = konts
    primals = Dual.tree_primal(dual_tree)
    tangents = Dual.tree_tangent(dual_tree)

    # Sample from the distribution
    v = self.sample(*primals)

    # Evaluate f(X) using dual continuation
    dual_tree = Dual.tree_pure(v)
    out_dual = kdual(dual_tree)
    (out_primal,), (out_tangent,) = Dual.tree_unzip(out_dual)

    # Compute score function: ∇log p(X)
    # For discrete values, use float0 tangent type as required by JAX
    v_tangent = (
        jnp.zeros(v.shape, dtype=jax.dtypes.float0)
        if v.dtype in (jnp.bool_, jnp.int32, jnp.int64)
        else jnp.zeros_like(v)
    )
    _, lp_tangent = jax.jvp(
        self.differentiable_logpdf.value,
        (v, *primals),
        (v_tangent, *tangents),
    )

    # REINFORCE identity: ∇E[f(X)] = f(X) + f(X) * ∇log p(X)
    # This gives an unbiased estimate of the gradient as proven in the ADEV paper
    return Dual(out_primal, out_tangent + (out_primal * lp_tangent))

FlipEnum

Bases: ADEVPrimitive

Exact enumeration gradient estimator for Bernoulli distributions.

For discrete distributions with finite support, we can compute exact gradients by enumerating all possible outcomes and weighting by their probabilities. This gives zero-variance gradient estimates for the flip/Bernoulli case.

The estimator computes: ∇E[f(X)] = pf(True) + (1-p)f(False)

FlipMVD

Bases: ADEVPrimitive

Measure-Valued Derivative (MVD) gradient estimator for Bernoulli distributions.

Implements the measure-valued derivative approach for gradient estimation with discrete distributions. MVD is a flexible gradient estimation technique that decomposes the derivative of a probability density into positive and negative components: ∇_θ p(x; θ) = c_θ(p^+(x; θ) - p^-(x; θ)).

Theoretical Foundation: For discrete distributions like Bernoulli, MVD enables gradient estimation without requiring differentiability assumptions. The estimator works by: 1. Sampling from the original distribution 2. Evaluating the function on both the sampled value and its complement 3. Using a signed difference to create an unbiased gradient estimate

MVD Implementation for Bernoulli: The key insight is using the "phantom estimator" approach where: - The sampled outcome determines the sign via (-1)^v - Both the actual outcome and its complement are evaluated - The difference (other - b_primal) captures the discrete gradient

Advantages: - Works with discrete distributions where REINFORCE may have issues - No differentiability requirements on the objective function - Provides unbiased gradient estimates for discrete parameters

Disadvantages: - Computationally expensive (requires multiple evaluations) - Higher variance than reparameterization when applicable - Only applies to single parameters at a time

Note

This is a "phantom estimator" that evaluates the function on auxiliary samples (the complement outcome) to construct the gradient estimate. The (-1)^v term creates the appropriate sign for the discrete difference.

sample

sample(*args)

Sample from Bernoulli distribution.

Source code in src/genjax/adev/__init__.py
def sample(self, *args):
    """Sample from Bernoulli distribution."""
    p = (args,)
    return 1 == bernoulli.sample(probs=p)

prim_jvp_estimate

prim_jvp_estimate(dual_tree: DualTree, konts: tuple[Any, ...])

Measure-valued derivative gradient estimation for Bernoulli.

Implements the MVD approach using phantom estimation: 1. Sample v ~ Bernoulli(p) to get the primary outcome 2. Evaluate f(v) using the dual continuation (kdual) 3. Evaluate f(¬v) using the pure continuation (kpure) as phantom estimate 4. Combine with signed difference: (-1)^v * (f(¬v) - f(v))

The (-1)^v term ensures the correct sign for the discrete gradient: - When v=1: -1 * (f(0) - f(1)) = f(1) - f(0) - When v=0: +1 * (f(1) - f(0)) = f(1) - f(0)

This creates an unbiased estimator of ∇_p E[f(X)] for X ~ Bernoulli(p).

Source code in src/genjax/adev/__init__.py
def prim_jvp_estimate(
    self,
    dual_tree: DualTree,
    konts: tuple[Any, ...],
):
    """Measure-valued derivative gradient estimation for Bernoulli.

    Implements the MVD approach using phantom estimation:
    1. Sample v ~ Bernoulli(p) to get the primary outcome
    2. Evaluate f(v) using the dual continuation (kdual)
    3. Evaluate f(¬v) using the pure continuation (kpure) as phantom estimate
    4. Combine with signed difference: (-1)^v * (f(¬v) - f(v))

    The (-1)^v term ensures the correct sign for the discrete gradient:
    - When v=1: -1 * (f(0) - f(1)) = f(1) - f(0)
    - When v=0: +1 * (f(1) - f(0)) = f(1) - f(0)

    This creates an unbiased estimator of ∇_p E[f(X)] for X ~ Bernoulli(p).
    """
    (kpure, kdual) = konts
    (p_primal,) = Dual.tree_primal(dual_tree)
    (p_tangent,) = Dual.tree_tangent(dual_tree)  # Fix: was tree_primal

    # Sample from Bernoulli(p)
    v = bernoulli.sample(probs=p_primal)
    b = v == 1

    # Evaluate f(v) using dual continuation
    # For discrete values, use float0 tangent type as required by JAX
    b_tangent_zero = (
        jnp.zeros(b.shape, dtype=jax.dtypes.float0)
        if b.dtype in (jnp.bool_, jnp.int32, jnp.int64)
        else jnp.zeros_like(b)
    )
    b_dual = kdual(Dual(b, b_tangent_zero))
    (b_primal,), (b_tangent,) = Dual.tree_unzip(b_dual)

    # Evaluate f(¬v) using pure continuation (phantom estimate)
    other_result = kpure(jnp.logical_not(b))

    # Extract scalar value using JAX-compatible tree operations
    # kpure may return a pytree structure, so we flatten and take the first element
    other_flat, _ = jtu.tree_flatten(other_result)
    other = other_flat[0]  # Assume there's always at least one element

    # MVD estimator: (-1)^v * (f(¬v) - f(v))
    # This creates the signed discrete difference for gradient estimation
    est = ((-1) ** v) * (other - b_primal)

    return Dual(b_primal, b_tangent + est * p_tangent)

NormalREPARAM

Bases: ADEVPrimitive

Reparameterization (pathwise) gradient estimator for normal distributions.

Implements the reparameterization trick, also known as the pathwise estimator, which is one of the core gradient estimation strategies in the ADEV framework. This provides low-variance gradient estimates for reparameterizable distributions.

Theoretical Foundation: For a reparameterizable distribution p(X; θ) = p(g(ε; θ)) where ε ~ p(ε) is parameter-free, the pathwise estimator is: ∇_θ E[f(X)] = E[∇_θ f(g(ε; θ))]

For Normal(μ, σ): X = g(ε; μ, σ) = μ + σ * ε, where ε ~ Normal(0, 1) This reparameterization allows gradients to flow directly through the parameters μ and σ via standard automatic differentiation (chain rule).

ADEV Implementation: Within ADEV's CPS framework, reparameterization: 1. Samples parameter-free noise ε ~ Normal(0, 1) 2. Applies the transformation X = μ + σ * ε with JAX's JVP for gradients 3. Passes the reparameterized sample through the dual continuation (kdual)

This strategy typically exhibits lower variance than REINFORCE, as noted in the ADEV paper and empirical studies (Kingma & Welling, 2014).

prim_jvp_estimate

prim_jvp_estimate(dual_tree: DualTree, konts: tuple[Any, ...])

Reparameterization gradient estimation using the pathwise estimator.

Implements: ∇_θ E[f(X)] = E[∇_θ f(g(ε; θ))] where X = g(ε; θ)

This method applies the ADEV CPS transformation for reparameterization: 1. Sample parameter-free noise ε ~ Normal(0, 1) 2. Apply reparameterization X = μ + σ * ε with gradients via JAX JVP 3. Pass the dual number (X, ∇X) to the dual continuation (kdual)

The dual continuation captures the ADEV-transformed remainder of the computation, enabling low-variance gradient flow as described in the ADEV paper.

Source code in src/genjax/adev/__init__.py
def prim_jvp_estimate(
    self,
    dual_tree: DualTree,
    konts: tuple[Any, ...],
):
    """Reparameterization gradient estimation using the pathwise estimator.

    Implements: ∇_θ E[f(X)] = E[∇_θ f(g(ε; θ))] where X = g(ε; θ)

    This method applies the ADEV CPS transformation for reparameterization:
    1. Sample parameter-free noise ε ~ Normal(0, 1)
    2. Apply reparameterization X = μ + σ * ε with gradients via JAX JVP
    3. Pass the dual number (X, ∇X) to the dual continuation (kdual)

    The dual continuation captures the ADEV-transformed remainder of the
    computation, enabling low-variance gradient flow as described in the ADEV paper.
    """
    _, kdual = konts
    (mu_primal, sigma_primal) = Dual.tree_primal(dual_tree)
    (mu_tangent, sigma_tangent) = Dual.tree_tangent(dual_tree)

    # Sample parameter-free noise
    eps = normal.sample(0.0, 1.0)

    # Reparameterization: X = μ + σ * ε with gradient flow
    def _inner(mu, sigma):
        return mu + sigma * eps

    primal_out, tangent_out = jax.jvp(
        _inner,
        (mu_primal, sigma_primal),
        (mu_tangent, sigma_tangent),
    )
    return kdual(Dual(primal_out, tangent_out))

MultivariateNormalREPARAM

Bases: ADEVPrimitive

Multivariate reparameterization (pathwise) gradient estimator.

Extends the reparameterization trick to multivariate normal distributions, implementing the pathwise estimator for high-dimensional parameter spaces as supported by the ADEV framework.

Theoretical Foundation: For MultivariateNormal(μ, Σ), the reparameterization is: X = g(ε; μ, Σ) = μ + L @ ε where L = cholesky(Σ) and ε ~ Normal(0, I).

The pathwise estimator then gives

{μ,Σ} E[f(X)] = E[∇ f(μ + L @ ε)]

ADEV Implementation: This primitive enables efficient gradient flow with respect to both the mean vector μ and covariance matrix Σ, crucial for scalable variational inference in high-dimensional spaces. The Cholesky decomposition ensures positive definiteness while enabling automatic differentiation through the covariance structure.

This implementation follows the ADEV paper's approach to modular gradient estimation, allowing seamless integration with other stochastic primitives in complex probabilistic programs.

prim_jvp_estimate

prim_jvp_estimate(dual_tree: DualTree, konts: tuple[Any, ...])

Multivariate reparameterization using Cholesky decomposition.

Implements: ∇E[f(X)] = E[∇f(μ + L @ ε)] where L = cholesky(Σ)

This method applies the pathwise estimator for multivariate normal distributions: 1. Sample standard multivariate normal noise ε ~ Normal(0, I) 2. Apply Cholesky reparameterization X = μ + L @ ε with gradient flow 3. Pass the dual number (X, ∇X) to the dual continuation (kdual)

The Cholesky decomposition ensures efficient and numerically stable gradients with respect to the covariance matrix Σ, as described in the ADEV framework for modular gradient estimation strategies.

Source code in src/genjax/adev/__init__.py
def prim_jvp_estimate(
    self,
    dual_tree: DualTree,
    konts: tuple[Any, ...],
):
    """Multivariate reparameterization using Cholesky decomposition.

    Implements: ∇E[f(X)] = E[∇f(μ + L @ ε)] where L = cholesky(Σ)

    This method applies the pathwise estimator for multivariate normal distributions:
    1. Sample standard multivariate normal noise ε ~ Normal(0, I)
    2. Apply Cholesky reparameterization X = μ + L @ ε with gradient flow
    3. Pass the dual number (X, ∇X) to the dual continuation (kdual)

    The Cholesky decomposition ensures efficient and numerically stable
    gradients with respect to the covariance matrix Σ, as described in
    the ADEV framework for modular gradient estimation strategies.
    """
    _, kdual = konts
    (loc_primal, cov_primal) = Dual.tree_primal(dual_tree)
    (loc_tangent, cov_tangent) = Dual.tree_tangent(dual_tree)

    # Sample standard multivariate normal: ε ~ Normal(0, I)
    eps = multivariate_normal.sample(
        jnp.zeros_like(loc_primal), jnp.eye(loc_primal.shape[-1])
    )

    # Multivariate reparameterization: X = μ + L @ ε where L = cholesky(Σ)
    # This provides efficient gradients for both mean and covariance parameters
    def _inner(loc, cov):
        L = jnp.linalg.cholesky(cov)
        return loc + L @ eps

    primal_out, tangent_out = jax.jvp(
        _inner,
        (loc_primal, cov_primal),
        (loc_tangent, cov_tangent),
    )
    return kdual(Dual(primal_out, tangent_out))

sample_primitive

sample_primitive(adev_prim: ADEVPrimitive, *args)

Integrate an ADEV primitive with the PJAX infrastructure.

This function wraps an ADEVPrimitive so it can be used within GenJAX's probabilistic programming system. It ensures the primitive works correctly with JAX transformations (jit, vmap, grad) and addressing (@) operators.

The key insight is that ADEV primitives need to be integrated with PJAX's sample_binder to get proper parameter setup (like flat_keyful_sampler) that enables compatibility with the seed transformation and other GenJAX features.

Note

This function was crucial for fixing the flat_keyful_sampler error - previously ADEV primitives bypassed sample_binder and lacked proper parameter setup for JAX transformations.

Source code in src/genjax/adev/__init__.py
def sample_primitive(adev_prim: ADEVPrimitive, *args):
    """Integrate an ADEV primitive with the PJAX infrastructure.

    This function wraps an ADEVPrimitive so it can be used within GenJAX's
    probabilistic programming system. It ensures the primitive works correctly
    with JAX transformations (jit, vmap, grad) and addressing (@) operators.

    The key insight is that ADEV primitives need to be integrated with PJAX's
    sample_binder to get proper parameter setup (like flat_keyful_sampler) that
    enables compatibility with the seed transformation and other GenJAX features.

    Args:
        adev_prim: The ADEV primitive to integrate
        *args: Arguments to pass to the primitive's sample method

    Returns:
        Sample from the primitive, properly integrated with PJAX infrastructure

    Note:
        This function was crucial for fixing the flat_keyful_sampler error -
        previously ADEV primitives bypassed sample_binder and lacked proper
        parameter setup for JAX transformations.
    """

    def _adev_prim_call(key, adev_prim, *args, **kwargs):
        """Wrapper function that conforms to sample_binder's expected signature."""
        return adev_prim.sample(*args)

    return sample_binder(_adev_prim_call)(adev_prim, *args)

expectation

expectation(source: Callable[..., Any]) -> Expectation

Decorator to create an Expectation object from a stochastic function.

This decorator transforms a function containing stochastic operations into an Expectation object that supports automatic differentiation of expectation values. The decorated function should use ADEV-compatible distributions (those with gradient estimation strategies like normal_reparam, normal_reinforce, etc.).

Example

from genjax.adev import expectation, normal_reparam

Basic usage

@expectation ... def quadratic_loss(theta): ... x = normal_reparam(theta, 1.0) # Reparameterizable distribution ... return (x - 2.0)**2

Compute gradient

gradient = quadratic_loss.grad_estimate(1.0) import jax.numpy as jnp jnp.isfinite(gradient) # doctest: +ELLIPSIS Array(True, dtype=bool...)

Compute expectation value

loss_value = quadratic_loss.estimate(1.0) jnp.isfinite(loss_value) # doctest: +ELLIPSIS Array(True, dtype=bool...)

More complex example with multiple variables: @expectation def complex_objective(mu, sigma): x = normal_reparam(mu, sigma) y = normal_reinforce(0.0, 1.0) # REINFORCE strategy return jnp.sin(x) * jnp.cos(y)

grad_mu, grad_sigma = complex_objective.grad_estimate(0.5, 1.2) ```

Note

The function should only use ADEV-compatible distributions that have gradient estimation strategies. Regular distributions (normal, beta, etc.) won't provide gradient estimates - use their ADEV variants instead (normal_reparam, normal_reinforce, flip_enum, etc.).

The resulting Expectation object's interfaces (grad_estimate, estimate, etc.) are compatible with JAX transformations like jit and modular_vmap. The Expectation object itself is also a Pytree, so it can be passed as an argument to JAX-transformed functions. Use modular_vmap instead of regular vmap for proper handling of probabilistic primitives within ADEV programs.

Source code in src/genjax/adev/__init__.py
def expectation(source: Callable[..., Any]) -> Expectation:
    """Decorator to create an Expectation object from a stochastic function.

    This decorator transforms a function containing stochastic operations into an
    Expectation object that supports automatic differentiation of expectation values.
    The decorated function should use ADEV-compatible distributions (those with
    gradient estimation strategies like normal_reparam, normal_reinforce, etc.).

    Args:
        source: Function containing stochastic operations using ADEV primitives

    Returns:
        Expectation object with grad_estimate, jvp_estimate, and estimate methods

    Example:
        >>> from genjax.adev import expectation, normal_reparam
        >>>
        >>> # Basic usage
        >>> @expectation
        ... def quadratic_loss(theta):
        ...     x = normal_reparam(theta, 1.0)  # Reparameterizable distribution
        ...     return (x - 2.0)**2
        >>>
        >>> # Compute gradient
        >>> gradient = quadratic_loss.grad_estimate(1.0)
        >>> import jax.numpy as jnp
        >>> jnp.isfinite(gradient)  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)
        >>>
        >>> # Compute expectation value
        >>> loss_value = quadratic_loss.estimate(1.0)
        >>> jnp.isfinite(loss_value)  # doctest: +ELLIPSIS
        Array(True, dtype=bool...)

        More complex example with multiple variables:
        @expectation
        def complex_objective(mu, sigma):
            x = normal_reparam(mu, sigma)
            y = normal_reinforce(0.0, 1.0)  # REINFORCE strategy
            return jnp.sin(x) * jnp.cos(y)

        grad_mu, grad_sigma = complex_objective.grad_estimate(0.5, 1.2)
        ```

    Note:
        The function should only use ADEV-compatible distributions that have
        gradient estimation strategies. Regular distributions (normal, beta, etc.)
        won't provide gradient estimates - use their ADEV variants instead
        (normal_reparam, normal_reinforce, flip_enum, etc.).

        The resulting Expectation object's interfaces (grad_estimate, estimate, etc.)
        are compatible with JAX transformations like jit and modular_vmap. The
        Expectation object itself is also a Pytree, so it can be passed as an
        argument to JAX-transformed functions. Use modular_vmap instead of regular
        vmap for proper handling of probabilistic primitives within ADEV programs.
    """
    prog = ADEVProgram(const(source))
    return Expectation(prog)

invoke_closed_over

invoke_closed_over(instance, args)

Primal forward-mode function for Expectation objects with custom JVP rule.

This function serves as the primal computation for JAX's custom JVP rule registration. It's defined externally to the Expectation class to avoid complications with defining custom JVP rules on Pytree classes.

Note

This function is decorated with @jax.custom_jvp to register ADEV's jvp_estimate as the custom JVP rule. This allows JAX to automatically synthesize grad implementations that use ADEV's unbiased gradient estimators.

Source code in src/genjax/adev/__init__.py
@jax.custom_jvp
def invoke_closed_over(instance, args):
    """Primal forward-mode function for Expectation objects with custom JVP rule.

    This function serves as the primal computation for JAX's custom JVP rule
    registration. It's defined externally to the Expectation class to avoid
    complications with defining custom JVP rules on Pytree classes.

    Args:
        instance: Expectation object to evaluate
        args: Arguments to pass to the expectation

    Returns:
        The expectation value computed by instance.estimate(*args)

    Note:
        This function is decorated with @jax.custom_jvp to register ADEV's
        jvp_estimate as the custom JVP rule. This allows JAX to automatically
        synthesize grad implementations that use ADEV's unbiased gradient estimators.
    """
    return instance.estimate(*args)

invoke_closed_over_jvp

invoke_closed_over_jvp(primals: tuple, tangents: tuple)

Custom JVP rule that delegates to ADEV's jvp_estimate method.

This function registers ADEV's jvp_estimate as the JVP rule for Expectation objects, enabling JAX to automatically synthesize grad implementations. When JAX encounters invoke_closed_over in a computation that requires differentiation, it will use this rule instead of trying to differentiate through the stochastic computation.

Note

This converts between JAX's JVP representation (separate primals/tangents) and ADEV's dual number representation, then delegates to jvp_estimate for the actual gradient computation using ADEV's CPS transformation.

Source code in src/genjax/adev/__init__.py
def invoke_closed_over_jvp(primals: tuple, tangents: tuple):
    """Custom JVP rule that delegates to ADEV's jvp_estimate method.

    This function registers ADEV's jvp_estimate as the JVP rule for Expectation
    objects, enabling JAX to automatically synthesize grad implementations.
    When JAX encounters invoke_closed_over in a computation that requires
    differentiation, it will use this rule instead of trying to differentiate
    through the stochastic computation.

    Args:
        primals: Tuple of (instance, args) representing the primal values
        tangents: Tuple of (_, tangents) representing the tangent vectors

    Returns:
        Tuple of (primal_output, tangent_output) where:
        - primal_output: The expectation value E[f(X)]
        - tangent_output: ADEV's unbiased gradient estimate ∇E[f(X)]

    Note:
        This converts between JAX's JVP representation (separate primals/tangents)
        and ADEV's dual number representation, then delegates to jvp_estimate
        for the actual gradient computation using ADEV's CPS transformation.
    """
    (instance, primals) = primals
    (_, tangents) = tangents
    duals = Dual.dual_tree(primals, tangents)
    out_dual = instance.jvp_estimate(*duals)
    (v,), (tangent,) = Dual.tree_unzip(out_dual)
    return v, tangent

reinforce

reinforce(sample_func, logpdf_func)

Factory function for creating REINFORCE gradient estimators.

Example

normal_reinforce_prim = reinforce(normal.sample, normal.logpdf)

Source code in src/genjax/adev/__init__.py
def reinforce(sample_func, logpdf_func):
    """Factory function for creating REINFORCE gradient estimators.

    Args:
        sample_func: Function to sample from distribution
        logpdf_func: Function to compute log-probability density

    Returns:
        REINFORCE primitive for the given distribution

    Example:
        >>> normal_reinforce_prim = reinforce(normal.sample, normal.logpdf)
    """
    return REINFORCE(const(sample_func), const(logpdf_func))