Source code for pyrenew.latent.infections

# numpydoc ignore=GL08

from __future__ import annotations

from typing import NamedTuple

import jax.numpy as jnp
from jax.typing import ArrayLike

import pyrenew.latent.infection_functions as inf
from pyrenew.metaclass import RandomVariable


[docs] class InfectionsSample(NamedTuple): """ A container for holding the output from `latent.Infections()`. Attributes ---------- post_initialization_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. """ post_initialization_infections: ArrayLike | None = None def __repr__(self): return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections})"
[docs] class Infections(RandomVariable): r"""Latent infections This class samples infections given Rt, initial infections, and generation interval. Notes ----- The mathematical model is given by: .. math:: I(t) = R(t) \times \sum_{\tau < t} I(\tau) g(t-\tau) where :math:`I(t)` is the number of infections at time :math:`t`, :math:`R(t)` is the reproduction number at time :math:`t`, and :math:`g(t-\tau)` is the generation interval. """
[docs] @staticmethod def validate() -> None: # numpydoc ignore=GL08 return None
[docs] def sample( self, Rt: ArrayLike, I0: ArrayLike, gen_int: ArrayLike, **kwargs, ) -> InfectionsSample: """ Samples infections given Rt, initial infections, and generation interval. Parameters ---------- Rt : ArrayLike Reproduction number. I0 : ArrayLike Initial infections vector of the same length as the generation interval. gen_int : ArrayLike Generation interval pmf vector. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. Returns ------- InfectionsSample Named tuple with "infections". """ if I0.shape[0] < gen_int.size: raise ValueError( "Initial infections vector must be at least as long as " "the generation interval. " f"Initial infections vector length: {I0.shape[0]}, " f"generation interval length: {gen_int.size}." ) if I0.shape[1:] != Rt.shape[1:]: raise ValueError( "Initial infections and Rt must have the same batch shapes. " f"Got initial infections of batch shape {I0.shape[1:]} " f"and Rt of batch shape {Rt.shape[1:]}." ) gen_int_rev = jnp.flip(gen_int) recent_I0 = I0[-gen_int_rev.size :] post_initialization_infections = inf.compute_infections_from_rt( I0=recent_I0, Rt=Rt, reversed_generation_interval_pmf=gen_int_rev, ) return InfectionsSample(post_initialization_infections)