Source code for pyrenew.latent.infection_initialization_process

# numpydoc ignore=GL08
from jax.typing import ArrayLike

from pyrenew.latent.infection_initialization_method import (
    InfectionInitializationMethod,
)
from pyrenew.metaclass import RandomVariable, _assert_type


[docs] class InfectionInitializationProcess(RandomVariable): """Generate an initial infection history""" def __init__( self, name, I_pre_init_rv: RandomVariable, infection_init_method: InfectionInitializationMethod, ) -> None: """Default class constructor for InfectionInitializationProcess Parameters ---------- name : str A name to assign to the RandomVariable. I_pre_init_rv : RandomVariable A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_init_method` uses this random variable in different ways. infection_init_method : InfectionInitializationMethod An `InfectionInitializationMethod` that generates the initial infections for the renewal process. Returns ------- None """ InfectionInitializationProcess.validate( I_pre_init_rv, infection_init_method ) self.I_pre_init_rv = I_pre_init_rv self.infection_init_method = infection_init_method self.name = name
[docs] @staticmethod def validate( I_pre_init_rv: RandomVariable, infection_init_method: InfectionInitializationMethod, ) -> None: """Validate the input arguments to the InfectionInitializationProcess class constructor Parameters ---------- I_pre_init_rv : RandomVariable A random variable representing the number of infections that occur at some time before the renewal process begins. infection_init_method : InfectionInitializationMethod An method to generate the initial infections. Returns ------- None """ _assert_type("I_pre_init_rv", I_pre_init_rv, RandomVariable) _assert_type( "infection_init_method", infection_init_method, InfectionInitializationMethod, )
[docs] def sample(self) -> ArrayLike: """Sample the Infection Initialization Process. Returns ------- ArrayLike the number of initialized infections at each time point. """ I_pre_init = self.I_pre_init_rv() infection_initialization = self.infection_init_method( I_pre_init, ) return infection_initialization