======================================== Fitting a Hospital Admissions-only Model ======================================== This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data. Model definition ================ In this section, we provide the formal definition of the model. The hospitalization model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions: .. math:: h(t) \sim \text{HospDist}\left(H(t)\right) Where :math:`h(t)` is the observed number of hospital admissions at time :math:`t`, and :math:`H(t)` is the number of latent hospital admissions at time :math:`t`. The distribution :math:`\text{HospDist}` is discrete. For this example, we will use a negative binomial distribution: .. math:: \begin{align*} h(t) & \sim \text{NegativeBinomial}\left(\text{concentration} = 1, \text{mean} = H(t)\right) \\ H(t) & = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) \end{align*} Were :math:`d(\tau)` is the infection to hospitalization interval, :math:`I(t)` is the number of latent infections at time :math:`t`, :math:`p_\mathrm{hosp}(t)` is the infection to hospitalization rate, and :math:`\omega(t)` is the day-of-the-week effect at time :math:`t`; the last section provides an example building such a ``RandomVariable``. The number of latent hospital admissions at time :math:`t` is a function of the number of latent infections at time :math:`t` and the infection to hospitalization rate. The latent infections are modeled as a renewal process: .. math:: \begin{align*} I(t) &= R(t) \times \sum_{\tau < t} I(\tau) g(t - \tau) \\ I(0) &\sim \text{LogNormal}(\mu = \log(80/0.05), \sigma = 1.5) \end{align*} The reproductive number :math:`R(t)` is modeled as a random walk process: .. math:: \begin{align*} R(t) & = R(t-1) + \epsilon\\ \log{\epsilon} & \sim \text{Normal}(\mu=0, \sigma=0.1) \\ R(0) &\sim \text{TruncatedNormal}(\text{loc}=1.2, \text{scale}=0.2, \text{min}=0) \end{align*} Data processing =============== We start by loading the data and inspecting the first five rows. .. container:: cell .. code:: python import polars as pl from pyrenew import datasets dat = datasets.load_wastewater() dat.head(5) .. container:: cell-output cell-output-display :name: data-inspect .. raw:: html
shape: (5, 14) +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | t | lab | log | date | lod_s | belo | da | d | pop | for | h | site | w | inf_ | | | _wwtp | _conc | | ewage | w_lod | ily_h | aily_ | | ecast | osp_c | | w_pop | per_c | | | _uniq | | | | | osp_a | hosp_ | | _date | alibr | | | apita | | | ue_id | | | | | dmits | admit | | | ation | | | | | | | | | | | | s_for | | | _time | | | | | | | | | | | | _eval | | | | | | | +=====+=======+=======+=======+=======+=======+=======+=======+=====+=======+=======+======+=======+=======+ | i64 | i64 | f64 | date | f64 | i64 | i64 | i64 | f64 | date | i64 | i64 | f64 | f64 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 1 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 2 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 3 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 2 | 200 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 4 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 3 | 100 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ | 1 | 5 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 4 | 50 | 0.0 | | | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 | +-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+ .. raw:: html
The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day. .. container:: cell .. code:: python # Keeping the first observation of each date dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"]) # Now, sorting by date dat = dat.sort("date") # Keeping the first 90 days dat = dat.head(90) dat.head(5) .. container:: cell-output cell-output-display :name: aggregation .. raw:: html
shape: (5, 2) ========== ================= date daily_hosp_admits ========== ================= date i64 2023-10-30 6 2023-10-31 8 2023-11-01 4 2023-11-02 8 2023-11-03 4 ========== ================= .. raw:: html
Let’s take a look at the daily prevalence of hospital admissions. .. container:: cell .. code:: python import matplotlib.pyplot as plt # Rotating the x-axis labels, and only showing ~10 labels ax = plt.gca() ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10)) ax.xaxis.set_tick_params(rotation=45) plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy()) plt.xlabel("Date") plt.ylabel("Admissions") plt.show() .. container:: cell-output cell-output-display |image1| Building the model ================== First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospitalization interval. .. container:: cell .. code:: python gen_int = datasets.load_generation_interval() inf_hosp_int = datasets.load_infection_admission_interval() # We only need the probability_mass column of each dataset gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() # Taking a pick at the first 5 elements of each gen_int[:5], inf_hosp_int[:5] # Visualizing both quantities side by side fig, axs = plt.subplots(1, 2) axs[0].plot(gen_int) axs[0].set_title("Generation interval") axs[1].plot(inf_hosp_int) axs[1].set_title("Infection to hospitalization interval") plt.show() .. container:: cell-output cell-output-display |image2| With these two in hand, we can start building the model. First, we will define the latent hospital admissions: .. container:: cell :name: latent-hosp .. code:: python from pyrenew import latent, deterministic, metaclass import jax.numpy as jnp import numpyro.distributions as dist inf_hosp_int = deterministic.DeterministicPMF( inf_hosp_int, name="inf_hosp_int" ) hosp_rate = metaclass.DistributionalRV( dist=dist.LogNormal(jnp.log(0.05), 0.1), name="IHR", ) latent_hosp = latent.HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp_int, infect_hosp_rate_rv=hosp_rate, ) .. container:: cell-output cell-output-stderr :: /home/runner/.cache/pypoetry/virtualenvs/pyrenew-ay_vsbmF-py3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html from .autonotebook import tqdm as notebook_tqdm The ``inf_hosp_int`` is a ``DeterministicPMF`` object that takes the infection to hospitalization interval as input. The ``hosp_rate`` is a ``DistributionalRV`` object that takes a numpyro distribution to represent the infection to hospitalization rate. The ``HospitalAdmissions`` class is a ``RandomVariable`` that takes two distributions as inputs: the infection to admission interval and the infection to hospitalization rate. Now, we can define the rest of the other components: .. container:: cell :name: initializing-rest-of-model .. code:: python from pyrenew import model, process, observation, metaclass, transformation from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponential # Infection process latent_inf = latent.Infections() I0 = InfectionSeedingProcess( "I0_seeding", metaclass.DistributionalRV( dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0" ), SeedInfectionsExponential( gen_int_array.size, deterministic.DeterministicVariable(0.5, name="rate"), ), t_unit=1, ) # Generation interval and Rt gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int") rtproc = process.RtRandomWalkProcess( Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0), Rt_transform=transformation.ExpTransform().inv, Rt_rw_dist=dist.Normal(0, 0.025), ) # The observation model obs = observation.NegativeBinomialObservation(concentration_prior=1.0) Notice all the components are ``RandomVariable`` instances. We can now build the model: .. container:: cell :name: init-model .. code:: python hosp_model = model.HospitalAdmissionsModel( latent_infections_rv=latent_inf, latent_hosp_admissions_rv=latent_hosp, I0_rv=I0, gen_int_rv=gen_int, Rt_process_rv=rtproc, hosp_admission_obs_process_rv=obs, ) Let’s simulate to check if the model is working: .. container:: cell :name: simulation .. code:: python import numpyro as npro import numpy as np timeframe = 120 np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)): sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe) .. container:: cell-output cell-output-stderr :: /home/runner/.cache/pypoetry/virtualenvs/pyrenew-ay_vsbmF-py3.12/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:3044: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error. return array(arys[0], copy=False, ndmin=1) .. container:: cell .. code:: python import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot axs[0].plot(sim_data.Rt) axs[0].set_ylabel("Rt") # Infections plot axs[1].plot(sim_data.observed_hosp_admissions) axs[1].set_ylabel("Infections") axs[1].set_yscale("log") fig.suptitle("Basic renewal model") fig.supxlabel("Time") plt.tight_layout() plt.show() .. container:: cell-output cell-output-display |image3| Fitting the model ================= We can fit the model to the data. We will use the ``run`` method of the model object: .. container:: cell :name: model-fit .. code:: python import jax hosp_model.run( num_samples=2000, num_warmup=2000, data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(), rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) .. container:: cell-output cell-output-stderr :: /home/runner/work/multisignal-epi-inference/multisignal-epi-inference/model/src/pyrenew/metaclass.py:373: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`. self.mcmc = MCMC( We can use the ``plot_posterior`` method to visualize the results [1]_: .. container:: cell .. code:: python out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( dat["daily_hosp_admits"].to_numpy().astype(float), (gen_int_array.size, 0), constant_values=np.nan, ), ) .. container:: cell-output cell-output-display |image4| The first half of the model is not looking good. The reason is that the infection to hospitalization interval PMF makes it unlikely to observe admissions from the beginning. The following section shows how to fix this. Padding the model ================= We can use the padding argument to solve the overestimation of hospital admissions in the first half of the model. By setting ``padding > 0``, the model then assumes that the first ``padding`` observations are missing; thus, only observations after ``padding`` will count towards the likelihood of the model. In practice, the model will extend the estimated Rt and latent infections by ``padding`` days, given time to adjust to the observed data. The following code will add 21 days of missing data at the beginning of the model and re-estimate it with ``padding = 21``: .. container:: cell :name: model-fit-padding .. code:: python days_to_impute = 21 # Add 21 Nas to the beginning of dat_w_padding dat_w_padding = np.pad( dat["daily_hosp_admits"].to_numpy().astype(float), (days_to_impute, 0), constant_values=np.nan, ) hosp_model.run( num_samples=2000, num_warmup=2000, data_observed_hosp_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), padding=days_to_impute, # Padding the model ) And plotting the results: .. container:: cell .. code:: python out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan ), ) .. container:: cell-output cell-output-display |image5| We can use `ArviZ `__ to visualize the results. Let’s start by converting the fitted model to Arviz InferenceData object: .. container:: cell :name: convert-inferencedata .. code:: python import arviz as az idata = az.from_numpyro(hosp_model.mcmc) We obtain the summary of model diagnostics and print the diagnostics for ``latent_hospital_admissions[1]`` .. container:: cell :name: diagnostics .. code:: python diagnostic_stats_summary = az.summary( idata.posterior, kind="diagnostics", ) print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"]) .. container:: cell-output cell-output-stdout :: mcse_mean 0.0 mcse_sd 0.0 ess_bulk 5000.0 ess_tail 2788.0 r_hat 1.0 Name: latent_hospital_admissions[1], dtype: float64 Below we plot 90% and 50% highest density intervals for latent hospital admissions using `plot_hdi `__: .. container:: cell .. code:: python x_data = idata.posterior["latent_hospital_admissions_dim_0"] y_data = idata.posterior["latent_hospital_admissions"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, y_data, hdi_prob=0.9, color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( x_data, y_data, hdi_prob=0.5, color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) # Add mean of the posterior to the figure mean_latent_hosp_admission = np.mean( idata.posterior["latent_hospital_admissions"], axis=1 ) axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Hospital Admissions", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Hospital Admissions", fontsize=10); .. container:: cell-output cell-output-display |image6| We can also take a look at the latent infections: .. container:: cell .. code:: python out2 = hosp_model.plot_posterior( var="all_latent_infections", ylab="Latent Infections" ) .. container:: cell-output cell-output-display |image7| and the distribution of latent infections .. container:: cell .. code:: python x_data = idata.posterior["all_latent_infections_dim_0"] y_data = idata.posterior["all_latent_infections"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( x_data, y_data, hdi_prob=0.9, color="C0", smooth=False, fill_kwargs={"alpha": 0.3}, ax=axes, ) az.plot_hdi( x_data, y_data, hdi_prob=0.5, color="C0", smooth=False, fill_kwargs={"alpha": 0.6}, ax=axes, ) # Add mean of the posterior to the figure mean_latent_infection = np.mean( idata.posterior["all_latent_infections"], axis=1 ) axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean") axes.legend() axes.set_title("Posterior Latent Infections", fontsize=10) axes.set_xlabel("Time", fontsize=10) axes.set_ylabel("Latent Infections", fontsize=10); .. container:: cell-output cell-output-display |image8| Round 2: Incorporating day-of-the-week effects ============================================== We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect distribution. To do this, we will create a new instance of ``RandomVariable`` to model the effect. The class will be based on a truncated normal distribution with a mean of 1.0 and a standard deviation of 0.5. The distribution will be truncated between 0.1 and 10.0. The random variable will be repeated for the number of weeks in the dataset. Note a similar weekday effect is implemented in its own module, with example code `here `__. .. container:: cell :name: weekly-effect .. code:: python from pyrenew import metaclass import numpyro as npro class DayOfWeekEffect(metaclass.RandomVariable): """Day of the week effect""" def __init__(self, len: int): """Initialize the day of the week effect distribution Parameters ---------- len : int The number of observations """ self.nweeks = int(jnp.ceil(len / 7)) self.len = len @staticmethod def validate(): return None def sample(self, **kwargs): ans = npro.sample( name="dayofweek_effect", fn=npro.distributions.TruncatedNormal( loc=1.0, scale=0.5, low=0.1, high=10.0 ), sample_shape=(7,), ) return jnp.tile(ans, self.nweeks)[: self.len] # Initializing the RV dayofweek_effect = DayOfWeekEffect(dat.shape[0]) Notice that the instance’s ``nweeks`` and ``len`` members are passed during construction. Trying to compute the number of weeks and the length of the dataset in the ``validate`` method will raise a ``jit`` error in ``jax`` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new effect, we can rebuild the latent hospitalization model: .. container:: cell :name: latent-hosp-weekday .. code:: python latent_hosp_wday_effect = latent.HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp_int, infect_hosp_rate_rv=hosp_rate, day_of_week_effect_rv=dayofweek_effect, ) hosp_model_weekday = model.HospitalAdmissionsModel( latent_infections_rv=latent_inf, latent_hosp_admissions_rv=latent_hosp_wday_effect, I0_rv=I0, gen_int_rv=gen_int, Rt_process_rv=rtproc, hosp_admission_obs_process_rv=obs, ) Running the model (with the same padding as before): .. container:: cell :name: model-2-run .. code:: python hosp_model_weekday.run( num_samples=2000, num_warmup=2000, data_observed_hosp_admissions=dat_w_padding, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), padding=days_to_impute, ) And plotting the results: .. container:: cell .. code:: python out = hosp_model_weekday.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", obs_signal=np.pad( dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan ), ) .. container:: cell-output cell-output-display |image9| .. [1] The output is captured to avoid ``quarto`` from displaying the output twice. .. |image1| image:: example_with_datasets_files/figure-rst/fig-plot-hospital-admissions-output-1.png .. |image2| image:: example_with_datasets_files/figure-rst/fig-data-extract-output-1.png .. |image3| image:: example_with_datasets_files/figure-rst/fig-basic-output-1.png .. |image4| image:: example_with_datasets_files/figure-rst/fig-output-hospital-admissions-output-1.png .. |image5| image:: example_with_datasets_files/figure-rst/fig-output-admissions-with-padding-output-1.png .. |image6| image:: example_with_datasets_files/figure-rst/fig-output-admission-distribution-output-1.png .. |image7| image:: example_with_datasets_files/figure-rst/fig-output-infections-with-padding-output-1.png .. |image8| image:: example_with_datasets_files/figure-rst/fig-output-infections-distribution-output-1.png .. |image9| image:: example_with_datasets_files/figure-rst/fig-output-admissions-padding-and-weekday-output-1.png