Source code for pyrenew.process.rtrandomwalk

# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

import numpyro as npro
import numpyro.distributions as dist
import pyrenew.transformation as t
from pyrenew.metaclass import RandomVariable
from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess


[docs] class RtRandomWalkProcess(RandomVariable): r"""Rt Randomwalk Process Notes ----- The process is defined as follows: .. math:: Rt(0) &\sim \text{Rt0_dist} \\ Rt(t) &\sim \text{Rt_transform}(\text{Rt_transformed_rw}(t)) """ def __init__( self, Rt0_dist: dist.Distribution, Rt_rw_dist: dist.Distribution, Rt_transform: t.Transform | None = None, ) -> None: """ Default constructor Parameters ---------- Rt0_dist : dist.Distribution Initial distribution of Rt. Rt_rw_dist : dist.Distribution Randomwalk process. Rt_transform : numpyro.distributions.transformers.Transform, optional Transformation applied to the sampled Rt0. If None, the identity transformation is used. Returns ------- None """ if Rt_transform is None: Rt_transform = t.IdentityTransform() RtRandomWalkProcess.validate(Rt0_dist, Rt_transform, Rt_rw_dist) self.Rt0_dist = Rt0_dist self.Rt_transform = Rt_transform self.Rt_rw_dist = Rt_rw_dist return None
[docs] @staticmethod def validate( Rt0_dist: dist.Distribution, Rt_transform: t.Transform, Rt_rw_dist: dist.Distribution, ) -> None: """ Validates Rt0_dist, Rt_transform, and Rt_rw_dist. Parameters ---------- Rt0_dist : dist.Distribution, optional Initial distribution of Rt, expected dist.Distribution Rt_transform : numpyro.distributions.transforms.Transform Transformation applied to the sampled Rt0. Rt_rw_dist : any Randomwalk process, expected dist.Distribution. Returns ------- None Raises ------ AssertionError If Rt0_dist or Rt_rw_dist are not dist.Distribution or if Rt_transform is not numpyro.distributions.transforms.Transform. """ assert isinstance(Rt0_dist, dist.Distribution) assert isinstance(Rt_transform, t.Transform) assert isinstance(Rt_rw_dist, dist.Distribution)
[docs] def sample( self, n_timepoints: int, **kwargs, ) -> tuple: """ Generate samples from the process Parameters ---------- n_timepoints : int Number of timepoints to sample. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. Returns ------- tuple With a single array of shape (n_timepoints,). """ Rt0 = npro.sample("Rt0", self.Rt0_dist) Rt0_trans = self.Rt_transform(Rt0) Rt_trans_proc = SimpleRandomWalkProcess(self.Rt_rw_dist) Rt_trans_ts, *_ = Rt_trans_proc.sample( n_timepoints=n_timepoints, name="Rt_transformed_rw", init=Rt0_trans, ) Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts)) return (Rt,)