Source code for pyrenew.distutil

"""
distutil

Utilities for working with commonly-
encountered probability distributions
found in renewal equation modeling,
such as discrete time-to-event distributions
"""

from __future__ import annotations

import jax.numpy as jnp
from jax.typing import ArrayLike


[docs] def validate_discrete_dist_vector( discrete_dist: ArrayLike, tol: float = 1e-5 ) -> ArrayLike: """ Validate that a vector represents a discrete probability distribution to within a specified tolerance, raising a ValueError if not. Parameters ---------- discrete_dist : ArrayLike An jax array containing non-negative values that represent a discrete probability distribution. The values must sum to 1 within the specified tolerance. tol : float, optional The tolerance within which the sum of the distribution must be 1. Defaults to 1e-5. Returns ------- ArrayLike The normalized distribution array if the input is valid. Raises ------ ValueError If any value in discrete_dist is negative or if the sum of the distribution does not equal 1 within the specified tolerance. """ discrete_dist = discrete_dist.flatten() if not jnp.all(discrete_dist >= 0): raise ValueError( "Discrete distribution " "vector must have " "only non-negative " "entries; got {}" "".format(discrete_dist) ) dist_norm = jnp.sum(discrete_dist) if not jnp.abs(dist_norm - 1) < tol: raise ValueError( "Discrete generation interval " "distributions must sum to 1 " "with a tolerance of {}" "".format(tol) ) return discrete_dist / dist_norm
[docs] def reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike: """ Reverse a discrete distribution vector (useful for discrete time-to-event distributions). Parameters ---------- dist : ArrayLike A discrete distribution vector (likely discrete time-to-event distribution) Returns ------- ArrayLike A reversed (jnp.flip) discrete distribution vector """ return jnp.flip(dist)