Source code for pyrenew.metaclass

# -*- coding: utf-8 -*-

"""
pyrenew helper classes
"""

from abc import ABCMeta, abstractmethod
from typing import NamedTuple, get_type_hints

import jax
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
import numpy as np
import numpyro as npro
import polars as pl
from jax.typing import ArrayLike
from numpyro.infer import MCMC, NUTS, Predictive
from pyrenew.mcmcutils import plot_posterior, spread_draws


def _assert_sample_and_rtype(
    rp: "RandomVariable", skip_if_none: bool = True
) -> None:
    """
    Return type-checking for RandomVariable's sample function

    Objects passed as `RandomVariable` should (a) have a sample() method that
    (b) returns either a tuple or a named tuple.

    Parameters
    ----------
    rp : RandomVariable
        Random variable to check.
    skip_if_none : bool, optional
        When `True` it returns if `rp` is None. Defaults to True.

    Returns
    -------
    None

    Raises
    ------
    Exception
        If rp is not a RandomVariable, does not have a sample function, or
        does not return a tuple. Also occurs if rettype does not initialized
        properly.
    """

    # Addressing the None case
    if (rp is None) and (not skip_if_none):
        Exception(
            "The passed object cannot be None. It should be RandomVariable"
        )
    elif skip_if_none and (rp is None):
        return None

    if not isinstance(rp, RandomVariable):
        raise Exception(f"{rp} is not an instance of RandomVariable.")

    # Otherwise, checking for the sample function (must have one)
    # with a defined rtype.
    try:
        sfun = rp.sample
    except Exception:
        raise Exception(
            f"The RandomVariable {rp} does not have a sample function."
        )  # noqa: E722

    # Getting the return annotation (if any)
    rettype = get_type_hints(sfun).get("return", None)

    if rettype is None:
        raise Exception(
            f"The RandomVariable {rp} does not have return type "
            + "annotation."
        )

    try:
        if not isinstance(rettype(), tuple):
            raise Exception(
                f"The RandomVariable {rp}'s return type annotation is not"
                + "a tuple"
            )
    except Exception:
        raise Exception(
            f"There was a problem when trying to initialize {rettype}."
            + "the rtype of the random variable should be a tuple or a namedtuple"
            + " with default values."
        )

    return None


