Overview (Tutorial): Vectorized Probabilistic Programming #

To introduce our language, consider the task of polynomial regression: given a dataset of pairs $(x_i, y_i) \in \mathbb{R}^2$ for $1 \leq i \leq n$, we wish to infer a polynomial relating $x$ and $y$. In the following sections, we illustrate how to solve this problem using generative functions and programmable inference in GenJAX.

Vectorizing Generative Functions with vmap

The first figure depicts a generative model for quadratic regression. The ultimate goal is to, given a noisy dataset $(x_i, y_i)_{1 \leq i \leq n}$, infer a quadratic function that plausibly governs the relationship between $x$ and $y$. Our model for this task is defined by composing generative functions, each defined as a @gen-decorated Python function. A key feature of GenJAX is that each random choice is assigned a string-valued name using the syntax dist @ "name". The polynomial generative function describes a prior distribution on the coefficients (a, b, c) of the underlying quadratic function. The point generative function models how an individual datapoint is generated based on those coefficients. Finally, npoint_curve calls polynomial to generate coefficients and maps the point generative function over a vector of inputs.

This is our first use of vmap: we use it to generate multiple y values in parallel, exploiting the fact that datapoints are generated conditionally independently of one another, given the coefficients. This is an instance of a general pattern that appears in many probabilistic programs, and is one key place where vectorization can yield significant speed-ups when parts of the generative model can be parallelized.

Generative functions
# Basic polynomial model
@gen
def polynomial():
  # @ denotes introduction of
  # random choices
  a = normal(0, 1) @ "a"
  b = normal(0, 1) @ "b"
  c = normal(0, 1) @ "c"
  return (a, b, c)

# Point model with noise
@gen
def point(x, a, b, c):
  y_mean = a + b * x + c * x ** 2
  y = normal(y_mean, 0.2) @ "obs"
  return y
Vectorization with vmap
@gen
def npoint_curve(xs):
  (a, b, c) = polynomial() @ "curve"
  # Vectorization for modeling: here, over data points
  ys = point.vmap(args_mapped=0)(xs, a, b, c) @ "ys"
  return (a, b, c), ys

# Vectorized sampling from the generative function
# using the simulate interface.
xs = array([0.1, 0.3, 0.4, 0.6])
traces = vmap(simulate(npoint_curve), repeat=4)(xs)

# Vectorized evaluation of the pointwise density
# using the assess interface.
xs = traces.get_args()
densities, retvals = (
    vmap(assess(npoint_curve), args_mapped=0)(
        traces, xs
    )
)
Figure. Vectorization of generative functions. Left: Probabilistic programs encoding a prior over quadratic functions, and a single-datapoint likelihood. Right: vmap can be used to parallelize the likelihood: the same program that works for single points works for many points via vmap. Inference operations are also compatible with vmap.

When vmap transforms a generative function, it induces a transformation on the values in the trace, preserving the structure of the trace while converting scalars to arrays and returning a trace in struct-of-array representation.

Vectorized Programmable Inference

Generative functions are compiled to implementations of the generative function interface, which includes methods like simulate and assess. The simulate method runs a generative function and yields an execution trace, and assess computes the probability density of a generative function at a given trace. These methods can be composed to implement inference algorithms. For example, likelihood weighting involves simulating many possible traces from the prior, and assessing them under the likelihood.

By vectorizing the compiled simulate and assess methods, we can generate or assess many traces at once. We can use vmap to scale the number of guesses, automatically transforming single-particle code into a many-particle vectorized version. If the algorithm is executed in parallel on a GPU, the number of particles can be freely increased as long as the GPU has free memory. The time remains near constant as we increase the number of particles, and the accuracy improves to convergence.

Single particle importance sampling
# Single particle importance sampling.
def importance_sampling(ys, xs):
  trace = simulate(default_proposal)(xs)
  logp, _ = assess(npoint_curve)(
      {"ys" : {"obs" : ys}},
      xs
  )
  w = logp - trace.get_score()
  return (trace, w)
Vectorized over N particles
# Vectorized over N particles.
def vectorized_importance_sampling(ys, xs, N):
  # vmap automatically batches over n copies
  return vmap(
      importance_sampling,
      repeat=N
  )(ys, xs)

# Compute log marginal likelihood estimate.
def lmle(ws, N):
  return logsumexp(ws) - log(N)
