Source code for pyrenew.latent.hospitaladmissions

# numpydoc ignore=GL08

from __future__ import annotations

from typing import Any, NamedTuple

import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike

import pyrenew.arrayutils as au
from pyrenew.convolve import compute_delay_ascertained_incidence
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable


[docs] class HospitalAdmissionsSample(NamedTuple): """ A container to hold the output of :meth:`HospitalAdmissions.sample`. Attributes ---------- infection_hosp_rate : ArrayLike, optional The infection-to-hospitalization rate. Defaults to None. latent_hospital_admissions : ArrayLike or None The computed number of hospital admissions. Defaults to None. multiplier : ArrayLike or None The day of the week effect multiplier. Defaults to None. It should match the number of timepoints in the latent hospital admissions. """ infection_hosp_rate: ArrayLike | None = None latent_hospital_admissions: ArrayLike | None = None multiplier: ArrayLike | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions}, multiplier={self.multiplier})"
[docs] class HospitalAdmissions(RandomVariable): r""" Latent hospital admissions Implements a renewal process for the expected number of hospital admissions. Notes ----- The following text was directly extracted from the wastewater model documentation (`link <https://github.com/cdcent/cfa-forecast-renewal-ww/blob/a17efc090b2ffbc7bc11bdd9eec5198d6bcf7322/model_definition.md#hospital-admissions-component>`_). Following other semi-mechanistic renewal frameworks, we model the *expected* hospital admissions per capita :math:`H(t)` as a convolution of the *expected* latent incident infections per capita :math:`I(t)`, and a discrete infection to hospitalization distribution :math:`d(\tau)`, scaled by the probability of being hospitalized :math:`p_\mathrm{hosp}(t)`. To account for day-of-week effects in hospital reporting, we use an estimated *day of the week effect* :math:`\omega(t)`. If :math:`t` and :math:`t'` are the same day of the week, :math:`\omega(t) = \omega(t')`. The seven values that :math:`\omega(t)` takes on are constrained to have mean 1. .. math:: H(t) = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) Where :math:`T_d` is the maximum delay from infection to hospitalization that we consider. """ def __init__( self, infection_to_admission_interval_rv: RandomVariable, infection_hospitalization_ratio_rv: RandomVariable, day_of_week_effect_rv: RandomVariable | None = None, hospitalization_reporting_ratio_rv: RandomVariable | None = None, obs_data_first_day_of_the_week: int = 0, ) -> None: """ Default constructor Parameters ---------- infection_to_admission_interval_rv : RandomVariable pmf for reporting (informing) hospital admissions (see pyrenew.observations.Deterministic). infection_hospitalization_ratio_rv : RandomVariable Infection to hospitalization rate random variable. day_of_week_effect_rv : RandomVariable, optional Day of the week effect. Should return a ArrayLike with 7 values. Defaults to a deterministic variable with jax.numpy.ones(7) (no effect). hospitalization_reporting_ratio_rv : RandomVariable, optional Random variable for the hospital admission reporting probability. Defaults to 1 (full reporting). obs_data_first_day_of_the_week : int, optional The day of the week that the first day of the observation data corresponds to. Valid values are 0-6, where 0 is Monday and 6 is Sunday. Defaults to 0. Returns ------- None """ if day_of_week_effect_rv is None: day_of_week_effect_rv = DeterministicVariable( name="weekday_effect", value=jnp.ones(7) ) if hospitalization_reporting_ratio_rv is None: hospitalization_reporting_ratio_rv = DeterministicVariable( name="hosp_report_prob", value=1.0 ) HospitalAdmissions.validate( infection_to_admission_interval_rv, infection_hospitalization_ratio_rv, day_of_week_effect_rv, hospitalization_reporting_ratio_rv, obs_data_first_day_of_the_week, ) self.infection_to_admission_interval_rv = ( infection_to_admission_interval_rv ) self.infection_hospitalization_ratio_rv = ( infection_hospitalization_ratio_rv ) self.day_of_week_effect_rv = day_of_week_effect_rv self.hospitalization_reporting_ratio_rv = ( hospitalization_reporting_ratio_rv ) self.obs_data_first_day_of_the_week = obs_data_first_day_of_the_week
[docs] @staticmethod def validate( infection_to_admission_interval_rv: Any, infection_hospitalization_ratio_rv: Any, day_of_week_effect_rv: Any, hospitalization_reporting_ratio_rv: Any, obs_data_first_day_of_the_week: Any, ) -> None: """ Validates that the IHR, weekday effects, probability of being reported hospitalized distributions, and infection to hospital admissions reporting delay pmf are RandomVariable types Parameters ---------- infection_to_admission_interval_rv : Any Possibly incorrect input for the infection to hospitalization interval distribution. infection_hospitalization_ratio_rv : Any Possibly incorrect input for infection to hospitalization rate distribution. day_of_week_effect_rv : Any Possibly incorrect input for day of the week effect. hospitalization_reporting_ratio_rv : Any Possibly incorrect input for distribution or fixed value for the hospital admission reporting probability. obs_data_first_day_of_the_week : Any Possibly incorrect input for the day of the week that the first day of the observation data corresponds to. Valid values are 0-6, where 0 is Monday and 6 is Sunday. Returns ------- None Raises ------ AssertionError If any of the random variables are not of the correct type, or if the day of the week is not within the valid range. """ assert isinstance(infection_to_admission_interval_rv, RandomVariable) assert isinstance(infection_hospitalization_ratio_rv, RandomVariable) assert isinstance(day_of_week_effect_rv, RandomVariable) assert isinstance(hospitalization_reporting_ratio_rv, RandomVariable) assert isinstance(obs_data_first_day_of_the_week, int) assert 0 <= obs_data_first_day_of_the_week <= 6 return None
[docs] def sample( self, latent_infections: ArrayLike, **kwargs, ) -> HospitalAdmissionsSample: """ Samples from the observation process Parameters ---------- latent_infections : ArrayLike Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal :meth:`sample() <pyrenew.metaclass.RandomVariable.sample>` calls, should there be any. Returns ------- HospitalAdmissionsSample """ infection_hosp_rate = self.infection_hospitalization_ratio_rv(**kwargs) infection_to_admission_interval = ( self.infection_to_admission_interval_rv(**kwargs) ) latent_hospital_admissions = compute_delay_ascertained_incidence( latent_infections, infection_to_admission_interval, infection_hosp_rate, ) # Applying the day of the week effect. For this we need to: # 1. Get the day of the week effect # 2. Identify the offset of the latent_infections # 3. Apply the day of the week effect to the # latent_hospital_admissions dow_effect_sampled = self.day_of_week_effect_rv(**kwargs) if dow_effect_sampled.size != 7: raise ValueError( "Day of the week effect should have 7 values. " f"Got {dow_effect_sampled.size} instead." ) inf_offset = self.obs_data_first_day_of_the_week % 7 # Replicating the day of the week effect to match the number of # timepoints dow_effect = au.tile_until_n( data=dow_effect_sampled, n_timepoints=latent_hospital_admissions.size, offset=inf_offset, ) latent_hospital_admissions = latent_hospital_admissions * dow_effect # Applying reporting probability latent_hospital_admissions = ( latent_hospital_admissions * self.hospitalization_reporting_ratio_rv(**kwargs) ) numpyro.deterministic( "latent_hospital_admissions", latent_hospital_admissions ) return HospitalAdmissionsSample( infection_hosp_rate=infection_hosp_rate, latent_hospital_admissions=latent_hospital_admissions, multiplier=dow_effect, )