"""
pyrenew helper classes
"""
from abc import ABCMeta, abstractmethod
import jax.random as jr
import numpy as np
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample
def _assert_type(arg_name: str, value, expected_type) -> None:
"""
Matches TypeError arising during validation
Parameters
----------
arg_name : str
Name of the argument
value : object
The object to be validated
expected_type : type
The expected object type
Raises
-------
TypeError
If `value` is not an instance of `expected_type`.
Returns
-------
None
"""
if not isinstance(value, expected_type):
raise TypeError(
f"{arg_name} must be an instance of {expected_type}. "
f"Got {type(value)}"
)
[docs]
class RandomVariable(metaclass=ABCMeta):
"""
Abstract base class for latent and observed random variables.
"""
def __init__(self, **kwargs):
"""
Default constructor
"""
pass
[docs]
@abstractmethod
def sample(
self,
**kwargs,
) -> tuple:
"""
Sample method of the process
The method design in the class should have at least kwargs.
Parameters
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal
:meth:`sample` calls, should there be any.
Returns
-------
tuple
"""
pass
[docs]
@staticmethod
@abstractmethod
def validate(**kwargs) -> None:
"""
Validation of kwargs to be implemented in subclasses.
"""
pass
def __call__(self, **kwargs):
"""
Alias for :meth:`sample`.
"""
return self.sample(**kwargs)
[docs]
class Model(metaclass=ABCMeta):
"""Abstract base class for models"""
# Since initialized in none, values not shared across instances
kernel = None
mcmc = None
@abstractmethod
def __init__(self, **kwargs) -> None: # numpydoc ignore=GL08
pass
[docs]
@staticmethod
@abstractmethod
def validate() -> None: # numpydoc ignore=GL08
pass
[docs]
@abstractmethod
def sample(
self,
**kwargs,
) -> tuple:
"""
Sample method of the model.
The method design in the class should have at least kwargs.
Parameters
----------
**kwargs : dict, optional
Additional keyword arguments passed through to internal
:meth:`sample` calls, should there be any.
Returns
-------
tuple
"""
pass
[docs]
def model(self, **kwargs) -> tuple:
"""
Alias for the sample method.
Parameters
----------
**kwargs : dict, optional
Additional keyword arguments passed through to
internal :meth:`sample` calls, should there be any.
Returns
-------
tuple
"""
return self.sample(**kwargs)
def _init_model(
self,
num_warmup,
num_samples,
nuts_args: dict = None,
mcmc_args: dict = None,
) -> None:
"""
Creates the NUTS kernel and MCMC model
Parameters
----------
nuts_args : dict, optional
Dictionary of arguments passed to the
:class:`numpyro.infer.hmc.NUTS` constructor.
Default None.
mcmc_args : dict, optional
Dictionary of arguments passed to the
:class:`numpyro.infer.mcmc.MCMC` constructor.
Default None.
Returns
-------
None
"""
if nuts_args is None:
nuts_args = dict()
if "find_heuristic_step_size" not in nuts_args:
nuts_args["find_heuristic_step_size"] = True
if "init_strategy" not in nuts_args:
nuts_args["init_strategy"] = init_to_sample
if mcmc_args is None:
mcmc_args = dict()
self.kernel = NUTS(
model=self.model,
**nuts_args,
)
self.mcmc = MCMC(
self.kernel,
num_warmup=num_warmup,
num_samples=num_samples,
**mcmc_args,
)
return None
[docs]
def run(
self,
num_warmup,
num_samples,
rng_key: ArrayLike | None = None,
nuts_args: dict = None,
mcmc_args: dict = None,
**kwargs,
) -> None:
"""
Runs the model
Parameters
----------
nuts_args : dict, optional
Dictionary of arguments passed to the kernel
(:class:`numpyro.infer.hmc.NUTS`) constructor.
Defaults to None.
mcmc_args : dict, optional
Dictionary of arguments passed to the MCMC runner
(:class:`numpyro.infer.mcmc.MCMC`) constructor.
Defaults to None.
Returns
-------
None
"""
self._init_model(
num_warmup=num_warmup,
num_samples=num_samples,
nuts_args=nuts_args,
mcmc_args=mcmc_args,
)
if rng_key is None:
rand_int = np.random.randint(
np.iinfo(np.int64).min, np.iinfo(np.int64).max
)
rng_key = jr.key(rand_int)
self.mcmc.run(rng_key=rng_key, **kwargs)
return None
[docs]
def print_summary(
self,
prob: float = 0.9,
exclude_deterministic: bool = True,
) -> None:
"""
A wrapper of :meth:`MCMC.print_summary()
<numpyro.infer.mcmc.MCMC.print_summary>`.
Parameters
----------
prob : float, optional
The width of the credible interval to show. Default 0.9
exclude_deterministic : bool, optional
Whether to print deterministic sites in the summary.
Defaults to True.
Returns
-------
None
"""
return self.mcmc.print_summary(prob, exclude_deterministic)
[docs]
def posterior_predictive(
self,
rng_key: ArrayLike | None = None,
numpyro_predictive_args: dict = {},
**kwargs,
) -> dict:
"""
A wrapper of :class:`numpyro.infer.util.Predictive` to generate
posterior predictive samples.
Parameters
----------
rng_key : ArrayLike, optional
Random key for the Predictive function call. Defaults to None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to the
:class:`numpyro.infer.util.Predictive` constructor.
**kwargs
Additional named arguments passed to the
:meth:`__call__()` method of
:class:`numpyro.infer.util.Predictive`.
Returns
-------
dict
"""
if self.mcmc is None:
raise ValueError(
"No posterior samples available. Run model with model.run()."
)
if rng_key is None:
rand_int = np.random.randint(
np.iinfo(np.int64).min, np.iinfo(np.int64).max
)
rng_key = jr.key(rand_int)
predictive = Predictive(
model=self.model,
posterior_samples=self.mcmc.get_samples(),
**numpyro_predictive_args,
)
return predictive(rng_key, **kwargs)
[docs]
def prior_predictive(
self,
rng_key: ArrayLike | None = None,
numpyro_predictive_args: dict = {},
**kwargs,
) -> dict:
"""
A wrapper for :class:`numpyro.infer.util.Predictive`
to generate prior predictive samples.
Parameters
----------
rng_key : ArrayLike, optional
Random key for the Predictive function call.
Default None.
numpyro_predictive_args : dict, optional
Dictionary of arguments to be passed to
the :class:`numpyro.infer.util.Predictive`
constructor. Default None.
**kwargs
Additional named arguments passed to the
:meth:`__call__()` method of
:class:`numpyro.infer.util.Predictive`.
Returns
-------
dict
"""
if rng_key is None:
rand_int = np.random.randint(
np.iinfo(np.int64).min, np.iinfo(np.int64).max
)
rng_key = jr.key(rand_int)
predictive = Predictive(
model=self.model,
posterior_samples=None,
**numpyro_predictive_args,
)
return predictive(rng_key, **kwargs)