Source code for pyrenew.model.admissionsmodel

# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from __future__ import annotations

from typing import NamedTuple

import jax.numpy as jnp
import pyrenew.arrayutils as au
from jax.typing import ArrayLike
from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype
from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel


[docs] class HospModelSample(NamedTuple): """ A container for holding the output from `model.HospitalAdmissionsModel.sample()`. Attributes ---------- Rt : float | None, optional The reproduction number over time. Defaults to None. latent_infections : ArrayLike | None, optional The estimated number of new infections over time. Defaults to None. infection_hosp_rate : float | None, optional The infected hospitalization rate. Defaults to None. latent_hosp_admissions : ArrayLike | None, optional The estimated latent hospitalizations. Defaults to None. observed_hosp_admissions : ArrayLike | None, optional The sampled or observed hospital admissions. Defaults to None. """ Rt: float | None = None latent_infections: ArrayLike | None = None infection_hosp_rate: float | None = None latent_hosp_admissions: ArrayLike | None = None observed_hosp_admissions: ArrayLike | None = None def __repr__(self): return ( f"HospModelSample(Rt={self.Rt}, " f"latent_infections={self.latent_infections}, " f"infection_hosp_rate={self.infection_hosp_rate}, " f"latent_hosp_admissions={self.latent_hosp_admissions}, " f"observed_hosp_admissions={self.observed_hosp_admissions}" )
[docs] class HospitalAdmissionsModel(Model): """ Hospital Admissions Model (BasicRenewal + HospitalAdmissions) This class inherits from pyrenew.models.Model. It extends the basic renewal model by adding a hospital admissions module, e.g., pyrenew.observations.HospitalAdmissions. """ def __init__( self, latent_hosp_admissions_rv: RandomVariable, latent_infections_rv: RandomVariable, gen_int_rv: RandomVariable, I0_rv: RandomVariable, Rt_process_rv: RandomVariable, hosp_admission_obs_process_rv: RandomVariable, ) -> None: # numpydoc ignore=PR04 """ Default constructor Parameters ---------- latent_hosp_admissions_rv : RandomVariable Latent process for the hospital admissions. latent_infections_rv : RandomVariable The infections latent process (passed to RtInfectionsRenewalModel). gen_int_rv : RandomVariable Generation time (passed to RtInfectionsRenewalModel) I0_rv : RandomVariable Initial infections (passed to RtInfectionsRenewalModel) Rt_process_rv : RandomVariable Rt process (passed to RtInfectionsRenewalModel). hosp_admission_obs_process_rv : RandomVariable, optional Observation process for the hospital admissions. Returns ------- None """ self.basic_renewal = RtInfectionsRenewalModel( gen_int_rv=gen_int_rv, I0_rv=I0_rv, latent_infections_rv=latent_infections_rv, infection_obs_process_rv=None, # why is this None? Rt_process_rv=Rt_process_rv, ) HospitalAdmissionsModel.validate( latent_hosp_admissions_rv, hosp_admission_obs_process_rv ) self.latent_hosp_admissions_rv = latent_hosp_admissions_rv self.hosp_admission_obs_process_rv = hosp_admission_obs_process_rv
[docs] @staticmethod def validate( latent_hosp_admissions_rv, hosp_admission_obs_process_rv ) -> None: """ Verifies types and status (RV) of latent and observed hospital admissions Parameters ---------- latent_hosp_admissions_rv : RandomVariable The latent process for the hospital admissions. hosp_admission_obs_process_rv : RandomVariable The observed hospital admissions. Returns ------- None See Also -------- _assert_sample_and_rtype : Perform type-checking and verify RV """ _assert_sample_and_rtype(latent_hosp_admissions_rv, skip_if_none=False) _assert_sample_and_rtype( hosp_admission_obs_process_rv, skip_if_none=True ) return None
[docs] def sample( self, n_timepoints_to_simulate: int | None = None, data_observed_hosp_admissions: ArrayLike | None = None, padding: int = 0, **kwargs, ) -> HospModelSample: """ Sample from the HospitalAdmissions model Parameters ---------- n_timepoints_to_simulate : int, optional Number of timepoints to sample (passed to the basic renewal model). data_observed_hosp_admissions : ArrayLike, optional The observed hospitalization data (passed to the basic renewal model). Defaults to None (simulation, rather than fit). padding : int, optional Number of padding timepoints to add to the beginning of the simulation. Defaults to 0. **kwargs : dict, optional Additional keyword arguments passed through to internal sample() calls, should there be any. Returns ------- HospModelSample See Also -------- basic_renewal.sample : For sampling the basic renewal model sample_observed_admissions : For sampling observed hospital admissions """ if ( n_timepoints_to_simulate is None and data_observed_hosp_admissions is None ): raise ValueError( "Either n_timepoints_to_simulate or data_observed_hosp_admissions " "must be passed." ) elif ( n_timepoints_to_simulate is not None and data_observed_hosp_admissions is not None ): raise ValueError( "Cannot pass both n_timepoints_to_simulate and data_observed_hosp_admissions." ) elif n_timepoints_to_simulate is None: n_timepoints = len(data_observed_hosp_admissions) else: n_timepoints = n_timepoints_to_simulate # Getting the initial quantities from the basic model basic_model = self.basic_renewal.sample( n_timepoints_to_simulate=n_timepoints, data_observed_infections=None, padding=padding, **kwargs, ) # Sampling the latent hospital admissions ( infection_hosp_rate, latent_hosp_admissions, *_, ) = self.latent_hosp_admissions_rv.sample( latent_infections=basic_model.latent_infections, **kwargs, ) i0_size = len(latent_hosp_admissions) - n_timepoints if self.hosp_admission_obs_process_rv is None: observed_hosp_admissions = None else: if data_observed_hosp_admissions is None: ( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( mu=latent_hosp_admissions, obs=data_observed_hosp_admissions, **kwargs, ) else: data_observed_hosp_admissions = au.pad_x_to_match_y( data_observed_hosp_admissions, latent_hosp_admissions, jnp.nan, pad_direction="start", ) ( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( mu=latent_hosp_admissions[i0_size + padding :], obs=data_observed_hosp_admissions[i0_size + padding :], **kwargs, ) return HospModelSample( Rt=basic_model.Rt, latent_infections=basic_model.latent_infections, infection_hosp_rate=infection_hosp_rate, latent_hosp_admissions=latent_hosp_admissions, observed_hosp_admissions=observed_hosp_admissions, )