Source code for pyrenew.latent.hospitaladmissions

# -*- coding: utf-8 -*-
# numpydoc ignore=GL08

from __future__ import annotations

from typing import Any, NamedTuple

import jax.numpy as jnp
import numpyro as npro
from jax.typing import ArrayLike
from pyrenew.deterministic import DeterministicVariable
from pyrenew.metaclass import RandomVariable


[docs] class HospitalAdmissionsSample(NamedTuple): """ A container to hold the output of `latent.HospAdmissions.sample()`. Attributes ---------- infection_hosp_rate : float, optional The infection-to-hospitalization rate. Defaults to None. latent_hospital_admissions : ArrayLike or None The computed number of hospital admissions. Defaults to None. """ infection_hosp_rate: float | None = None latent_hospital_admissions: ArrayLike | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})"
[docs] class HospitalAdmissions(RandomVariable): r""" Latent hospital admissions Implements a renewal process for the expected number of hospital admissions. Notes ----- The following text was directly extracted from the wastewater model documentation (`link <https://github.com/cdcent/cfa-forecast-renewal-ww/blob/a17efc090b2ffbc7bc11bdd9eec5198d6bcf7322/model_definition.md#hospital-admissions-component>`_). Following other semi-mechanistic renewal frameworks, we model the *expected* hospital admissions per capita :math:`H(t)` as a convolution of the *expected* latent incident infections per capita :math:`I(t)`, and a discrete infection to hospitalization distribution :math:`d(\tau)`, scaled by the probability of being hospitalized :math:`p_\mathrm{hosp}(t)`. To account for day-of-week effects in hospital reporting, we use an estimated *day of the week effect* :math:`\omega(t)`. If :math:`t` and :math:`t'` are the same day of the week, :math:`\omega(t) = \omega(t')`. The seven values that :math:`\omega(t)` takes on are constrained to have mean 1. .. math:: H(t) = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) Where :math:`T_d` is the maximum delay from infection to hospitalization that we consider. """ def __init__( self, infection_to_admission_interval_rv: RandomVariable, infect_hosp_rate_rv: RandomVariable, latent_hospital_admissions_varname: str = "latent_hospital_admissions", day_of_week_effect_rv: RandomVariable | None = None, hosp_report_prob_rv: RandomVariable | None = None, ) -> None: """ Default constructor Parameters ---------- infection_to_admission_interval_rv : RandomVariable pmf for reporting (informing) hospital admissions (see pyrenew.observations.Deterministic). infect_hosp_rate_rv : RandomVariable Infection to hospitalization rate random variable. latent_hospital_admissions_varname : str Name to assign to the deterministic component in numpyro of observed hospital admissions. day_of_week_effect_rv : RandomVariable, optional Day of the week effect. hosp_report_prob_rv : RandomVariable, optional Random variable for the hospital admission reporting probability. Defaults to 1 (full reporting). Returns ------- None """ if day_of_week_effect_rv is None: day_of_week_effect_rv = DeterministicVariable(1, "weekday_effect") if hosp_report_prob_rv is None: hosp_report_prob_rv = DeterministicVariable(1, "hosp_report_prob") HospitalAdmissions.validate( infect_hosp_rate_rv, day_of_week_effect_rv, hosp_report_prob_rv, ) self.latent_hospital_admissions_varname = ( latent_hospital_admissions_varname ) self.infect_hosp_rate_rv = infect_hosp_rate_rv self.day_of_week_effect_rv = day_of_week_effect_rv self.hosp_report_prob_rv = hosp_report_prob_rv self.infection_to_admission_interval_rv = ( infection_to_admission_interval_rv ) # Why isn't infection_to_admission_interval_rv validated?
[docs] @staticmethod def validate( infect_hosp_rate_rv: Any, day_of_week_effect_rv: Any, hosp_report_prob_rv: Any, ) -> None: """ Validates that the IHR, weekday effects, and probability of being reported hospitalized distributions are RandomVariable types Parameters ---------- infect_hosp_rate_rv : Any Possibly incorrect input for infection to hospitalization rate distribution. day_of_week_effect_rv : Any Possibly incorrect input for day of the week effect. hosp_report_prob_rv : Any Possibly incorrect input for distribution or fixed value for the hospital admission reporting probability. Returns ------- None Raises ------ AssertionError If the object `distr` is not an instance of `dist.Distribution`, indicating that the validation has failed. """ assert isinstance(infect_hosp_rate_rv, RandomVariable) assert isinstance(day_of_week_effect_rv, RandomVariable) assert isinstance(hosp_report_prob_rv, RandomVariable) return None
[docs] def sample( self, latent_infections: ArrayLike, **kwargs, ) -> HospitalAdmissionsSample: """ Samples from the observation process Parameters ---------- latent : ArrayLike Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` calls, should there be any. Returns ------- HospitalAdmissionsSample """ infection_hosp_rate, *_ = self.infect_hosp_rate_rv.sample(**kwargs) infection_hosp_rate_t = infection_hosp_rate * latent_infections ( infection_to_admission_interval, *_, ) = self.infection_to_admission_interval_rv.sample(**kwargs) latent_hospital_admissions = jnp.convolve( infection_hosp_rate_t, infection_to_admission_interval, mode="full", )[: infection_hosp_rate_t.shape[0]] # Applying the day of the week effect latent_hospital_admissions = ( latent_hospital_admissions * self.day_of_week_effect_rv.sample(**kwargs)[0] ) # Applying probability of hospitalization effect latent_hospital_admissions = ( latent_hospital_admissions * self.hosp_report_prob_rv.sample(**kwargs)[0] ) npro.deterministic( self.latent_hospital_admissions_varname, latent_hospital_admissions ) return HospitalAdmissionsSample( infection_hosp_rate, latent_hospital_admissions )