Scaling behavior of vectorized importance sampling Posterior approximations at different particle counts
Figure. Vectorized programmable inference. Top left: Single-particle importance sampling with a proposal (the default proposal is the prior in the npoint_curve model, excluding the "obs" random variable) implemented using generative function interface methods (simulate and assess). Top right: Using vmap, we can automatically transform the single-particle version into a many-particle vectorized version. Middle: The vectorized version runs in parallel on GPUs; the runtime is nearly constant as long as the GPU has memory to spare. Increasing the number of particles increases accuracy. Bottom: Posterior approximations for different numbers of particles N.

Improving Robustness Using Stochastic Branching

In real-world data, the assumptions of simple polynomial regression are often violated. Our polynomial model assumes every data point follows the same noise model — but what if 10% of our measurements follow a different distribution? We can improve robustness by using stochastic branching, which allows us to account for outlier observations through heterogeneous mixture modeling. Each data point gets a latent outlier flag. If the flag is true, the observation comes from a uniform distribution; if false, it follows the noisy polynomial curve.

Outlier-robust observation model
# Outlier-robust observation model
@gen
def point_with_outliers(x, a, b, c):
  outlier_flag = bernoulli(0.1) @ "outlier"
  y_mean = a + b * x + c * x ** 2
  return cond(outlier_flag,
    lambda x: uniform(-2.0, 2.0),
    lambda x: trunc_norm(x, 0.05, 2.0),
    y_mean,
  ) @ "obs"
Vectorized curve model with outliers
# Vectorized curve model with outliers
@gen
def npoint_curve_with_outliers(xs):
  (a, b, c) = polynomial() @ "curve"
  ys = point_with_outliers.vmap(
        args_mapped=0,
    )(xs, a, b, c) @ "ys"
  return ys
Outlier detection comparison
Figure. Robust modeling with stochastic branching. Stochastic branching allows us to extend our models to explain more complex data, including data with outliers. Circle markers depict observed data points; the shading of the marker denotes the estimated posterior probability that the point is an outlier. Bottom, left: Using importance sampling to construct a posterior in our original model results in a poor explanation of the data. Bottom, middle: Extending the model to explicitly represent outliers as random variables should allow us to produce better explanations, but results in a harder inference problem which importance sampling cannot effectively solve. Bottom, right: Changing inference to vectorized MCMC using Gibbs sampling (to infer outliers) and Hamiltonian Monte Carlo (to infer continuous parameters) finds better explanations of the data, i.e., more accurate posterior approximations.

Improving Inference Accuracy Using Programmable Inference

Even when a model's assumptions are sensible, inference can fail to find good explanations of a given dataset. Importance sampling applied to the outlier model identifies likely outliers, but has wide uncertainty over the possible curves, and several curves do not seem to explain the data well. This is a kind of underfitting: by adding new latent variables to our model, we have made inference more challenging, and the "guess and check" approach of importance sampling runs into limitations, even with $N = 10^5$ particles — the limit where our GPU memory begins to saturate.

The right panel of the outlier figure illustrates the results of a custom hybrid algorithm, which combines Gibbs sampling and Hamiltonian Monte Carlo (HMC). The algorithm uses Gibbs sampling to identify which points are outliers, and HMC to sample from the posterior distribution over curves, given the inliers. This algorithm generates much more accurate posterior samples that explain the data well.

Vectorized Gibbs Sampling and the Generative Function Interface

We present the GenJAX implementation of the Gibbs sampling step of our hybrid algorithm. Our implementation highlights generative function interface methods, including trace manipulation and getter methods. In our outlier model, we apply Gibbs sampling to update the vector of outlier choices, keeping other random choices constant. As each outlier choice is conditionally independent from the others (given all the non-outlier choices), the outlier updates can be vectorized. For each element, we enumerate the unnormalized posterior density for the possible values for the outlier value, and then sample a new value from a categorical distribution.

Enumerative Gibbs update for single point
def gibbs_outlier(subtrace):
  def _assess(v):
    (x, a, b, c) = subtrace.args()
    chm = {"outlier": v,
           "obs": subtrace["obs"]}
    log_prob, _ = assess(point_with_outliers)(
      chm, x, a, b, c
    )
    return log_prob

  log_probs = vmap(_assess)(
    array([False, True])
  )
  return categorical(log_probs) == 1
