Source code for pyrenew.distributions.censorednormal
# numpydoc ignore=GL08
import jax
import jax.numpy as jnp
import numpyro
import numpyro.util
from numpyro.distributions import constraints
from numpyro.distributions.util import promote_shapes, validate_sample
[docs]
class CensoredNormal(numpyro.distributions.Distribution):
"""
Censored normal distribution under which samples
are truncated to lie within a specified interval.
This implementation is adapted from
https://github.com/dylanhmorris/host-viral-determinants/blob/main/src/distributions.py
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
pytree_data_fields = (
"loc",
"scale",
"lower_limit",
"upper_limit",
"_support",
)
def __init__(
self,
loc=0,
scale=1,
lower_limit=-jnp.inf,
upper_limit=jnp.inf,
validate_args=None,
):
"""
Default constructor
Parameters
----------
loc : ArrayLike or float, optional
The mean of the normal distribution.
Defaults to 0.
scale : ArrayLike or float, optional
The standard deviation of the normal
distribution. Must be positive. Defaults to 1.
lower_limit : float, optional
The lower bound of the interval for censoring.
Defaults to -inf (no lower bound).
upper_limit : float, optional
The upper bound of the interval for censoring.
Defaults to inf (no upper bound).
validate_args : bool, optional
If True, checks if parameters are valid.
Defaults to None.
Returns
-------
None
"""
self.loc, self.scale = promote_shapes(loc, scale)
self.lower_limit = lower_limit
self.upper_limit = upper_limit
self._support = constraints.interval(
self.lower_limit, self.upper_limit
)
batch_shape = jax.lax.broadcast_shapes(
jnp.shape(loc), jnp.shape(scale)
)
self.normal_ = numpyro.distributions.Normal(
loc=loc, scale=scale, validate_args=validate_args
)
super().__init__(batch_shape=batch_shape, validate_args=validate_args)
@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self): # numpydoc ignore=GL08
return self._support
[docs]
def sample(self, key, sample_shape=()):
"""
Generates samples from the censored normal distribution.
Returns
-------
Array
Containing samples from the censored normal distribution.
"""
assert numpyro.util.is_prng_key(key)
result = self.normal_.sample(key, sample_shape)
return jnp.clip(result, min=self.lower_limit, max=self.upper_limit)
@validate_sample
def log_prob(self, value):
"""
Computes the log probability density of a given value(s) under
the censored normal distribution.
Returns
-------
Array
Containing log probability of the given value(s)
under the censored normal distribution
"""
rescaled_ulim = (self.upper_limit - self.loc) / self.scale
rescaled_llim = (self.lower_limit - self.loc) / self.scale
lim_val = jnp.where(
value <= self.lower_limit,
jax.scipy.special.log_ndtr(rescaled_llim),
jax.scipy.special.log_ndtr(-rescaled_ulim),
)
# we exploit the fact that for the
# standard normal, P(x > a) = P(-x < a)
# to compute the log complementary CDF
inbounds = jnp.logical_and(
value > self.lower_limit, value < self.upper_limit
)
result = jnp.where(inbounds, self.normal_.log_prob(value), lim_val)
return result