========================================
Fitting a Hospital Admissions-only Model
========================================
This document illustrates how a hospital admissions-only model can be
fitted using data from the Pyrenew package, particularly the wastewater
dataset. The CFA wastewater team created this dataset, which contains
simulated data.
Model definition
================
In this section, we provide the formal definition of the model. The
hospitalization model is a semi-mechanistic model that describes the
number of observed hospital admissions as a function of a set of latent
variables. Mainly, the observed number of hospital admissions is
discretely distributed with location at the number of latent hospital
admissions:
.. math::
h(t) \sim \text{HospDist}\left(H(t)\right)
Where :math:`h(t)` is the observed number of hospital admissions at time
:math:`t`, and :math:`H(t)` is the number of latent hospital admissions
at time :math:`t`. The distribution :math:`\text{HospDist}` is discrete.
For this example, we will use a negative binomial distribution:
.. math::
\begin{align*}
h(t) & \sim \text{NegativeBinomial}\left(\text{concentration} = 1, \text{mean} = H(t)\right) \\
H(t) & = \omega(t) p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau)
\end{align*}
Were :math:`d(\tau)` is the infection to hospitalization interval,
:math:`I(t)` is the number of latent infections at time :math:`t`,
:math:`p_\mathrm{hosp}(t)` is the infection to hospitalization rate, and
:math:`\omega(t)` is the day-of-the-week effect at time :math:`t`; the
last section provides an example building such a ``RandomVariable``.
The number of latent hospital admissions at time :math:`t` is a function
of the number of latent infections at time :math:`t` and the infection
to hospitalization rate. The latent infections are modeled as a renewal
process:
.. math::
\begin{align*}
I(t) &= R(t) \times \sum_{\tau < t} I(\tau) g(t - \tau) \\
I(0) &\sim \text{LogNormal}(\mu = \log(80/0.05), \sigma = 1.5)
\end{align*}
The reproductive number :math:`R(t)` is modeled as a random walk
process:
.. math::
\begin{align*}
R(t) & = R(t-1) + \epsilon\\
\log{\epsilon} & \sim \text{Normal}(\mu=0, \sigma=0.1) \\
R(0) &\sim \text{TruncatedNormal}(\text{loc}=1.2, \text{scale}=0.2, \text{min}=0)
\end{align*}
Data processing
===============
We start by loading the data and inspecting the first five rows.
.. container:: cell
.. code:: python
import polars as pl
from pyrenew import datasets
dat = datasets.load_wastewater()
dat.head(5)
.. container:: cell-output cell-output-display
:name: data-inspect
.. raw:: html
shape: (5, 14)
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| t | lab | log | date | lod_s | belo | da | d | pop | for | h | site | w | inf_ |
| | _wwtp | _conc | | ewage | w_lod | ily_h | aily_ | | ecast | osp_c | | w_pop | per_c |
| | _uniq | | | | | osp_a | hosp_ | | _date | alibr | | | apita |
| | ue_id | | | | | dmits | admit | | | ation | | | |
| | | | | | | | s_for | | | _time | | | |
| | | | | | | | _eval | | | | | | |
+=====+=======+=======+=======+=======+=======+=======+=======+=====+=======+=======+======+=======+=======+
| i64 | i64 | f64 | date | f64 | i64 | i64 | i64 | f64 | date | i64 | i64 | f64 | f64 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| 1 | 1 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 |
| | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| 1 | 2 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 1 | 400 | 0.0 |
| | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| 1 | 3 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 2 | 200 | 0.0 |
| | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| 1 | 4 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 3 | 100 | 0.0 |
| | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
| 1 | 5 | null | 2023- | null | null | 6 | 6 | 1e6 | 2024- | 90 | 4 | 50 | 0.0 |
| | | | 10-30 | | | | | | 02-05 | | | 000.0 | 00663 |
+-----+-------+-------+-------+-------+-------+-------+-------+-----+-------+-------+------+-------+-------+
.. raw:: html
The data shows one entry per site, but the way it was simulated, the
number of admissions is the same across sites. Thus, we will only keep
the first observation per day.
.. container:: cell
.. code:: python
# Keeping the first observation of each date
dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"])
# Now, sorting by date
dat = dat.sort("date")
# Keeping the first 90 days
dat = dat.head(90)
dat.head(5)
.. container:: cell-output cell-output-display
:name: aggregation
.. raw:: html
shape: (5, 2)
========== =================
date daily_hosp_admits
========== =================
date i64
2023-10-30 6
2023-10-31 8
2023-11-01 4
2023-11-02 8
2023-11-03 4
========== =================
.. raw:: html
Let’s take a look at the daily prevalence of hospital admissions.
.. container:: cell
.. code:: python
import matplotlib.pyplot as plt
# Rotating the x-axis labels, and only showing ~10 labels
ax = plt.gca()
ax.xaxis.set_major_locator(plt.MaxNLocator(nbins=10))
ax.xaxis.set_tick_params(rotation=45)
plt.plot(dat["date"].to_numpy(), dat["daily_hosp_admits"].to_numpy())
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()
.. container:: cell-output cell-output-display
|image1|
Building the model
==================
First, we will extract two datasets we will use as deterministic
quantities: the generation interval and the infection to hospitalization
interval.
.. container:: cell
.. code:: python
gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy()
# Taking a pick at the first 5 elements of each
gen_int[:5], inf_hosp_int[:5]
# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)
axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int)
axs[1].set_title("Infection to hospitalization interval")
plt.show()
.. container:: cell-output cell-output-display
|image2|
With these two in hand, we can start building the model. First, we will
define the latent hospital admissions:
.. container:: cell
:name: latent-hosp
.. code:: python
from pyrenew import latent, deterministic, metaclass
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
inf_hosp_int, name="inf_hosp_int"
)
hosp_rate = metaclass.DistributionalRV(
dist=dist.LogNormal(jnp.log(0.05), 0.1),
name="IHR",
)
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infect_hosp_rate_rv=hosp_rate,
)
.. container:: cell-output cell-output-stderr
::
/home/runner/.cache/pypoetry/virtualenvs/pyrenew-ay_vsbmF-py3.12/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
The ``inf_hosp_int`` is a ``DeterministicPMF`` object that takes the
infection to hospitalization interval as input. The ``hosp_rate`` is a
``DistributionalRV`` object that takes a numpyro distribution to
represent the infection to hospitalization rate. The
``HospitalAdmissions`` class is a ``RandomVariable`` that takes two
distributions as inputs: the infection to admission interval and the
infection to hospitalization rate. Now, we can define the rest of the
other components:
.. container:: cell
:name: initializing-rest-of-model
.. code:: python
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import InfectionSeedingProcess, SeedInfectionsExponential
# Infection process
latent_inf = latent.Infections()
I0 = InfectionSeedingProcess(
"I0_seeding",
metaclass.DistributionalRV(
dist=dist.LogNormal(loc=jnp.log(100), scale=0.5), name="I0"
),
SeedInfectionsExponential(
gen_int_array.size,
deterministic.DeterministicVariable(0.5, name="rate"),
),
t_unit=1,
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(gen_int, name="gen_int")
rtproc = process.RtRandomWalkProcess(
Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
Rt_transform=transformation.ExpTransform().inv,
Rt_rw_dist=dist.Normal(0, 0.025),
)
# The observation model
obs = observation.NegativeBinomialObservation(concentration_prior=1.0)
Notice all the components are ``RandomVariable`` instances. We can now
build the model:
.. container:: cell
:name: init-model
.. code:: python
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
Let’s simulate to check if the model is working:
.. container:: cell
:name: simulation
.. code:: python
import numpyro as npro
import numpy as np
timeframe = 120
np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, timeframe)):
sim_data = hosp_model.sample(n_timepoints_to_simulate=timeframe)
.. container:: cell-output cell-output-stderr
::
/home/runner/.cache/pypoetry/virtualenvs/pyrenew-ay_vsbmF-py3.12/lib/python3.12/site-packages/jax/_src/numpy/lax_numpy.py:3044: FutureWarning: None encountered in jnp.array(); this is currently treated as NaN. In the future this will result in an error.
return array(arys[0], copy=False, ndmin=1)
.. container:: cell
.. code:: python
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_hosp_admissions)
axs[1].set_ylabel("Infections")
axs[1].set_yscale("log")
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
.. container:: cell-output cell-output-display
|image3|
Fitting the model
=================
We can fit the model to the data. We will use the ``run`` method of the
model object:
.. container:: cell
:name: model-fit
.. code:: python
import jax
hosp_model.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=dat["daily_hosp_admits"].to_numpy(),
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
.. container:: cell-output cell-output-stderr
::
/home/runner/work/multisignal-epi-inference/multisignal-epi-inference/model/src/pyrenew/metaclass.py:373: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
self.mcmc = MCMC(
We can use the ``plot_posterior`` method to visualize the results [1]_:
.. container:: cell
.. code:: python
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
(gen_int_array.size, 0),
constant_values=np.nan,
),
)
.. container:: cell-output cell-output-display
|image4|
The first half of the model is not looking good. The reason is that the
infection to hospitalization interval PMF makes it unlikely to observe
admissions from the beginning. The following section shows how to fix
this.
Padding the model
=================
We can use the padding argument to solve the overestimation of hospital
admissions in the first half of the model. By setting ``padding > 0``,
the model then assumes that the first ``padding`` observations are
missing; thus, only observations after ``padding`` will count towards
the likelihood of the model. In practice, the model will extend the
estimated Rt and latent infections by ``padding`` days, given time to
adjust to the observed data. The following code will add 21 days of
missing data at the beginning of the model and re-estimate it with
``padding = 21``:
.. container:: cell
:name: model-fit-padding
.. code:: python
days_to_impute = 21
# Add 21 Nas to the beginning of dat_w_padding
dat_w_padding = np.pad(
dat["daily_hosp_admits"].to_numpy().astype(float),
(days_to_impute, 0),
constant_values=np.nan,
)
hosp_model.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
padding=days_to_impute, # Padding the model
)
And plotting the results:
.. container:: cell
.. code:: python
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan
),
)
.. container:: cell-output cell-output-display
|image5|
We can use `ArviZ `__ to visualize the results.
Let’s start by converting the fitted model to Arviz InferenceData
object:
.. container:: cell
:name: convert-inferencedata
.. code:: python
import arviz as az
idata = az.from_numpyro(hosp_model.mcmc)
We obtain the summary of model diagnostics and print the diagnostics for
``latent_hospital_admissions[1]``
.. container:: cell
:name: diagnostics
.. code:: python
diagnostic_stats_summary = az.summary(
idata.posterior,
kind="diagnostics",
)
print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"])
.. container:: cell-output cell-output-stdout
::
mcse_mean 0.0
mcse_sd 0.0
ess_bulk 5000.0
ess_tail 2788.0
r_hat 1.0
Name: latent_hospital_admissions[1], dtype: float64
Below we plot 90% and 50% highest density intervals for latent hospital
admissions using
`plot_hdi `__:
.. container:: cell
.. code:: python
x_data = idata.posterior["latent_hospital_admissions_dim_0"]
y_data = idata.posterior["latent_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add mean of the posterior to the figure
mean_latent_hosp_admission = np.mean(
idata.posterior["latent_hospital_admissions"], axis=1
)
axes.plot(x_data, mean_latent_hosp_admission[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Hospital Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10);
.. container:: cell-output cell-output-display
|image6|
We can also take a look at the latent infections:
.. container:: cell
.. code:: python
out2 = hosp_model.plot_posterior(
var="all_latent_infections", ylab="Latent Infections"
)
.. container:: cell-output cell-output-display
|image7|
and the distribution of latent infections
.. container:: cell
.. code:: python
x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10);
.. container:: cell-output cell-output-display
|image8|
Round 2: Incorporating day-of-the-week effects
==============================================
We will re-use the infection to admission interval and infection to
hospitalization rate from the previous model. But we will also add a
day-of-the-week effect distribution. To do this, we will create a new
instance of ``RandomVariable`` to model the effect. The class will be
based on a truncated normal distribution with a mean of 1.0 and a
standard deviation of 0.5. The distribution will be truncated between
0.1 and 10.0. The random variable will be repeated for the number of
weeks in the dataset. Note a similar weekday effect is implemented in
its own module, with example code `here `__.
.. container:: cell
:name: weekly-effect
.. code:: python
from pyrenew import metaclass
import numpyro as npro
class DayOfWeekEffect(metaclass.RandomVariable):
"""Day of the week effect"""
def __init__(self, len: int):
"""Initialize the day of the week effect distribution
Parameters
----------
len : int
The number of observations
"""
self.nweeks = int(jnp.ceil(len / 7))
self.len = len
@staticmethod
def validate():
return None
def sample(self, **kwargs):
ans = npro.sample(
name="dayofweek_effect",
fn=npro.distributions.TruncatedNormal(
loc=1.0, scale=0.5, low=0.1, high=10.0
),
sample_shape=(7,),
)
return jnp.tile(ans, self.nweeks)[: self.len]
# Initializing the RV
dayofweek_effect = DayOfWeekEffect(dat.shape[0])
Notice that the instance’s ``nweeks`` and ``len`` members are passed
during construction. Trying to compute the number of weeks and the
length of the dataset in the ``validate`` method will raise a ``jit``
error in ``jax`` as the shape and size of elements are not known during
the validation step, which happens before the model is run. With the new
effect, we can rebuild the latent hospitalization model:
.. container:: cell
:name: latent-hosp-weekday
.. code:: python
latent_hosp_wday_effect = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infect_hosp_rate_rv=hosp_rate,
day_of_week_effect_rv=dayofweek_effect,
)
hosp_model_weekday = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp_wday_effect,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
Running the model (with the same padding as before):
.. container:: cell
:name: model-2-run
.. code:: python
hosp_model_weekday.run(
num_samples=2000,
num_warmup=2000,
data_observed_hosp_admissions=dat_w_padding,
rng_key=jax.random.PRNGKey(54),
mcmc_args=dict(progress_bar=False),
padding=days_to_impute,
)
And plotting the results:
.. container:: cell
.. code:: python
out = hosp_model_weekday.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=np.pad(
dat_w_padding, (gen_int_array.size, 0), constant_values=np.nan
),
)
.. container:: cell-output cell-output-display
|image9|
.. [1]
The output is captured to avoid ``quarto`` from displaying the output
twice.
.. |image1| image:: example_with_datasets_files/figure-rst/fig-plot-hospital-admissions-output-1.png
.. |image2| image:: example_with_datasets_files/figure-rst/fig-data-extract-output-1.png
.. |image3| image:: example_with_datasets_files/figure-rst/fig-basic-output-1.png
.. |image4| image:: example_with_datasets_files/figure-rst/fig-output-hospital-admissions-output-1.png
.. |image5| image:: example_with_datasets_files/figure-rst/fig-output-admissions-with-padding-output-1.png
.. |image6| image:: example_with_datasets_files/figure-rst/fig-output-admission-distribution-output-1.png
.. |image7| image:: example_with_datasets_files/figure-rst/fig-output-infections-with-padding-output-1.png
.. |image8| image:: example_with_datasets_files/figure-rst/fig-output-infections-distribution-output-1.png
.. |image9| image:: example_with_datasets_files/figure-rst/fig-output-admissions-padding-and-weekday-output-1.png