Source code for pyrenew.randomvariable.transformedvariable

# numpydoc ignore=GL08

import numpyro

from pyrenew.metaclass import RandomVariable
from pyrenew.transformation import Transform


[docs] class TransformedVariable(RandomVariable): """ Class to represent RandomVariables defined by taking the output of another RV's :meth:`RandomVariable.sample()` method and transforming it by a given transformation (typically a :class:`Transform`) """ def __init__( self, name: str, base_rv: RandomVariable, transforms: Transform | tuple[Transform], ): """ Default constructor Parameters ---------- name : str A name for the random variable instance. base_rv : RandomVariable The underlying (untransformed) RandomVariable. transforms : Transform Transformation or tuple of transformations to apply to the output of `base_rv.sample()`; single values will be coerced to a length-one tuple. If a tuple, should be the same length as the tuple returned by `base_rv.sample()`. Returns ------- None """ self.name = name self.base_rv = base_rv if not isinstance(transforms, tuple): transforms = (transforms,) self.transforms = transforms self.validate()
[docs] def sample(self, record=False, **kwargs) -> tuple: """ Sample method. Call self.base_rv.sample() and then apply the transforms specified in self.transforms. Parameters ---------- record : bool, optional Whether to record the value of the deterministic RandomVariable. Defaults to False. **kwargs : Keyword arguments passed to self.base_rv.sample() Returns ------- tuple of the same length as the tuple returned by self.base_rv.sample() """ untransformed_values = self.base_rv.sample(**kwargs) if not isinstance(untransformed_values, tuple): untransformed_values = (untransformed_values,) transformed_values = tuple( t(uv) for t, uv in zip(self.transforms, untransformed_values) ) if record: if len(untransformed_values) == 1: numpyro.deterministic(self.name, transformed_values) else: suffixes = ( untransformed_values._fields if hasattr(untransformed_values, "_fields") else range(len(transformed_values)) ) for suffix, tv in zip(suffixes, transformed_values): numpyro.deterministic(f"{self.name}_{suffix}", tv) if len(transformed_values) == 1: transformed_values = transformed_values[0] return transformed_values
[docs] def sample_length(self): """ Sample length for a transformed random variable must be equal to the length of self.transforms or validation will fail. Returns ------- int Equal to the length self.transforms """ return len(self.transforms)
[docs] def validate(self): """ Perform validation checks on a TransformedVariable instance, confirming that all transformations are callable and that the number of transformations is equal to the sample length of the base random variable. Returns ------- None on successful validation, or raise a ValueError """ for t in self.transforms: if not callable(t): raise ValueError( "All entries in self.transforms " "must be callable" ) if hasattr(self.base_rv, "sample_length"): n_transforms = len(self.transforms) n_entries = self.base_rv.sample_length() if not n_transforms == n_entries: raise ValueError( "There must be exactly as many transformations " "specified as entries self.transforms as there are " "entries in the tuple returned by " "self.base_rv.sample()." f"Got {n_transforms} transforms and {n_entries} " "entries" )