Source code for pyrenew.deterministic.deterministicpmf
# numpydoc ignore=GL08
from __future__ import annotations
from jax.typing import ArrayLike
from pyrenew.deterministic.deterministic import DeterministicVariable
from pyrenew.distutil import validate_discrete_dist_vector
from pyrenew.metaclass import RandomVariable
[docs]
class DeterministicPMF(RandomVariable):
"""
A deterministic (degenerate) random variable that represents a PMF."""
def __init__(
self,
name: str,
value: ArrayLike,
tol: float = 1e-5,
) -> None:
"""
Default constructor
Automatically checks that the elements in `value` can be indeed
considered to be a PMF by calling
pyrenew.distutil.validate_discrete_dist_vector on each one of its
entries.
Parameters
----------
name : str
A name to assign to the variable.
value : ArrayLike
An ArrayLike object.
tol : float, optional
Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults
to 1e-5.
Returns
-------
None
"""
value = validate_discrete_dist_vector(
discrete_dist=value,
tol=tol,
)
self.basevar = DeterministicVariable(
name=name,
value=value,
)
return None
[docs]
@staticmethod
def validate(value: ArrayLike) -> None:
"""
Validates input to DeterministicPMF
Parameters
----------
value : ArrayLike
An ArrayLike object.
Returns
-------
None
"""
return None
[docs]
def sample(
self,
**kwargs,
) -> ArrayLike:
"""
Retrieves the deterministic PMF
Parameters
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal `sample()`
calls, if any
Returns
-------
ArrayLike
"""
return self.basevar.sample(**kwargs)
[docs]
def size(self) -> int:
"""
Returns the size of the PMF
Returns
-------
int
The size of the PMF
"""
return self.basevar.value.size