genjax.adev¶
Automatic differentiation of expected values for gradient estimation.
adev
¶
ADEV: Sound Automatic Differentiation of Expected Values
This module implements ADEV (Automatic Differentiation of Expectation Values), a system for computing sound, unbiased gradient estimators of expectations involving stochastic functions. Based on the research presented in "ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs" (Lew et al., POPL 2023, arXiv:2212.06386).
ADEV is a source-to-source transformation that extends forward-mode automatic differentiation to correctly handle probabilistic computations. The key insight is transforming a probabilistic program into a new program whose expected return value is the derivative of the original program's expectation.
Theoetical Foundation: ADEV uses a continuation-passing style (CPS) transformation that reflects the law of iterated expectation: E[f(X)] = E[E[f(X) | Z]] where X depends on Z. This enables modular composition of different gradient estimation strategies while maintaining soundness guarantees proven via logical relations.
Key Concepts
- ADEVPrimitive: Stochastic primitives with custom gradient estimation strategies
- Dual Numbers: Pairs (primal, tangent) for forward-mode automatic differentiation
- Continuations: Higher-order functions representing "the rest of the computation"
- Pure continuation: Operates on primal values only (no differentiation)
- Dual continuation: Operates on dual numbers (differentiation applied)
- CPS Transformation: Allows modular selection of gradient strategies per distribution
Gradient Estimation Strategies
- REINFORCE: Score function estimator ∇E[f(X)] = E[f(X) * ∇log p(X)]
- Reparameterization: Pathwise estimator ∇E[f(g(ε))] = E[∇f(g(ε)) * ∇g(ε)]
- Enumeration: Exact computation for finite discrete distributions
- Measure-Valued Derivatives: Advanced discrete gradient estimators
Example
References
Lew, A. K., Huot, M., Staton, S., & Mansinghka, V. K. (2023). ADEV: Sound Automatic Differentiation of Expected Values of Probabilistic Programs. Proceedings of the ACM on Programming Languages, 7(POPL), 121-148.
DualTree
module-attribute
¶
DualTree
is the type of Pytree
argument values with Dual
leaves.
ADEVPrimitive
¶
Bases: Pytree
Base class for stochastic primitives with custom gradient estimation strategies.
An ADEVPrimitive represents a stochastic operation (like sampling from a distribution) that can provide custom gradient estimates through the ADEV system. Each primitive implements both forward sampling and a strategy for computing Jacobian-Vector Product (JVP) estimates during automatic differentiation.
The key insight is that different stochastic operations benefit from different gradient estimation strategies (REINFORCE, reparameterization, enumeration, etc.), and ADEVPrimitive allows each operation to specify its optimal strategy.
sample
abstractmethod
¶
Forward sampling operation.
prim_jvp_estimate
abstractmethod
¶
prim_jvp_estimate(dual_tree: tuple[DualTree, ...], konts: tuple[Callable[..., Any], Callable[..., Any]]) -> Dual
Custom JVP gradient estimation strategy.
This method implements the core gradient estimation logic for this primitive. It receives dual numbers (primal + tangent values) for the arguments and two continuations representing the rest of the computation.
Note
The choice of continuation reflects ADEV's CPS transformation approach: - Pure continuation: Evaluates the remaining computation on primal values - Dual continuation: Applies ADEV transformation to remaining computation
Different gradient strategies utilize these continuations differently: - REINFORCE: Uses dual continuation to evaluate f(X), computes ∇log p(X) - Reparameterization: Uses dual continuation with reparameterized samples - Enumeration: May use both to compute weighted exact expectations
This design enables modular composition as described in the ADEV paper.
Source code in src/genjax/adev/__init__.py
Dual
¶
Bases: Pytree
Dual number for forward-mode automatic differentiation.
A Dual number represents both a value (primal) and its derivative (tangent) with respect to some input. This is the fundamental data structure for forward-mode AD in the ADEV system.
Example
x = Dual(3.0, 1.0) # x = 3, dx/dx = 1 y = Dual(x.primal2, 2x.primalx.tangent) # y = x^2, dy/dx = 2x
tree_pure
staticmethod
¶
Convert a tree to have Dual leaves with zero tangents.
This is used to "lift" regular values into the dual number system by pairing them with zero tangents, indicating no sensitivity.
Source code in src/genjax/adev/__init__.py
dual_tree
staticmethod
¶
Combine primal and tangent trees into a tree of Dual numbers.
Source code in src/genjax/adev/__init__.py
tree_primal
staticmethod
¶
Extract primal values from a tree of Dual numbers.
Source code in src/genjax/adev/__init__.py
tree_tangent
staticmethod
¶
Extract tangent values from a tree of Dual numbers.
Source code in src/genjax/adev/__init__.py
tree_leaves
staticmethod
¶
Get leaves of a tree, treating Dual numbers as atomic.
Source code in src/genjax/adev/__init__.py
tree_unzip
staticmethod
¶
Separate a tree of Dual numbers into primal and tangent trees.
Source code in src/genjax/adev/__init__.py
static_check_is_dual
staticmethod
¶
static_check_dual_tree
staticmethod
¶
Check if all leaves in a tree are Dual numbers.
ADEV
¶
Bases: Pytree
Interpreter for ADEV's continuation-passing style automatic differentiation.
The ADEV interpreter processes JAX computation graphs (Jaxpr) and transforms them to support stochastic automatic differentiation. It implements a continuation-passing style (CPS) transformation that reflects the law of iterated expectation, allowing different gradient estimation strategies for each stochastic operation.
Key responsibilities: 1. Propagate dual numbers through deterministic JAX operations 2. Apply CPS transformation at stochastic operations (sample_p primitives) 3. Create continuation closures for gradient estimation strategies 4. Handle control flow (conditionals, loops) within the AD system
The CPS transformation is crucial: when encountering a stochastic operation, the interpreter creates two continuations representing the rest of the computation: - Pure continuation: For sampling-based gradient estimates - Dual continuation: For the ADEV-transformed remainder
This allows each ADEVPrimitive to choose its optimal gradient strategy while maintaining composability across the entire computation graph.
ADEVProgram
¶
Bases: Pytree
Internal representation of a stochastic program for ADEV gradient estimation.
An ADEVProgram wraps a source function containing stochastic operations and provides the infrastructure for computing Jacobian-Vector Product (JVP) estimates through the ADEV system. This class serves as an intermediate representation between user-defined @expectation functions and the low-level ADEV interpreter.
The ADEVProgram handles the integration between: 1. User source code containing ADEV primitives 2. The ADEV interpreter's CPS transformation 3. Continuation-based gradient estimation strategies
Note
This class is typically not used directly by users. It's created internally by the @expectation decorator and managed by the Expectation class.
jvp_estimate
¶
Compute JVP estimate for the stochastic program.
This method applies the ADEV forward-mode transformation to compute an unbiased estimate of the Jacobian-Vector Product for expectations involving stochastic operations. It uses the continuation-passing style transformation to integrate different gradient estimation strategies.
Note
This method coordinates between the user's source function and the ADEV interpreter to apply the appropriate gradient estimation strategies for each stochastic primitive encountered during execution.
Source code in src/genjax/adev/__init__.py
Expectation
¶
Bases: Pytree
Represents an expectation with automatic differentiation support.
An Expectation object encapsulates a stochastic computation and provides methods to compute unbiased gradient estimates of expectation values. This is the primary user-facing interface for ADEV (Automatic Differentiation of Expectation Values).
The key insight is that for expectations E[f(X)] where X is a random variable, we can compute unbiased gradient estimates ∇E[f(X)] using various strategies: - REINFORCE: ∇E[f(X)] = E[f(X) * ∇log p(X)] - Reparameterization: ∇E[f(X)] = E[∇f(g(ε))] where X = g(ε), ε ~ fixed distribution - Enumeration: Exact computation for discrete distributions with finite support
Example
from genjax.adev import expectation, normal_reparam import jax.numpy as jnp
@expectation ... def loss_function(theta): ... x = normal_reparam(theta, 1.0) ... return x**2
Compute gradient estimate¶
grad = loss_function.grad_estimate(0.5) jnp.isfinite(grad) # doctest: +ELLIPSIS Array(True, dtype=bool...)
Compute expectation value¶
value = loss_function.estimate(0.5) jnp.isfinite(value) # doctest: +ELLIPSIS Array(True, dtype=bool...)
jvp_estimate
¶
Compute Jacobian-Vector Product estimate for the expectation.
This method provides the core JVP computation for ADEV. It applies the continuation-passing style transformation with an identity continuation, meaning this expectation represents the "final" computation in the chain.
Note
This is the foundational method that enables both grad_estimate and integration with JAX's automatic differentiation system.
Source code in src/genjax/adev/__init__.py
grad_estimate
¶
Compute unbiased gradient estimate of the expectation.
This method provides the primary interface for computing gradients of expectation values. It leverages JAX's grad transformation combined with ADEV's custom JVP rules to produce unbiased gradient estimates.
Example
from genjax.adev import expectation, normal_reparam import jax.numpy as jnp
@expectation ... def objective(mu, sigma): ... x = normal_reparam(mu, sigma) ... return x**2
Compute gradient with respect to both parameters¶
grad_mu, grad_sigma = objective.grad_estimate(1.0, 0.5) jnp.isfinite(grad_mu) # doctest: +ELLIPSIS Array(True, dtype=bool...) jnp.isfinite(grad_sigma) # doctest: +ELLIPSIS Array(True, dtype=bool...)
Note
The gradient estimates are unbiased, meaning E[∇̂f] = ∇E[f], but they may have variance. The choice of gradient estimation strategy (REINFORCE, reparameterization, etc.) affects this variance.
Source code in src/genjax/adev/__init__.py
estimate
¶
Compute the expectation value (forward pass only).
This method evaluates E[f(X)] without computing gradients. It's useful when you only need the expectation value itself, not its derivatives.
Example
from genjax.adev import expectation, normal_reparam import jax.numpy as jnp
@expectation ... def mean_squared(mu): ... x = normal_reparam(mu, 1.0) ... return x**2
Just compute E[X^2] where X ~ Normal(mu, 1)¶
expectation_value = mean_squared.estimate(2.0) jnp.isfinite(expectation_value) # doctest: +ELLIPSIS Array(True, dtype=bool...) expectation_value > 0 # Should be positive for squared values # doctest: +ELLIPSIS Array(True, dtype=bool...)
Note
This method uses zero tangents in the dual number computation, effectively performing only the forward pass through the stochastic computation graph.
Source code in src/genjax/adev/__init__.py
REINFORCE
¶
Bases: ADEVPrimitive
REINFORCE (score function) gradient estimator primitive.
Implements the REINFORCE gradient estimator (Williams, 1992), also known as the score function estimator or likelihood ratio method. This estimator is one of the key gradient estimation strategies supported by the ADEV framework.
Theoretical Foundation: The REINFORCE estimator is based on the score function identity: ∇_θ E[f(X)] = E[f(X) * ∇_θ log p(X; θ)]
where ∇_θ log p(X; θ) is the score function. This identity holds for any distribution p(X; θ) with differentiable log-density, making REINFORCE universally applicable but potentially high-variance.
ADEV Implementation: Within ADEV's CPS framework, REINFORCE: 1. Samples X ~ p(·; θ) using the current parameters 2. Evaluates f(X) using the dual continuation (kdual) 3. Computes the score function ∇_θ log p(X; θ) via JAX's JVP 4. Returns f(X) + f(X) * ∇_θ log p(X; θ) as the gradient estimate
Note
While general-purpose, REINFORCE can exhibit high variance. Reparameterization is preferred when available, as proven more efficient in the ADEV paper.
sample
¶
prim_jvp_estimate
¶
REINFORCE gradient estimation using the score function identity.
Implements the score function estimator: ∇_θ E[f(X)] = E[f(X) * ∇_θ log p(X; θ)]
This method applies the ADEV CPS transformation for REINFORCE: 1. Sample X ~ p(·; θ) from the distribution with current parameters 2. Evaluate f(X) using the dual continuation (kdual) to get the function value 3. Compute the score function ∇_θ log p(X; θ) using JAX's forward-mode AD 4. Combine via the REINFORCE identity: f(X) + f(X) * ∇_θ log p(X; θ)
The dual continuation captures the ADEV-transformed "rest of the computation" after this stochastic choice, enabling modular composition with other gradient estimation strategies as described in the ADEV paper.
Source code in src/genjax/adev/__init__.py
FlipEnum
¶
Bases: ADEVPrimitive
Exact enumeration gradient estimator for Bernoulli distributions.
For discrete distributions with finite support, we can compute exact gradients by enumerating all possible outcomes and weighting by their probabilities. This gives zero-variance gradient estimates for the flip/Bernoulli case.
The estimator computes: ∇E[f(X)] = pf(True) + (1-p)f(False)
FlipMVD
¶
Bases: ADEVPrimitive
Measure-Valued Derivative (MVD) gradient estimator for Bernoulli distributions.
Implements the measure-valued derivative approach for gradient estimation with discrete distributions. MVD is a flexible gradient estimation technique that decomposes the derivative of a probability density into positive and negative components: ∇_θ p(x; θ) = c_θ(p^+(x; θ) - p^-(x; θ)).
Theoretical Foundation: For discrete distributions like Bernoulli, MVD enables gradient estimation without requiring differentiability assumptions. The estimator works by: 1. Sampling from the original distribution 2. Evaluating the function on both the sampled value and its complement 3. Using a signed difference to create an unbiased gradient estimate
MVD Implementation for Bernoulli: The key insight is using the "phantom estimator" approach where: - The sampled outcome determines the sign via (-1)^v - Both the actual outcome and its complement are evaluated - The difference (other - b_primal) captures the discrete gradient
Advantages: - Works with discrete distributions where REINFORCE may have issues - No differentiability requirements on the objective function - Provides unbiased gradient estimates for discrete parameters
Disadvantages: - Computationally expensive (requires multiple evaluations) - Higher variance than reparameterization when applicable - Only applies to single parameters at a time
Note
This is a "phantom estimator" that evaluates the function on auxiliary samples (the complement outcome) to construct the gradient estimate. The (-1)^v term creates the appropriate sign for the discrete difference.
sample
¶
prim_jvp_estimate
¶
Measure-valued derivative gradient estimation for Bernoulli.
Implements the MVD approach using phantom estimation: 1. Sample v ~ Bernoulli(p) to get the primary outcome 2. Evaluate f(v) using the dual continuation (kdual) 3. Evaluate f(¬v) using the pure continuation (kpure) as phantom estimate 4. Combine with signed difference: (-1)^v * (f(¬v) - f(v))
The (-1)^v term ensures the correct sign for the discrete gradient: - When v=1: -1 * (f(0) - f(1)) = f(1) - f(0) - When v=0: +1 * (f(1) - f(0)) = f(1) - f(0)
This creates an unbiased estimator of ∇_p E[f(X)] for X ~ Bernoulli(p).
Source code in src/genjax/adev/__init__.py
NormalREPARAM
¶
Bases: ADEVPrimitive
Reparameterization (pathwise) gradient estimator for normal distributions.
Implements the reparameterization trick, also known as the pathwise estimator, which is one of the core gradient estimation strategies in the ADEV framework. This provides low-variance gradient estimates for reparameterizable distributions.
Theoretical Foundation: For a reparameterizable distribution p(X; θ) = p(g(ε; θ)) where ε ~ p(ε) is parameter-free, the pathwise estimator is: ∇_θ E[f(X)] = E[∇_θ f(g(ε; θ))]
For Normal(μ, σ): X = g(ε; μ, σ) = μ + σ * ε, where ε ~ Normal(0, 1) This reparameterization allows gradients to flow directly through the parameters μ and σ via standard automatic differentiation (chain rule).
ADEV Implementation: Within ADEV's CPS framework, reparameterization: 1. Samples parameter-free noise ε ~ Normal(0, 1) 2. Applies the transformation X = μ + σ * ε with JAX's JVP for gradients 3. Passes the reparameterized sample through the dual continuation (kdual)
This strategy typically exhibits lower variance than REINFORCE, as noted in the ADEV paper and empirical studies (Kingma & Welling, 2014).
prim_jvp_estimate
¶
Reparameterization gradient estimation using the pathwise estimator.
Implements: ∇_θ E[f(X)] = E[∇_θ f(g(ε; θ))] where X = g(ε; θ)
This method applies the ADEV CPS transformation for reparameterization: 1. Sample parameter-free noise ε ~ Normal(0, 1) 2. Apply reparameterization X = μ + σ * ε with gradients via JAX JVP 3. Pass the dual number (X, ∇X) to the dual continuation (kdual)
The dual continuation captures the ADEV-transformed remainder of the computation, enabling low-variance gradient flow as described in the ADEV paper.
Source code in src/genjax/adev/__init__.py
MultivariateNormalREPARAM
¶
Bases: ADEVPrimitive
Multivariate reparameterization (pathwise) gradient estimator.
Extends the reparameterization trick to multivariate normal distributions, implementing the pathwise estimator for high-dimensional parameter spaces as supported by the ADEV framework.
Theoretical Foundation: For MultivariateNormal(μ, Σ), the reparameterization is: X = g(ε; μ, Σ) = μ + L @ ε where L = cholesky(Σ) and ε ~ Normal(0, I).
The pathwise estimator then gives
∇{μ,Σ} E[f(X)] = E[∇ f(μ + L @ ε)]
ADEV Implementation: This primitive enables efficient gradient flow with respect to both the mean vector μ and covariance matrix Σ, crucial for scalable variational inference in high-dimensional spaces. The Cholesky decomposition ensures positive definiteness while enabling automatic differentiation through the covariance structure.
This implementation follows the ADEV paper's approach to modular gradient estimation, allowing seamless integration with other stochastic primitives in complex probabilistic programs.
prim_jvp_estimate
¶
Multivariate reparameterization using Cholesky decomposition.
Implements: ∇E[f(X)] = E[∇f(μ + L @ ε)] where L = cholesky(Σ)
This method applies the pathwise estimator for multivariate normal distributions: 1. Sample standard multivariate normal noise ε ~ Normal(0, I) 2. Apply Cholesky reparameterization X = μ + L @ ε with gradient flow 3. Pass the dual number (X, ∇X) to the dual continuation (kdual)
The Cholesky decomposition ensures efficient and numerically stable gradients with respect to the covariance matrix Σ, as described in the ADEV framework for modular gradient estimation strategies.
Source code in src/genjax/adev/__init__.py
sample_primitive
¶
Integrate an ADEV primitive with the PJAX infrastructure.
This function wraps an ADEVPrimitive so it can be used within GenJAX's probabilistic programming system. It ensures the primitive works correctly with JAX transformations (jit, vmap, grad) and addressing (@) operators.
The key insight is that ADEV primitives need to be integrated with PJAX's sample_binder to get proper parameter setup (like flat_keyful_sampler) that enables compatibility with the seed transformation and other GenJAX features.
Note
This function was crucial for fixing the flat_keyful_sampler error - previously ADEV primitives bypassed sample_binder and lacked proper parameter setup for JAX transformations.
Source code in src/genjax/adev/__init__.py
expectation
¶
Decorator to create an Expectation object from a stochastic function.
This decorator transforms a function containing stochastic operations into an Expectation object that supports automatic differentiation of expectation values. The decorated function should use ADEV-compatible distributions (those with gradient estimation strategies like normal_reparam, normal_reinforce, etc.).
Example
from genjax.adev import expectation, normal_reparam
Basic usage¶
@expectation ... def quadratic_loss(theta): ... x = normal_reparam(theta, 1.0) # Reparameterizable distribution ... return (x - 2.0)**2
Compute gradient¶
gradient = quadratic_loss.grad_estimate(1.0) import jax.numpy as jnp jnp.isfinite(gradient) # doctest: +ELLIPSIS Array(True, dtype=bool...)
Compute expectation value¶
loss_value = quadratic_loss.estimate(1.0) jnp.isfinite(loss_value) # doctest: +ELLIPSIS Array(True, dtype=bool...)
More complex example with multiple variables: @expectation def complex_objective(mu, sigma): x = normal_reparam(mu, sigma) y = normal_reinforce(0.0, 1.0) # REINFORCE strategy return jnp.sin(x) * jnp.cos(y)
grad_mu, grad_sigma = complex_objective.grad_estimate(0.5, 1.2) ```
Note
The function should only use ADEV-compatible distributions that have gradient estimation strategies. Regular distributions (normal, beta, etc.) won't provide gradient estimates - use their ADEV variants instead (normal_reparam, normal_reinforce, flip_enum, etc.).
The resulting Expectation object's interfaces (grad_estimate, estimate, etc.) are compatible with JAX transformations like jit and modular_vmap. The Expectation object itself is also a Pytree, so it can be passed as an argument to JAX-transformed functions. Use modular_vmap instead of regular vmap for proper handling of probabilistic primitives within ADEV programs.
Source code in src/genjax/adev/__init__.py
invoke_closed_over
¶
Primal forward-mode function for Expectation objects with custom JVP rule.
This function serves as the primal computation for JAX's custom JVP rule registration. It's defined externally to the Expectation class to avoid complications with defining custom JVP rules on Pytree classes.
Note
This function is decorated with @jax.custom_jvp to register ADEV's jvp_estimate as the custom JVP rule. This allows JAX to automatically synthesize grad implementations that use ADEV's unbiased gradient estimators.
Source code in src/genjax/adev/__init__.py
invoke_closed_over_jvp
¶
Custom JVP rule that delegates to ADEV's jvp_estimate method.
This function registers ADEV's jvp_estimate as the JVP rule for Expectation objects, enabling JAX to automatically synthesize grad implementations. When JAX encounters invoke_closed_over in a computation that requires differentiation, it will use this rule instead of trying to differentiate through the stochastic computation.
Note
This converts between JAX's JVP representation (separate primals/tangents) and ADEV's dual number representation, then delegates to jvp_estimate for the actual gradient computation using ADEV's CPS transformation.
Source code in src/genjax/adev/__init__.py
reinforce
¶
Factory function for creating REINFORCE gradient estimators.
Example
normal_reinforce_prim = reinforce(normal.sample, normal.logpdf)