Skip to content

genjax.distributions

Built-in probability distributions that implement the Generative Function Interface.

distributions

Standard probability distributions for GenJAX.

This module provides a collection of common probability distributions wrapped as GenJAX Distribution objects. All distributions are built using TensorFlow Probability as the backend.

bernoulli module-attribute

bernoulli = tfp_distribution(Bernoulli, name='Bernoulli')

Bernoulli distribution for binary outcomes.

Mathematical Formulation

PMF: P(X = k) = p^k × (1-p)^(1-k) for k ∈ {0, 1}

Where p is the probability of success.

Mean: 𝔼[X] = p Variance: Var[X] = p(1-p) Support: {0, 1}

Parameterization

Can be specified via: - probs: p ∈ [0, 1] (probability of success) - logits: log(p/(1-p)) ∈ ℝ (log-odds)

References

.. [1] Johnson, N. L., Kotz, S., & Kemp, A. W. (1992). "Univariate Discrete Distributions". Wiley, Chapter 3.

flip module-attribute

flip = tfp_distribution(lambda p: Bernoulli(probs=p, dtype=bool_), name='Flip')

Flip distribution (Bernoulli with boolean output).

beta module-attribute

beta = tfp_distribution(Beta, name='Beta')

Beta distribution on the interval [0, 1].

Mathematical Formulation

PDF: f(x; α, β) = Γ(α+β)/(Γ(α)Γ(β)) × x^(α-1) × (1-x)^(β-1)

Where Γ is the gamma function, α > 0, β > 0.

Mean: 𝔼[X] = α/(α+β) Variance: Var[X] = αβ/((α+β)²(α+β+1)) Mode: (α-1)/(α+β-2) for α,β > 1 Support: [0, 1]

Special Cases
  • Beta(1, 1) = Uniform(0, 1)
  • Beta(α, α) is symmetric about 0.5
  • As α,β → ∞ with α/(α+β) fixed, approaches Normal
References

.. [1] Gupta, A. K., & Nadarajah, S. (2004). "Handbook of Beta Distribution and Its Applications". CRC Press.

categorical module-attribute

categorical = tfp_distribution(lambda logits: Categorical(logits), name='Categorical')

Categorical distribution over discrete outcomes.

Mathematical Formulation

PMF: P(X = k) = p_k for k ∈ {0, 1, ..., K-1}

Where ∑_k p_k = 1 and p_k ≥ 0.

Mean: 𝔼[X] = ∑_k k × p_k Variance: Var[X] = ∑_k k² × p_k - (𝔼[X])² Entropy: H[X] = -∑_k p_k log(p_k) Support: {0, 1, ..., K-1}

Parameterization
  • logits: θ_k ∈ ℝ, where p_k = exp(θ_k) / ∑_j exp(θ_j)
  • Softmax transformation ensures valid probabilities
Connection to Other Distributions
  • K=2: Equivalent to Bernoulli
  • Generalization of multinomial for single trial
References

.. [1] Bishop, C. M. (2006). "Pattern Recognition and Machine Learning". Springer, Section 2.2.

geometric module-attribute

geometric = tfp_distribution(Geometric, name='Geometric')

Geometric distribution (number of trials until first success).

Mathematical Formulation

PMF: P(X = k) = (1-p)^(k-1) × p for k ∈ {1, 2, 3, ...}

Where p ∈ (0, 1] is the probability of success.

Mean: 𝔼[X] = 1/p Variance: Var[X] = (1-p)/p² CDF: F(k) = 1 - (1-p)^k Support: {1, 2, 3, ...}

Memoryless Property

P(X > m + n | X > m) = P(X > n)

The only discrete distribution with this property.

Alternative Parameterization

Some define X as failures before first success: P(X = k) = (1-p)^k × p for k ∈ {0, 1, 2, ...}

References

.. [1] Johnson, N. L., Kotz, S., & Kemp, A. W. (1992). "Univariate Discrete Distributions". Wiley, Chapter 5.

normal module-attribute

normal = tfp_distribution(Normal, name='Normal')

Normal (Gaussian) distribution.

Mathematical Formulation

PDF: f(x; μ, σ) = (1/√(2πσ²)) × exp(-(x-μ)²/(2σ²))

Where μ ∈ ℝ is the mean, σ > 0 is the standard deviation.

