Skip to content

Distributions

CensoredNormal

CensoredNormal(
    loc=0, scale=1, lower_limit=-inf, upper_limit=inf, validate_args=None
)

Bases: Distribution

Censored normal distribution under which samples are truncated to lie within a specified interval. This implementation is adapted from https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py

Default constructor

Parameters:

Name Type Description Default
loc

The mean of the normal distribution. Defaults to 0.

0
scale

The standard deviation of the normal distribution. Must be positive. Defaults to 1.

1
lower_limit

The lower bound of the interval for censoring. Defaults to -inf (no lower bound).

-inf
upper_limit

The upper bound of the interval for censoring. Defaults to inf (no upper bound).

inf
validate_args

If True, checks if parameters are valid. Defaults to None.

None

Returns:

Type Description
None
Source code in pyrenew/distributions/censorednormal.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    loc=0,
    scale=1,
    lower_limit=-jnp.inf,
    upper_limit=jnp.inf,
    validate_args=None,
):
    """
    Default constructor

    Parameters
    ----------
    loc
        The mean of the normal distribution.
        Defaults to 0.
    scale
        The standard deviation of the normal
        distribution. Must be positive. Defaults to 1.
    lower_limit
        The lower bound of the interval for censoring.
        Defaults to -inf (no lower bound).
    upper_limit
        The upper bound of the interval for censoring.
        Defaults to inf (no upper bound).
    validate_args
        If True, checks if parameters are valid.
        Defaults to None.

    Returns
    -------
    None
    """
    self.loc, self.scale = promote_shapes(loc, scale)
    self.lower_limit = lower_limit
    self.upper_limit = upper_limit
    self._support = constraints.interval(self.lower_limit, self.upper_limit)
    batch_shape = jax.lax.broadcast_shapes(jnp.shape(loc), jnp.shape(scale))
    self.normal_ = numpyro.distributions.Normal(
        loc=loc, scale=scale, validate_args=validate_args
    )
    super().__init__(batch_shape=batch_shape, validate_args=validate_args)

log_prob

log_prob(value)

Computes the log probability density of a given value(s) under the censored normal distribution.

Returns:

Type Description
Array

Containing log probability of the given value(s) under the censored normal distribution

Source code in pyrenew/distributions/censorednormal.py
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
@validate_sample
def log_prob(self, value):
    """
    Computes the log probability density of a given value(s) under
    the censored normal distribution.

    Returns
    -------
    Array
        Containing log probability of the given value(s)
        under the censored normal distribution
    """
    rescaled_ulim = (self.upper_limit - self.loc) / self.scale
    rescaled_llim = (self.lower_limit - self.loc) / self.scale
    lim_val = jnp.where(
        value <= self.lower_limit,
        jax.scipy.special.log_ndtr(rescaled_llim),
        jax.scipy.special.log_ndtr(-rescaled_ulim),
    )
    # we exploit the fact that for the
    # standard normal, P(x > a) = P(-x < a)
    # to compute the log complementary CDF
    inbounds = jnp.logical_and(value > self.lower_limit, value < self.upper_limit)
    result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)

    return result

sample

sample(key, sample_shape=())

Generates samples from the censored normal distribution.

Returns:

Type Description
Array

Containing samples from the censored normal distribution.

Source code in pyrenew/distributions/censorednormal.py
75
76
77
78
79
80
81
82
83
84
85
86
def sample(self, key, sample_shape=()):
    """
    Generates samples from the censored normal distribution.

    Returns
    -------
    Array
        Containing samples from the censored normal distribution.
    """
    assert numpyro.util.is_prng_key(key)
    result = self.normal_.sample(key, sample_shape)
    return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)