Source code for pyrenew.latent.infection_initialization_method
# numpydoc ignore=GL08
from abc import ABCMeta, abstractmethod
import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.metaclass import RandomVariable
[docs]
class InfectionInitializationMethod(metaclass=ABCMeta):
"""Method for initializing infections in a renewal process."""
def __init__(self, n_timepoints: int):
"""Default constructor for
:class:`InfectionInitializationMethod`.
Parameters
----------
n_timepoints : int
the number of time points for which to
generate initial infections
Returns
-------
None
"""
self.validate(n_timepoints)
self.n_timepoints = n_timepoints
[docs]
@staticmethod
def validate(n_timepoints: int) -> None:
"""
Validate inputs to the
:class:`InfectionInitializationMethod`
constructor.
Parameters
----------
n_timepoints : int
the number of time points to generate initial infections for
Returns
-------
None
"""
if not isinstance(n_timepoints, int):
raise TypeError(
f"n_timepoints must be an integer. Got {type(n_timepoints)}"
)
if n_timepoints <= 0:
raise ValueError(
f"n_timepoints must be positive. Got {n_timepoints}"
)
[docs]
@abstractmethod
def initialize_infections(self, I_pre_init: ArrayLike):
"""Generate the number of initialized infections at each time point.
Parameters
----------
I_pre_init : ArrayLike
An array representing some number of latent infections to be used with the specified ``InfectionInitializationMethod``.
Returns
-------
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
def __call__(self, I_pre_init: ArrayLike):
return self.initialize_infections(I_pre_init)
[docs]
class InitializeInfectionsZeroPad(InfectionInitializationMethod):
"""
Create an initial infection vector of specified length by
padding a shorter vector with an appropriate number of
zeros at the beginning of the time series.
"""
[docs]
def initialize_infections(self, I_pre_init: ArrayLike):
"""Pad the initial infections with zeros at the beginning of the time series.
Parameters
----------
I_pre_init : ArrayLike
An array with initialized infections to be padded with zeros.
Returns
-------
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
I_pre_init = jnp.atleast_1d(I_pre_init)
if self.n_timepoints < I_pre_init.size:
raise ValueError(
"I_pre_init must be no longer than n_timepoints. "
f"Got I_pre_init of size {I_pre_init.size} and "
f" n_timepoints of size {self.n_timepoints}."
)
return jnp.pad(I_pre_init, (self.n_timepoints - I_pre_init.size, 0))
[docs]
class InitializeInfectionsFromVec(InfectionInitializationMethod):
"""Create initial infections from a vector of infections."""
[docs]
def initialize_infections(self, I_pre_init: ArrayLike) -> ArrayLike:
"""Create initial infections from a vector of infections.
Parameters
----------
I_pre_init : ArrayLike
An array with the same length as ``n_timepoints`` to be
used as the initial infections.
Returns
-------
ArrayLike
An array of length ``n_timepoints`` with the number of
initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
if I_pre_init.size != self.n_timepoints:
raise ValueError(
"I_pre_init must have the same size as n_timepoints. "
f"Got I_pre_init of size {I_pre_init.size} "
f"and n_timepoints of size {self.n_timepoints}."
)
return I_pre_init
[docs]
class InitializeInfectionsExponentialGrowth(InfectionInitializationMethod):
r"""Generate initial infections according to exponential growth.
Notes
-----
The number of incident infections at time `t` is given by:
.. math:: I(t) = I_p \exp \left( r (t - t_p) \right)
Where :math:`I_p` is ``I_pre_init``, :math:`r` is ``rate``, and :math:`t_p` is ``t_pre_init``.
This ensures that :math:`I(t_p) = I_p`.
We default to ``t_pre_init = n_timepoints - 1``, so that
``I_pre_init`` represents the number of incident infections immediately
before the renewal process begins.
"""
def __init__(
self,
n_timepoints: int,
rate_rv: RandomVariable,
t_pre_init: int | None = None,
):
"""Default constructor for the ``InitializeInfectionsExponentialGrowth`` class.
Parameters
----------
n_timepoints : int
the number of time points to generate initial infections for
rate_rv : RandomVariable
A random variable representing the rate of exponential growth
t_pre_init : int | None, optional
The time point whose number of infections is described by ``I_pre_init``. Defaults to ``n_timepoints - 1``.
"""
super().__init__(n_timepoints)
self.rate_rv = rate_rv
if t_pre_init is None:
t_pre_init = n_timepoints - 1
self.t_pre_init = t_pre_init
[docs]
def initialize_infections(self, I_pre_init: ArrayLike):
"""Generate initial infections according to exponential growth.
Parameters
----------
I_pre_init : ArrayLike
An array of size 1 representing the number of infections at time ``t_pre_init``.
Returns
-------
ArrayLike
An array of length ``n_timepoints`` with the number of initialized infections at each time point.
"""
I_pre_init = jnp.array(I_pre_init)
rate = jnp.array(self.rate_rv())
initial_infections = I_pre_init * jnp.exp(
rate
* (jnp.arange(self.n_timepoints)[:, jnp.newaxis] - self.t_pre_init)
)
return jnp.squeeze(initial_infections)