Mean: 𝔼[X] = μ Variance: Var[X] = σ² MGF: M(t) = exp(μt + σ²t²/2) Support: ℝ

Standard Normal

Z = (X - μ)/σ ~ N(0, 1)

Φ(z) = P(Z ≤ z) = ∫_{-∞}^z (1/√(2π)) exp(-t²/2) dt

Properties
  • Maximum entropy distribution for fixed mean and variance
  • Stable under convolution: X₁ + X₂ ~ N(μ₁+μ₂, σ₁²+σ₂²)
  • Central Limit Theorem: Sample means converge to Normal
Example
import jax
import jax.numpy as jnp
from genjax import distributions

# Sample from normal distribution  
trace = distributions.normal.simulate(0.0, 1.0)
sample = trace.get_retval()
print(f"Sample from Normal(0, 1): {sample:.3f}")

# Evaluate log probability
log_prob, _ = distributions.normal.assess(1.5, 0.0, 1.0)
print(f"Log prob of 1.5 under Normal(0, 1): {log_prob:.3f}")

# Use in a generative function
from genjax import gen

@gen
def model():
    x = distributions.normal(0.0, 1.0) @ "x"
    y = distributions.normal(x, 0.1) @ "y"
    return x + y

# Simulate the model
trace = model.simulate()
print(f"Model output: {trace.get_retval():.3f}")
print(f"Choices: x={trace.get_choices()['x']:.3f}, y={trace.get_choices()['y']:.3f}")
References

.. [1] Patel, J. K., & Read, C. B. (1996). "Handbook of the Normal Distribution". Marcel Dekker, 2nd edition.

uniform module-attribute

uniform = tfp_distribution(Uniform, name='Uniform')

Uniform distribution on an interval.

Mathematical Formulation

PDF: f(x; a, b) = 1/(b-a) for x ∈ [a, b], 0 otherwise

Where a < b define the support interval.

Mean: 𝔼[X] = (a + b)/2 Variance: Var[X] = (b - a)²/12 CDF: F(x) = (x - a)/(b - a) for x ∈ [a, b] Support: [a, b]

Properties
  • Maximum entropy distribution on bounded interval
  • All moments exist: 𝔼[X^n] = (b^(n+1) - a^(n+1))/((n+1)(b-a))
  • Order statistics have Beta distributions
Connection to Other Distributions
  • Standard uniform U(0,1) generates other distributions
  • -log(U) ~ Exponential(1)
  • U^(1/α) ~ Power distribution
References

.. [1] Johnson, N. L., Kotz, S., & Balakrishnan, N. (1995). "Continuous Univariate Distributions". Wiley, Vol. 2, Chapter 26.

exponential module-attribute

exponential = tfp_distribution(Exponential, name='Exponential')

Exponential distribution for positive continuous values.

Mathematical Formulation

PDF: f(x; λ) = λ exp(-λx) for x ≥ 0

Where λ > 0 is the rate parameter.

Mean: 𝔼[X] = 1/λ Variance: Var[X] = 1/λ² CDF: F(x) = 1 - exp(-λx) Support: [0, ∞)

Memoryless Property

P(X > s + t | X > s) = P(X > t)

The only continuous distribution with this property.

Connection to Other Distributions
  • Special case of Gamma(1, λ)
  • -log(U) ~ Exponential(1) where U ~ Uniform(0,1)
  • Minimum of n Exponential(λ) ~ Exponential(nλ)
  • Sum of n Exponential(λ) ~ Gamma(n, λ)

poisson module-attribute

poisson = tfp_distribution(Poisson, name='Poisson')

Poisson distribution for count data.

Mathematical Formulation

PMF: P(X = k) = (λ^k / k!) × exp(-λ) for k ∈ {0, 1, 2, ...}

Where λ > 0 is the rate parameter (expected count).

Mean: 𝔼[X] = λ Variance: Var[X] = λ MGF: M(t) = exp(λ(e^t - 1)) Support: {0, 1, 2, ...}

Properties
  • Mean equals variance (equidispersion)
  • Sum of Poissons: X₁ ~ Pois(λ₁), X₂ ~ Pois(λ₂) ⇒ X₁+X₂ ~ Pois(λ₁+λ₂)
  • Limit of Binomial: Bin(n,p) → Pois(np) as n→∞, p→0, np=λ