[docs] class RandomVariable(metaclass=ABCMeta): """ Abstract base class for latent and observed random variables. Notes ----- RandomVariables in pyrenew can be time-aware, meaning that they can have a t_start and t_unit attribute. These attributes are expected to be used internally mostly for tasks including padding, alignment of time series, and other time-aware operations. Both attributes give information about the output of the sample() method, in other words, the relative time units of the returning value. Attributes ---------- t_start : int The start of the time series. t_unit : int The unit of the time series relative to the model's fundamental (smallest) time unit. e.g. if the fundamental unit is days, then 1 corresponds to units of days and 7 to units of weeks. """ t_start: int = None t_unit: int = None def __init__(self, **kwargs): """ Default constructor """ pass
[docs] def set_timeseries( self, t_start: int, t_unit: int, ) -> None: """ Set the time series start and unit Parameters ---------- t_start : int The start of the time series relative to the model time. It could be negative, indicating that the sample() method returns timepoints that occur prior to the model t = 0. t_unit : int The unit of the time series relative to the model's fundamental (smallest) time unit. e.g. if the fundamental unit is days, then 1 corresponds to units of days and 7 to units of weeks. Returns ------- None """ # Timeseries unit should be a positive integer assert isinstance( t_unit, int ), f"t_unit should be an integer. It is {type(t_unit)}." # Timeseries unit should be a positive integer assert ( t_unit > 0 ), f"t_unit should be a positive integer. It is {t_unit}." # Data starts should be a positive integer assert isinstance( t_start, int ), f"t_start should be an integer. It is {type(t_start)}." self.t_start = t_start self.t_unit = t_unit return None
[docs] @abstractmethod def sample( self, **kwargs, ) -> tuple: """ Sample method of the process The method design in the class should have at least kwargs. Parameters ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` calls, should there be any. Returns ------- tuple """ pass
[docs] @staticmethod @abstractmethod def validate(**kwargs) -> None: """ Validation of kwargs to be implemented in subclasses. """ pass
[docs] class DistributionalRVSample(NamedTuple): """ Named tuple for the sample method of DistributionalRV Attributes ---------- value : ArrayLike Sampled value from the distribution. """ value: ArrayLike | None = None def __repr__(self) -> str: """ Representation of the DistributionalRVSample """ return f"DistributionalRVSample(value={self.value})"
[docs] class DistributionalRV(RandomVariable): """ Wrapper class for random variables that sample from a single `numpyro.distributions.Distribution`. """ def __init__( self, dist: npro.distributions.Distribution, name: str, ): """ Default constructor for DistributionalRV. Parameters ---------- dist : npro.distributions.Distribution Distribution of the random variable. name : str Name of the random variable. Returns ------- None """ self.validate(dist) self.dist = dist self.name = name return None
[docs] @staticmethod def validate(dist: any) -> None: """ Validation of the distribution to be implemented in subclasses. """ if not isinstance(dist, npro.distributions.Distribution): raise ValueError( "dist should be an instance of " f"numpyro.distributions.Distribution, got {dist}" ) return None
[docs] def sample( self, obs: ArrayLike | None = None, **kwargs, ) -> DistributionalRVSample: """ Sample from the distribution. Parameters ---------- obs : ArrayLike, optional Observations passed as the `obs` argument to `numpyro.sample()`. Default `None`. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. Returns ------- DistributionalRVSample """ return DistributionalRVSample( value=jnp.atleast_1d( npro.sample( name=self.name, fn=self.dist, obs=obs, ) ), )
[docs] class Model(metaclass=ABCMeta): """Abstract base class for models""" # Since initialized in none, values not shared across instances kernel = None mcmc = None @abstractmethod def __init__(self, **kwargs) -> None: # numpydoc ignore=GL08 pass
[docs] @staticmethod @abstractmethod def validate() -> None: # numpydoc ignore=GL08 pass
[docs] @abstractmethod def sample( self, **kwargs, ) -> tuple: """ Sample method of the model The method design in the class should have at least kwargs. Parameters ---------- **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` calls, should there be any. Returns ------- tuple """ pass
def _init_model( self, num_warmup, num_samples, nuts_args: dict = None, mcmc_args: dict = None, ) -> None: """ Creates the NUTS kernel and MCMC model Parameters ---------- nuts_args : dict, optional Dictionary of arguments passed to NUTS. Defaults to None. mcmc_args : dict, optional Dictionary of arguments passed to the MCMC sampler. Defaults to None. Returns ------- None """ if nuts_args is None: nuts_args = dict() if mcmc_args is None: mcmc_args = dict() self.kernel = NUTS( model=self.sample, **nuts_args, ) self.mcmc = MCMC( self.kernel, num_warmup=num_warmup, num_samples=num_samples, **mcmc_args, ) return None
[docs] def run( self, num_warmup, num_samples, rng_key: ArrayLike | None = None, nuts_args: dict = None, mcmc_args: dict = None, **kwargs, ) -> None: """ Runs the model Parameters ---------- nuts_args : dict, optional Dictionary of arguments passed to the NUTS. Defaults to None. mcmc_args : dict, optional Dictionary of passed to the MCMC sampler. Defaults to None. Returns ------- None """ if self.mcmc is None: self._init_model( num_warmup=num_warmup, num_samples=num_samples, nuts_args=nuts_args, mcmc_args=mcmc_args, ) if rng_key is None: rand_int = np.random.randint( np.iinfo(np.int64).min, np.iinfo(np.int64).max ) rng_key = jr.key(rand_int) self.mcmc.run(rng_key=rng_key, **kwargs) return None
[docs] def print_summary( self, prob: float = 0.9, exclude_deterministic: bool = True, ) -> None: """ A wrapper of MCMC.print_summary Parameters ---------- prob : float, optional The acceptance probability of print_summary. Defaults to 0.9 exclude_deterministic : bool, optional Whether to print deterministic variables in the summary. Defaults to True. Returns ------- None """ return self.mcmc.print_summary(prob, exclude_deterministic)
[docs] def spread_draws(self, variables_names: list) -> pl.DataFrame: """ A wrapper of mcmcutils.spread_draws Parameters ---------- variables_names : list A list of variable names to create a table of samples. Returns ------- pl.DataFrame """ return spread_draws(self.mcmc.get_samples(), variables_names)
[docs] def plot_posterior( self, var: list, obs_signal: jax.typing.ArrayLike = None, xlab: str = None, ylab: str = "Signal", samples: int = 50, figsize: list = [4, 5], draws_col: str = "darkblue", obs_col: str = "black", ) -> plt.Figure: # numpydoc ignore=RT01 """A wrapper of pyrenew.mcmcutils.plot_posterior""" return plot_posterior( var=var, draws=self.spread_draws([(var, "time")]), xlab=xlab, ylab=ylab, samples=samples, obs_signal=obs_signal, figsize=figsize, draws_col=draws_col, obs_col=obs_col, )
[docs] def posterior_predictive( self, rng_key: ArrayLike | None = None, numpyro_predictive_args: dict = {}, **kwargs, ) -> dict: """ A wrapper for numpyro.infer.Predictive to generate posterior predictive samples. Parameters ---------- rng_key : ArrayLike, optional Random key for the Predictive function call. Defaults to None. numpyro_predictive_args : dict, optional Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor. **kwargs Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive Returns ------- dict """ if self.mcmc is None: raise ValueError( "No posterior samples available. Run model with model.run()." ) if rng_key is None: rand_int = np.random.randint( np.iinfo(np.int64).min, np.iinfo(np.int64).max ) rng_key = jr.key(rand_int) predictive = Predictive( model=self.sample, posterior_samples=self.mcmc.get_samples(), **numpyro_predictive_args, ) return predictive(rng_key, **kwargs)
[docs] def prior_predictive( self, rng_key: ArrayLike | None = None, numpyro_predictive_args: dict = {}, **kwargs, ) -> dict: """ A wrapper for numpyro.infer.Predictive to generate prior predictive samples. Parameters ---------- rng_key : ArrayLike, optional Random key for the Predictive function call. Defaults to None. numpyro_predictive_args : dict, optional Dictionary of arguments to be passed to the numpyro.inference.Predictive constructor. **kwargs Additional named arguments passed to the `__call__()` method of numpyro.inference.Predictive Returns ------- dict """ if rng_key is None: rand_int = np.random.randint( np.iinfo(np.int64).min, np.iinfo(np.int64).max ) rng_key = jr.key(rand_int) predictive = Predictive( model=self.sample, posterior_samples=None, **numpyro_predictive_args, ) return predictive(rng_key, **kwargs)