Source code for pyrenew.randomvariable.distributionalvariable

# numpydoc ignore=GL08

from typing import Callable, Self

import numpyro
import numpyro.distributions as dist
from jax.typing import ArrayLike
from numpyro.infer.reparam import Reparam

from pyrenew.metaclass import RandomVariable


[docs] class DynamicDistributionalVariable(RandomVariable): """ Wrapper class for random variables that sample from a single :class:`numpyro.distributions.Distribution` that is parameterized / instantiated at `sample()` time (rather than at RandomVariable instantiation time). """ def __init__( self, name: str, distribution_constructor: Callable, reparam: Reparam = None, expand_by_shape: tuple = None, ) -> None: """ Default constructor for DynamicDistributionalVariable. Parameters ---------- name : str Name of the random variable. distribution_constructor : Callable Callable that returns a concrete parametrized numpyro.Distributions.distribution instance. reparam : numpyro.infer.reparam.Reparam If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer expand_by_shape : tuple, optional If not None, call :meth:`expand_by()` on the underlying distribution once it is instianted with the given `expand_by_shape`. Default None. Returns ------- None """ self.name = name self.validate(distribution_constructor) self.distribution_constructor = distribution_constructor if reparam is not None: self.reparam_dict = {self.name: reparam} else: self.reparam_dict = {} if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): raise ValueError( "expand_by_shape must be a tuple or be None ", f"Got {type(expand_by_shape)}", ) self.expand_by_shape = expand_by_shape return None
[docs] @staticmethod def validate(distribution_constructor: any) -> None: """ Confirm that the distribution_constructor is callable. Parameters ---------- distribution_constructor : any Putative distribution_constructor to validate. Returns ------- None or raises a ValueError """ if not callable(distribution_constructor): raise ValueError( "To instantiate a DynamicDistributionalVariable, ", "one must provide a Callable that returns a " "numpyro.distributions.Distribution as the " "distribution_constructor argument. " f"Got {type(distribution_constructor)}, which " "does not appear to be callable", ) return None
[docs] def sample( self, *args, obs: ArrayLike = None, **kwargs, ) -> ArrayLike: """ Sample from the distributional rv. Parameters ---------- *args : Positional arguments passed to self.distribution_constructor obs : ArrayLike, optional Observations passed as the `obs` argument to :meth:`numpyro.sample()`. Default `None`. **kwargs : dict, optional Keyword arguments passed to self.distribution_constructor Returns ------- ArrayLike a sample from the distribution. """ distribution = self.distribution_constructor(*args, **kwargs) if self.expand_by_shape is not None: distribution = distribution.expand_by(self.expand_by_shape) with numpyro.handlers.reparam(config=self.reparam_dict): sample = numpyro.sample( name=self.name, fn=distribution, obs=obs, ) return sample
[docs] def expand_by(self, sample_shape) -> Self: """ Expand the distribution by a given sample_shape, if possible. Returns a new DynamicDistributionalVariable whose underlying distribution will be expanded by the given shape at sample() time. Parameters ---------- sample_shape : tuple Sample shape by which to expand the distribution. Passed to the expand_by() method of :class:`numpyro.distributions.Distribution` after the distribution is instantiated. Returns ------- DynamicDistributionalVariable Whose underlying distribution will be expanded by the given sample shape at sampling time. """ return DynamicDistributionalVariable( name=self.name, distribution_constructor=self.distribution_constructor, reparam=self.reparam_dict.get(self.name, None), expand_by_shape=sample_shape, )
[docs] class StaticDistributionalVariable(RandomVariable): """ Wrapper class for random variables that sample from a single :class:`numpyro.distributions.Distribution` that is parameterized / instantiated at RandomVariable instantiation time (rather than at `sample()`-ing time). """ def __init__( self, name: str, distribution: numpyro.distributions.Distribution, reparam: Reparam = None, ) -> None: """ Default constructor for DistributionalVariable. Parameters ---------- name : str Name of the random variable. distribution : numpyro.distributions.Distribution Distribution of the random variable. reparam : numpyro.infer.reparam.Reparam If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer Returns ------- None """ self.name = name self.validate(distribution) self.distribution = distribution if reparam is not None: self.reparam_dict = {self.name: reparam} else: self.reparam_dict = {} return None
[docs] @staticmethod def validate(distribution: any) -> None: """ Validation of the distribution. """ if not isinstance(distribution, numpyro.distributions.Distribution): raise ValueError( "distribution should be an instance of " "numpyro.distributions.Distribution, got " "{type(distribution)}" ) return None
[docs] def sample( self, obs: ArrayLike | None = None, **kwargs, ) -> ArrayLike: """ Sample from the distribution. Parameters ---------- obs : ArrayLike, optional Observations passed as the `obs` argument to :meth:`numpyro.sample()`. Default `None`. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. Returns ------- ArrayLike Containing a sample from the distribution. """ with numpyro.handlers.reparam(config=self.reparam_dict): sample = numpyro.sample( name=self.name, fn=self.distribution, obs=obs, ) return sample
[docs] def expand_by(self, sample_shape) -> Self: """ Expand the distribution by the given sample_shape, if possible. Returns a new StaticDistributionalVariable whose underlying distribution has been expanded by the given sample_shape via :meth:`~numpyro.distributions.Distribution.expand_by()` Parameters ---------- sample_shape : tuple Sample shape for the expansion. Passed to the :meth:`expand_by()` method of :class:`numpyro.distributions.Distribution`. Returns ------- StaticDistributionalVariable Whose underlying distribution has been expanded by the given sample shape. """ if not isinstance(sample_shape, tuple): raise ValueError( "sample_shape for expand()-ing " "a DistributionalVariable must be a " f"tuple. Got {type(sample_shape)}" ) return StaticDistributionalVariable( name=self.name, distribution=self.distribution.expand_by(sample_shape), reparam=self.reparam_dict.get(self.name, None), )
[docs] def DistributionalVariable( name: str, distribution: numpyro.distributions.Distribution | Callable, reparam: Reparam = None, ) -> RandomVariable: """ Factory function to generate Distributional RandomVariables, either static or dynamic. Parameters ---------- name : str Name of the random variable. distribution: numpyro.distributions.Distribution | Callable Either numpyro.distributions.Distribution instance given the static distribution of the random variable or a callable that returns a parameterized numpyro.distributions.Distribution when called, which allows for dynamically-parameterized DistributionalVariables, e.g. a Normal distribution with an inferred location and scale. reparam : numpyro.infer.reparam.Reparam If not None, reparameterize sampling from the distribution according to the given numpyro reparameterizer Returns ------- DynamicDistributionalVariable | StaticDistributionalVariable or raises a ValueError if a distribution cannot be constructed. """ if isinstance(distribution, dist.Distribution): return StaticDistributionalVariable( name=name, distribution=distribution, reparam=reparam ) elif callable(distribution): return DynamicDistributionalVariable( name=name, distribution_constructor=distribution, reparam=reparam ) else: raise ValueError( "distribution argument to DistributionalVariable " "must be either a numpyro.distributions.Distribution " "(for instantiating a static DistributionalVariable) " "or a callable that returns a " "numpyro.distributions.Distribution (for " "a dynamic DistributionalVariable)." )