Connection to Other Distributions
  • Poisson process: Inter-arrival times ~ Exponential(λ)
  • Large λ: Approximately Normal(λ, λ)
  • Conditional on rate: If λ ~ Gamma(α,β), then X ~ NegBin(α, β/(1+β))
References

.. [1] Johnson, N. L., Kotz, S., & Kemp, A. W. (1992). "Univariate Discrete Distributions". Wiley, Chapter 4. .. [2] Haight, F. A. (1967). "Handbook of the Poisson Distribution". Wiley.

multivariate_normal module-attribute

multivariate_normal = tfp_distribution(MultivariateNormalFullCovariance, name='MultivariateNormal')

Multivariate normal distribution.

Mathematical Formulation

PDF: f(x; μ, Σ) = (2π)^(-k/2) |det(Σ)|^(-1/2) exp(-½(x-μ)^T Σ^(-1) (x-μ))

Where μ ∈ ℝ^k is the mean vector, Σ is k×k positive definite covariance.

Mean: 𝔼[X] = μ Covariance: Cov[X] = Σ MGF: M(t) = exp(t^Tμ + ½t^TΣt) Support: ℝ^k

Properties
  • Linear transformations: If Y = AX + b, then Y ~ N(Aμ + b, AΣA^T)
  • Marginals are Normal: X_i ~ N(μ_i, Σ_{ii})
  • Conditional distributions are Normal with closed-form parameters
  • Maximum entropy for fixed mean and covariance
Special Cases
  • Σ = σ²I: Spherical/isotropic Gaussian
  • Σ diagonal: Independent components
  • k = 1: Univariate normal
References

.. [1] Mardia, K. V., Kent, J. T., & Bibby, J. M. (1979). "Multivariate Analysis". Academic Press, Chapter 3. .. [2] Tong, Y. L. (1990). "The Multivariate Normal Distribution". Springer-Verlag.

dirichlet module-attribute

dirichlet = tfp_distribution(Dirichlet, name='Dirichlet')

Dirichlet distribution for probability vectors.

Mathematical Formulation

PDF: f(x; α) = [Γ(∑ᵢαᵢ)/∏ᵢΓ(αᵢ)] × ∏ᵢ xᵢ^(αᵢ-1)

Where x ∈ δ_{k-1} (probability simplex), αᵢ > 0 are concentrations.

Mean: 𝔼[Xᵢ] = αᵢ / ∑ⱼαⱼ Variance: Var[Xᵢ] = [αᵢ(α₀-αᵢ)] / [α₀²(α₀+1)], where α₀ = ∑ⱼαⱼ Support: δ_{k-1} = {x ∈ ℝ^k : xᵢ ≥ 0, ∑ᵢxᵢ = 1}

Properties
  • Conjugate prior for categorical/multinomial
  • Marginals: Xᵢ ~ Beta(αᵢ, ∑ⱼ≠ᵢαⱼ)
  • Aggregation property: (Xᵢ + Xⱼ, X_rest) follows lower-dim Dirichlet
  • Neutral element: Dir(1, 1, ..., 1) = Uniform on simplex
Connection to Other Distributions
  • k=2: Dir(α₁, α₂) equivalent to Beta(α₁, α₂)
  • Gamma construction: If Yᵢ ~ Gamma(αᵢ, 1), then Y/∑Y ~ Dir(α)
  • Log-normal approximation for large α
References

.. [1] Kotz, S., Balakrishnan, N., & Johnson, N. L. (2000). "Continuous Multivariate Distributions". Wiley, Vol. 1, Chapter 49. .. [2] Ng, K. W., Tian, G. L., & Tang, M. L. (2011). "Dirichlet and Related Distributions". Wiley.

binomial module-attribute

binomial = tfp_distribution(Binomial, name='Binomial')

Binomial distribution for count data with fixed number of trials.

Mathematical Formulation

PMF: P(X = k) = C(n,k) × p^k × (1-p)^(n-k) for k ∈ {0, 1, ..., n}

Where n is the number of trials, p is success probability, and C(n,k) = n!/(k!(n-k)!) is the binomial coefficient.

Mean: 𝔼[X] = np Variance: Var[X] = np(1-p) MGF: M(t) = (1 - p + pet)n Support: {0, 1, 2, ..., n}

Properties
  • Sum of Bernoulli: X = ∑ᵢ Yᵢ where Yᵢ ~ Bernoulli(p)
  • Additivity: Bin(n₁,p) + Bin(n₂,p) = Bin(n₁+n₂,p)
  • Symmetry: If p = 0.5, then P(X = k) = P(X = n-k)
