Source code for pyrenew.observation.poisson
# numpydoc ignore=GL08
from __future__ import annotations
import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
[docs]
class PoissonObservation(RandomVariable):
"""
Poisson observation process
"""
def __init__(
self,
name: str,
eps: float = 1e-8,
) -> None:
"""
Default Constructor
Parameters
----------
name : str
Passed to numpyro.sample.
eps : float, optional
Small value added to the rate parameter to avoid zero values.
Defaults to 1e-8.
Returns
-------
None
"""
self.name = name
self.eps = eps
return None
[docs]
@staticmethod
def validate(): # numpydoc ignore=GL08
None
[docs]
def sample(
self,
mu: ArrayLike,
obs: ArrayLike | None = None,
**kwargs,
) -> ArrayLike:
"""
Sample from the Poisson process
Parameters
----------
mu : ArrayLike
Rate parameter of the Poisson distribution.
obs : ArrayLike | None, optional
Observed data. Defaults to None.
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample calls, should there be any.
Returns
-------
ArrayLike
"""
poisson_sample = numpyro.sample(
name=self.name,
fn=dist.Poisson(rate=mu + self.eps),
obs=obs,
)
return poisson_sample