Skip to content

Distutil

distutil

Utilities for working with commonly- encountered probability distributions found in renewal equation modeling, such as discrete time-to-event distributions

reverse_discrete_dist_vector

reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike

Reverse a discrete distribution vector (useful for discrete time-to-event distributions).

Parameters:

Name Type Description Default
dist ArrayLike

A discrete distribution vector (likely discrete time-to-event distribution)

required

Returns:

Type Description
ArrayLike

A reversed (jnp.flip) discrete distribution vector

Source code in pyrenew/distutil.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def reverse_discrete_dist_vector(dist: ArrayLike) -> ArrayLike:
    """
    Reverse a discrete distribution
    vector (useful for discrete
    time-to-event distributions).

    Parameters
    ----------
    dist
        A discrete distribution vector (likely discrete time-to-event distribution)

    Returns
    -------
    ArrayLike
        A reversed (jnp.flip) discrete distribution vector
    """
    return jnp.flip(dist)

validate_discrete_dist_vector

validate_discrete_dist_vector(
    discrete_dist: ArrayLike, tol: float = 1e-05
) -> ArrayLike

Validate that a vector represents a discrete probability distribution to within a specified tolerance, raising a ValueError if not.

Parameters:

Name Type Description Default
discrete_dist ArrayLike

An jax array containing non-negative values that represent a discrete probability distribution. The values must sum to 1 within the specified tolerance.

required
tol float

The tolerance within which the sum of the distribution must be 1. Defaults to 1e-5.

1e-05

Returns:

Type Description
ArrayLike

The normalized distribution array if the input is valid.

Raises:

Type Description
ValueError

If any value in discrete_dist is negative or if the sum of the distribution does not equal 1 within the specified tolerance.

Source code in pyrenew/distutil.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def validate_discrete_dist_vector(
    discrete_dist: ArrayLike, tol: float = 1e-5
) -> ArrayLike:
    """
    Validate that a vector represents a discrete
    probability distribution to within a specified
    tolerance, raising a ValueError if not.

    Parameters
    ----------
    discrete_dist
        An jax array containing non-negative values that
        represent a discrete probability distribution. The values
        must sum to 1 within the specified tolerance.
    tol
        The tolerance within which the sum of the distribution must
        be 1. Defaults to 1e-5.

    Returns
    -------
    ArrayLike
        The normalized distribution array if the input is valid.

    Raises
    ------
    ValueError
        If any value in discrete_dist is negative or if the sum of the
        distribution does not equal 1 within the specified tolerance.
    """
    discrete_dist = discrete_dist.flatten()
    if not np.all(discrete_dist >= 0):
        raise ValueError(
            "Discrete distribution "
            "vector must have "
            "only non-negative "
            "entries; got {}"
            "".format(discrete_dist)
        )
    dist_norm = np.sum(discrete_dist)
    if not np.abs(dist_norm - 1) < tol:
        raise ValueError(
            "Discrete generation interval "
            "distributions must sum to 1 "
            "with a tolerance of {}"
            "".format(tol)
        )
    return discrete_dist / dist_norm