Vectorized enumerative Gibbs
# `trace` is a single trace object
# whose fields store batched values.
def enumerative_gibbs(trace):
  xs  = trace.get_args()
  # `subtrace` refers to the struct-of-arrays
  # view for the "ys" addresses.
  subtrace = trace.get_subtrace("ys")
  new_outliers = vmap(gibbs_outlier)(subtrace)
  # `update` applies the generative function
  # interface method that edits a trace.
  new_trace, weight, _ = update(
    trace,
    {"ys": {"outlier": new_outliers}},
  )
  return new_trace
Figure. Vectorized enumerative Gibbs sampling for outlier detection. Left: Enumerative Gibbs update for a single data point's outlier indicator. For each possible value (inlier/outlier), we compute the log probability under the model (proportional to the unnormalized posterior) and sample a new indicator using categorical sampling. Right: Vectorized Gibbs sampling step that applies the single-point update across all data points using vmap, then updates the trace with the new outlier indicators.
simulate: sampling
# Unconstrained sampling of a trace
tr = simulate(npoint_curve)(xs)
assess: density evaluation
# Evaluate log density at traced sample
chm = get_choices(tr)
logp, retval = assess(npoint_curve)(chm, xs)
generate: importance sampling
# Constrained sampling of a trace
partial_chm = {"ys": {"obs": data}}
tr_, weight = generate(npoint_curve)(
    partial_chm, xs
)
update: trace modification
# Modify a trace given constraints
new_chm = {"curve": {"a": 1.0}}
tr_, w, discard = update(npoint_curve)(
    tr, new_chm, xs
)
Figure. Generative function interface methods. GenJAX's generative functions provide several methods for programmable inference — a way to extend the system with new variants of inference using high-level interfaces. For authoring programmable algorithms which use proposal distributions (like sequential Monte Carlo), the simulate method performs unconstrained sampling and reciprocal density evaluation. For density evaluation, assess evaluates the log joint density of a generative function on traced samples. The generate interface performs constrained sampling (using importance weighting), allowing construction of a trace with observation constraints. The update method modifies a trace with provided choices, returning an updated trace and an incremental importance weight, and is used by algorithms like Gibbs sampling or Hamiltonian Monte Carlo to modify traces.

Formal Model (Theory): $\gen$ — A Core Calculus #

In this section, we give the syntax and semantics of a core calculus for traced probabilistic programming with vectors, and formalize a program transformation that vectorizes probabilistic programs. The formal model distills key ideas from our actual implementation in JAX, described in Section 6.

Syntax of $\gen$

$\gen$ is a simply-typed lambda calculus which extends a standard array programming calculus in two main ways: (1) a probability monad for stochastic computations; and (2) a graded monad of generative functions, or traced probabilistic programs. Generative functions can be automatically compiled to the density functions and stochastic traced simulation procedures necessary for inference.

$$ \begin{aligned} B &::= \mathbb{B} \mid \mathbb{R} \mid \mathbb{R}_{>0} \\ T &::= B \mid T[n] \\ \eta &::= 1 \mid T \mid \eta_1 \times \eta_2 \mid \{k_1 : \eta_1, \ldots, k_n : \eta_n\} \\ \tau &::= \eta \mid \tau_1 \rightarrow \tau_2 \mid \tau_1 \times \tau_2 \mid \Dm{\eta} \mid \Pm{\eta} \mid \Gm{\gamma}{\eta} \end{aligned} $$
Figure. Core type grammar for $\gen$: base types, batched types, ground types, and computation types.

Denotational Semantics

We give a denotational semantics for $\gen$ using quasi-Borel spaces (QBS), a standard mathematical framework for higher-order probabilistic programming. We assign to each type a space and to each term a map from the interpretation of the environment to the interpretation of its return type. A generative function is interpreted as a pair of a measure on traces and a return value function that computes the program's output given values for all random choices.

Programmable Inference Transformations

The formal model characterizes how generative functions support programmable inference by compiling to simulation and density-evaluation procedures. These transformations provide the foundation for implementing inference algorithms via high-level interfaces such as simulate, assess, generate, and update, which we use throughout the system.

Vectorization as Program Transformation

We introduce vmap as a program transform for vectorization and prove its correctness for deterministic, probabilistic, and generative computations. The proofs show how vectorization preserves distributions and trace structure and how it interacts with programmable inference interfaces.

Compiler Connection

The formal results justify the implementation strategy in the compiler section: inference on a vectorized model can be implemented by vectorizing inference applied to the model. Section 6 describes the compiler architecture and how these ideas are realized in a JAX-based implementation.