Approximations
  • Normal: For large n, np(1-p) > 10, approximately N(np, np(1-p))
  • Poisson: For large n, small p, np = λ moderate, approximately Pois(λ)
References

.. [1] Johnson, N. L., Kotz, S., & Kemp, A. W. (1992). "Univariate Discrete Distributions". Wiley, Chapter 3.

gamma module-attribute

gamma = tfp_distribution(Gamma, name='Gamma')

Gamma distribution for positive continuous values.

Mathematical Formulation

PDF: f(x; α, β) = (β^α / Γ(α)) × x^(α-1) × exp(-βx) for x > 0

Where α > 0 is the shape, β > 0 is the rate (or θ = 1/β is scale).

Mean: 𝔼[X] = α/β = αθ Variance: Var[X] = α/β² = αθ² Mode: (α-1)/β for α ≥ 1 Support: (0, ∞)

Special Cases
  • α = 1: Exponential(β)
  • α = k/2, β = 1/2: Chi-squared(k)
  • Integer α: Erlang distribution
Properties
  • Additivity: Gamma(α₁,β) + Gamma(α₂,β) = Gamma(α₁+α₂,β)
  • Scaling: cX ~ Gamma(α, β/c) for c > 0
  • Conjugate prior for Poisson rate, exponential rate
Connection to Other Distributions
  • If Xᵢ ~ Gamma(αᵢ, 1), then Xᵢ/∑Xⱼ ~ Dirichlet(α)
  • Inverse: 1/X ~ InverseGamma(α, β)
References

.. [1] Johnson, N. L., Kotz, S., & Balakrishnan, N. (1994). "Continuous Univariate Distributions". Wiley, Vol. 1, Chapter 17.

log_normal module-attribute

log_normal = tfp_distribution(LogNormal, name='LogNormal')

Log-normal distribution (exponential of normal random variable).

Mathematical Formulation

If Y ~ N(μ, σ²), then X = exp(Y) ~ LogNormal(μ, σ²)

PDF: f(x; μ, σ) = (1/(xσ√(2π))) × exp(-(ln(x)-μ)²/(2σ²)) for x > 0

Mean: 𝔼[X] = exp(μ + σ²/2) Variance: Var[X] = (exp(σ²) - 1) × exp(2μ + σ²) Mode: exp(μ - σ²) Support: (0, ∞)

Properties
  • Multiplicative: If Xᵢ ~ LogN(μᵢ, σᵢ²) independent, then ∏Xᵢ is log-normal
  • Not closed under addition (sum of log-normals is not log-normal)
  • Heavy right tail: all moments exist but grow rapidly
  • Median: exp(μ)
Applications
  • Income distributions
  • Stock prices (geometric Brownian motion)
  • Particle sizes
  • Species abundance
References

.. [1] Crow, E. L., & Shimizu, K. (Eds.). (1988). "Lognormal Distributions: Theory and Applications". Marcel Dekker. .. [2] Limpert, E., Stahel, W. A., & Abbt, M. (2001). "Log-normal distributions across the sciences". BioScience, 51(5), 341-352.

student_t module-attribute

student_t = tfp_distribution(StudentT, name='StudentT')

Student's t-distribution with specified degrees of freedom.

Mathematical Formulation

PDF: f(x; ν, μ, σ) = Γ((ν+1)/2)/(Γ(ν/2)√(νπ)σ) × [1 + ((x-μ)/σ)²/ν]^(-(ν+1)/2)

Where ν > 0 is degrees of freedom, μ is location, σ > 0 is scale.

Mean: 𝔼[X] = μ for ν > 1 (undefined for ν ≤ 1) Variance: Var[X] = σ²ν/(ν-2) for ν > 2 (infinite for 1 < ν ≤ 2) Support: ℝ

Properties
  • Heavier tails than normal (polynomial vs exponential decay)
  • ν → ∞: Converges to Normal(μ, σ²)
  • ν = 1: Cauchy distribution (no mean)
  • ν = 2: Finite mean but infinite variance
  • Symmetric about μ
Standardized Form

If T ~ t(ν), then X = μ + σT ~ t(ν, μ, σ)

