Source code for pyrenew.latent.infection_seeding_method

# -*- coding: utf-8 -*-
# numpydoc ignore=GL08
from abc import ABCMeta, abstractmethod

import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable


[docs] class InfectionSeedMethod(metaclass=ABCMeta): """Method for seeding initial infections in a renewal process.""" def __init__(self, n_timepoints: int): """Default constructor for the ``InfectionSeedMethod`` class. Parameters ---------- n_timepoints : int the number of time points to generate seed infections for Returns ------- None """ self.validate(n_timepoints) self.n_timepoints = n_timepoints
[docs] @staticmethod def validate(n_timepoints: int) -> None: """Validate inputs for the ``InfectionSeedMethod`` class constructor Parameters ---------- n_timepoints : int the number of time points to generate seed infections for Returns ------- None """ if not isinstance(n_timepoints, int): raise TypeError( f"n_timepoints must be an integer. Got {type(n_timepoints)}" ) if n_timepoints <= 0: raise ValueError( f"n_timepoints must be positive. Got {n_timepoints}" )
[docs] @abstractmethod def seed_infections(self, I_pre_seed: ArrayLike): """Generate the number of seeded infections at each time point. Parameters ---------- I_pre_seed : ArrayLike An array representing some number of latent infections to be used with the specified ``InfectionSeedMethod``. Returns ------- ArrayLike An array of length ``n_timepoints`` with the number of seeded infections at each time point. """
def __call__(self, I_pre_seed: ArrayLike): return self.seed_infections(I_pre_seed)
[docs] class SeedInfectionsZeroPad(InfectionSeedMethod): """ Create a seed infection vector of specified length by padding a shorter vector with an appropriate number of zeros at the beginning of the time series. """
[docs] def seed_infections(self, I_pre_seed: ArrayLike): """Pad the seed infections with zeros at the beginning of the time series. Parameters ---------- I_pre_seed : ArrayLike An array with seeded infections to be padded with zeros. Returns ------- ArrayLike An array of length ``n_timepoints`` with the number of seeded infections at each time point. """ if self.n_timepoints < I_pre_seed.size: raise ValueError( "I_pre_seed must be no longer than n_timepoints. " f"Got I_pre_seed of size {I_pre_seed.size} and " f" n_timepoints of size {self.n_timepoints}." ) return jnp.pad(I_pre_seed, (self.n_timepoints - I_pre_seed.size, 0))
[docs] class SeedInfectionsFromVec(InfectionSeedMethod): """Create seed infections from a vector of infections."""
[docs] def seed_infections(self, I_pre_seed: ArrayLike): """Create seed infections from a vector of infections. Parameters ---------- I_pre_seed : ArrayLike An array with the same length as ``n_timepoints`` to be used as the seed infections. Returns ------- ArrayLike An array of length ``n_timepoints`` with the number of seeded infections at each time point. """ if I_pre_seed.size != self.n_timepoints: raise ValueError( "I_pre_seed must have the same size as n_timepoints. " f"Got I_pre_seed of size {I_pre_seed.size} " f"and n_timepoints of size {self.n_timepoints}." ) return jnp.array(I_pre_seed)
[docs] class SeedInfectionsExponential(InfectionSeedMethod): r"""Generate seed infections according to exponential growth. Notes ----- The number of incident infections at time `t` is given by: .. math:: I(t) = I_p \exp \left( r (t - t_p) \right) Where :math:`I_p` is ``I_pre_seed``, :math:`r` is ``rate``, and :math:`t_p` is ``t_pre_seed``. This ensures that :math:`I(t_p) = I_p`. We default to ``t_pre_seed = n_timepoints - 1``, so that ``I_pre_seed`` represents the number of incident infections immediately before the renewal process begins. """ def __init__( self, n_timepoints: int, rate: RandomVariable, t_pre_seed: int | None = None, ): """Default constructor for the ``SeedInfectionsExponential`` class. Parameters ---------- n_timepoints : int the number of time points to generate seed infections for rate : RandomVariable A random variable representing the rate of exponential growth t_pre_seed : int | None, optional The time point whose number of infections is described by ``I_pre_seed``. Defaults to ``n_timepoints - 1``. """ super().__init__(n_timepoints) self.rate = rate if t_pre_seed is None: t_pre_seed = n_timepoints - 1 self.t_pre_seed = t_pre_seed
[docs] def seed_infections(self, I_pre_seed: ArrayLike): """Generate seed infections according to exponential growth. Parameters ---------- I_pre_seed : ArrayLike An array of size 1 representing the number of infections at time ``t_pre_seed``. Returns ------- ArrayLike An array of length ``n_timepoints`` with the number of seeded infections at each time point. """ if I_pre_seed.size != 1: raise ValueError( f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}." ) (rate,) = self.rate.sample() if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." ) return I_pre_seed * jnp.exp( rate * (jnp.arange(self.n_timepoints) - self.t_pre_seed) )