Source code for pyrenew.observation.poisson
# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
from __future__ import annotations
import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
[docs]
class PoissonObservation(RandomVariable):
"""
Poisson observation process
"""
def __init__(
self,
parameter_name: str = "poisson_rv",
eps: float = 1e-8,
) -> None:
"""
Default Constructor
Parameters
----------
parameter_name : str, optional
Passed to numpyro.sample. Defaults to "poisson_rv"
eps : float, optional
Small value added to the rate parameter to avoid zero values.
Defaults to 1e-8.
Returns
-------
None
"""
self.parameter_name = parameter_name
self.eps = eps
return None
[docs]
def sample(
self,
mu: ArrayLike,
obs: ArrayLike | None = None,
name: str | None = None,
**kwargs,
) -> tuple:
"""
Sample from the Poisson process
Parameters
----------
mu : ArrayLike
Rate parameter of the Poisson distribution.
obs : ArrayLike | None, optional
Observed data. Defaults to None.
name : str | None, optional
Name of the random variable. Defaults to None.
**kwargs : dict, optional
Additional keyword arguments passed through to internal sample calls, should there be any.
Returns
-------
tuple
"""
if name is None:
name = self.parameter_name
return (
numpyro.sample(
name=name,
fn=dist.Poisson(rate=mu + self.eps),
obs=obs,
),
)
[docs]
@staticmethod
def validate(): # numpydoc ignore=GL08
None