Source code for pyrenew.deterministic.deterministic
# numpydoc ignore=GL08
from __future__ import annotations
import numpyro
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
[docs]
class DeterministicVariable(RandomVariable):
"""
A deterministic (degenerate) random variable. Useful to pass fixed
quantities.
"""
def __init__(
self,
name: str,
value: ArrayLike,
) -> None:
"""Default constructor
Parameters
----------
name : str
A name to assign to the variable.
value : ArrayLike
An ArrayLike object.
Returns
-------
None
"""
self.name = name
self.validate(value)
self.value = value
return None
[docs]
@staticmethod
def validate(value: ArrayLike) -> None:
"""
Validates input to DeterministicVariable
Parameters
----------
value : ArrayLike
An ArrayLike object.
Returns
-------
None
Raises
------
Exception
If the input value object is not an ArrayLike object.
"""
if not isinstance(value, ArrayLike):
raise ValueError(
f"value {value} passed to a DeterministicVariable "
f"is of type {type(value).__name__}, expected "
"an ArrayLike object"
)
return None
[docs]
def sample(
self,
record=False,
**kwargs,
) -> ArrayLike:
"""
Retrieve the value of the deterministic Rv
Parameters
----------
record : bool, optional
Whether to record the value of the deterministic
RandomVariable. Defaults to False.
**kwargs : dict, optional
Additional keyword arguments passed through to internal
sample calls, should there be any.
Returns
-------
ArrayLike
"""
if record:
numpyro.deterministic(self.name, self)
return self.value