Connection to Other Distributions
  • Ratio of normal to chi: If Z ~ N(0,1), V ~ χ²(ν), then Z/√(V/ν) ~ t(ν)
  • F-distribution: T² ~ F(1, ν) if T ~ t(ν)
References

.. [1] Lange, K. L., Little, R. J., & Taylor, J. M. (1989). "Robust statistical modeling using the t distribution". JASA, 84(408), 881-896. .. [2] Kotz, S., & Nadarajah, S. (2004). "Multivariate t-distributions and their applications". Cambridge University Press.

laplace module-attribute

laplace = tfp_distribution(Laplace, name='Laplace')

Laplace (double exponential) distribution.

half_normal module-attribute

half_normal = tfp_distribution(HalfNormal, name='HalfNormal')

Half-normal distribution (positive half of normal distribution).

inverse_gamma module-attribute

inverse_gamma = tfp_distribution(InverseGamma, name='InverseGamma')

Inverse gamma distribution for positive continuous values.

weibull module-attribute

weibull = tfp_distribution(Weibull, name='Weibull')

Weibull distribution for modeling survival times and reliability.

cauchy module-attribute

cauchy = tfp_distribution(Cauchy, name='Cauchy')

Cauchy distribution with heavy tails.

chi2 module-attribute

chi2 = tfp_distribution(Chi2, name='Chi2')

Chi-squared distribution.

multinomial module-attribute

multinomial = tfp_distribution(Multinomial, name='Multinomial')

Multinomial distribution over count vectors.

negative_binomial module-attribute

negative_binomial = tfp_distribution(NegativeBinomial, name='NegativeBinomial')

Negative binomial distribution for overdispersed count data.

zipf module-attribute

zipf = tfp_distribution(Zipf, name='Zipf')

Zipf distribution for power-law distributed discrete data.

Live Examples

Continuous Distributions

import jax.numpy as jnp
from genjax import distributions

# Assess log probability under various distributions
x = 1.5

# Normal distribution
log_prob_normal, _ = distributions.normal.assess(x, 0.0, 1.0)
print(f"Log prob of {x} under Normal(0, 1): {log_prob_normal:.3f}")

# Beta distribution (x must be in [0, 1])
x_beta = 0.7
log_prob_beta, _ = distributions.beta.assess(x_beta, 2.0, 2.0)
print(f"Log prob of {x_beta} under Beta(2, 2): {log_prob_beta:.3f}")

# Exponential distribution
log_prob_exp, _ = distributions.exponential.assess(x, 1.0)
print(f"Log prob of {x} under Exponential(1): {log_prob_exp:.3f}")

Log prob of 1.5 under Normal(0, 1): -2.044 Log prob of 0.7 under Beta(2, 2): 0.231 Log prob of 1.5 under Exponential(1): -1.500

Discrete Distributions

import jax.numpy as jnp
from genjax import distributions

# Bernoulli distribution
log_prob_bern, _ = distributions.bernoulli.assess(1, 0.7)
print(f"Log prob of 1 under Bernoulli(0.7): {log_prob_bern:.3f}")

# Categorical distribution (uses logits, not probs)
probs = jnp.array([0.2, 0.3, 0.5])
logits = jnp.log(probs)
log_prob_cat, _ = distributions.categorical.assess(2, logits)
print(f"Log prob of category 2 under Categorical(probs={probs}): {log_prob_cat:.3f}")

# Poisson distribution
log_prob_pois, _ = distributions.poisson.assess(4, 3.0)
print(f"Log prob of 4 under Poisson(3.0): {log_prob_pois:.3f}")

Log prob of 1 under Bernoulli(0.7): -0.403 Log prob of category 2 under Categorical(probs=[0.2 0.3 0.5]): -0.693 Log prob of 4 under Poisson(3.0): -1.784

Distribution Parameters

# Distributions are parameterized by their standard parameters
print("Common distribution parameterizations:")
print("- normal(mu, sigma)")
print("- beta(alpha, beta)")
print("- exponential(rate)")
print("- bernoulli(p)")
print("- categorical(logits)  # Note: uses logits, not probs")
print("- poisson(rate)")
print("- gamma(concentration, rate)")
print("- uniform(low, high)")

Common distribution parameterizations: - normal(mu, sigma) - beta(alpha, beta) - exponential(rate) - bernoulli(p) - categorical(logits) # Note: uses logits, not probs - poisson(rate) - gamma(concentration, rate) - uniform(low, high)