Source code for pyrenew.observation.negativebinomial

# numpydoc ignore=GL08

from __future__ import annotations

import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike

from pyrenew.metaclass import RandomVariable


[docs] class NegativeBinomialObservation(RandomVariable): """Negative Binomial observation""" def __init__( self, name: str, concentration_rv: RandomVariable, eps: float = 1e-10, ) -> None: """ Default constructor Parameters ---------- name : str Name for the numpyro variable. concentration : RandomVariable Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion. eps : float, optional Small value to add to the predicted mean to prevent numerical instability. Defaults to 1e-10. Returns ------- None """ NegativeBinomialObservation.validate(concentration_rv) self.name = name self.concentration_rv = concentration_rv self.eps = eps
[docs] @staticmethod def validate(concentration_rv: RandomVariable) -> None: """ Check that the concentration_rv is actually a RandomVariable Parameters ---------- concentration_rv : any RandomVariable from which to sample the positive concentration parameter of the negative binomial. Returns ------- None """ assert isinstance(concentration_rv, RandomVariable) return None
[docs] def sample( self, mu: ArrayLike, obs: ArrayLike | None = None, **kwargs, ) -> ArrayLike: """ Sample from the negative binomial distribution Parameters ---------- mu : ArrayLike Mean parameter of the negative binomial distribution. obs : ArrayLike, optional Observed data, by default None. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. Returns ------- ArrayLike """ concentration = self.concentration_rv.sample() negative_binomial_sample = numpyro.sample( name=self.name, fn=dist.NegativeBinomial2( mean=mu + self.eps, concentration=concentration, ), obs=obs, ) return negative_binomial_sample