Fitting a hospital admissions-only model#

This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.

We begin by loading numpyro and configuring the device count to 2 to enable running MCMC chains in parallel. By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s set_host_device_count() function to set the number of devices available for parallel computing.

import numpyro

numpyro.set_host_device_count(2)

Model definition#

In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions:

\[h(t) \sim \text{HospDist}\left(H(t)\right)\]

Where \(h(t)\) is the observed number of hospital admissions at time \(t\), and \(H(t)\) is the number of latent hospital admissions at time \(t\). The distribution \(\text{HospDist}\) is discrete. For this example, we will use a negative binomial distribution with an inferred concentration.

\[\begin{split}\begin{align*} h(t) & \sim \mathrm{NegativeBinomial}\left(\mathrm{mean} = H(t), \mathrm{concentration} = k\right) \\ H(t) & = p_\mathrm{hosp}(t) \sum_{\tau = 0}^{T_d} d(\tau) I(t-\tau) \\ \log[p_\mathrm{hosp}(t)] & \sim \mathrm{Normal}(\mu=\log(0.05), \sigma=\log(1.1) \\ \log(k) & \sim \mathrm{Normal}(\mu=\log(1), \sigma=\log(10)) \end{align*}\end{split}\]

Were \(d(\tau)\) is the infection to hospital admission interval, \(I(t)\) is the number of latent infections at time \(t\), \(p_\mathrm{hosp}(t)\) is the infection to admission rate.

The number of latent hospital admissions at time \(t\) is a function of the number of latent infections at time \(t\) and the infection to admission rate. The latent infections are modeled as a renewal process:

\[\begin{split}\begin{align*} I(t) &= \mathcal{R}(t) \times \sum_{\tau < t} I(\tau) g(t - \tau) \\ \log[I(0)] &\sim \text{Normal}(\mu=\log(100), \sigma=\log(1.75)) \end{align*}\end{split}\]

The reproduction number \(\mathcal{R}(t)\) is modeled as a random walk in logarithmic space, i.e.:

\[\begin{split}\begin{align*} \log[\mathcal{R}(t)] & = \log[\mathcal{R}(t-1)] + \epsilon\\ \epsilon & \sim \text{Normal}(\mu=0, \sigma=0.025) \\ \mathcal{R}(0) &\sim \text{TruncatedNormal}(\text{loc}=1.2, \text{scale}=0.2, \text{min}=0) \end{align*}\end{split}\]

Data processing#

We start by loading the data and inspecting the first five rows.

import polars as pl
from pyrenew import datasets

dat = datasets.load_wastewater()
dat.head(5)
shape: (5, 14)

t

lab _wwtp _uniq ue_id

log _conc

date

lod_s ewage

belo w_lod

da ily_h osp_a dmits

d aily_ hosp_ admit s_for _eval

pop

for ecast _date

h osp_c alibr ation _time

site

w w_pop

inf_ per_c apita

i64

i64

f64

date

f64

i64

i64

i64

f64

date

i64

i64

f64

f64

1

1

null

2023- 10-30

null

null

6

6

1e6

2024- 02-05

90

1

400 000.0

0.0 00663

1

2

null

2023- 10-30

null

null

6

6

1e6

2024- 02-05

90

1

400 000.0

0.0 00663

1

3

null

2023- 10-30

null

null

6

6

1e6

2024- 02-05

90

2

200 000.0

0.0 00663

1

4

null

2023- 10-30

null

null

6

6

1e6

2024- 02-05

90

3

100 000.0

0.0 00663

1

5

null

2023- 10-30

null

null

6

6

1e6

2024- 02-05

90

4

50 000.0

0.0 00663

The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day.

# Keeping the first observation of each date
dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"])

# Now, sorting by date
dat = dat.sort("date")

# Keeping the first 90 days
dat = dat.head(90)

dat.head(5)
shape: (5, 2)

date

daily_hosp_admits

date

i64

2023-10-30

6

2023-10-31

8

2023-11-01

4

2023-11-02

8

2023-11-03

4

Let’s take a look at the daily prevalence of hospital admissions.

import matplotlib.pyplot as plt
import matplotlib.dates as mdates

daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
dates = dat["date"].to_numpy()
ax = plt.gca()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
ax.set_xlim(dates[0], dates[-1])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.plot(dates, daily_hosp_admits, "-o")
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()

image1

Building the model#

First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval.

gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()

# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()

# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int_array[:5]

# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)

axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()

image2

With these two in hand, we can start building the model. First, we will define the latent hospital admissions:

from pyrenew import latent, deterministic, metaclass, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
    name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = randomvariable.DistributionalVariable(
    name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)

latent_hosp = latent.HospitalAdmissions(
    infection_to_admission_interval_rv=inf_hosp_int,
    infection_hospitalization_ratio_rv=hosp_rate,
)

The inf_hosp_int is a DeterministicPMF object that takes the infection to hospital admission interval as input. The hosp_rate is a DistributionalVariable object that takes a numpyro distribution to represent the infection to hospital admission rate. The HospitalAdmissions class is a RandomVariable that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components:

from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import (
    InfectionInitializationProcess,
    InitializeInfectionsExponentialGrowth,
)


# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
    "I0_initialization",
    randomvariable.DistributionalVariable(
        name="I0",
        distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
    ),
    InitializeInfectionsExponentialGrowth(
        n_initialization_points,
        deterministic.DeterministicVariable(name="rate", value=0.05),
    ),
)

# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)


class MyRt(metaclass.RandomVariable):

    def validate(self):
        pass

    def sample(self, n: int, **kwargs) -> tuple:
        sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

        rt_rv = randomvariable.TransformedVariable(
            name="log_rt_random_walk",
            base_rv=process.RandomWalk(
                name="log_rt",
                step_rv=randomvariable.DistributionalVariable(
                    name="rw_step_rv", distribution=dist.Normal(0, 0.025)
                ),
            ),
            transforms=transformation.ExpTransform(),
        )
        rt_init_rv = randomvariable.DistributionalVariable(
            name="init_log_rt", distribution=dist.Normal(0, 0.2)
        )
        init_rt = rt_init_rv.sample()

        return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rtproc = MyRt()

# The observation model

# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = randomvariable.TransformedVariable(
    "concentration",
    randomvariable.DistributionalVariable(
        name="concentration_raw",
        distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
    ),
    transformation.PowerTransform(-2),
)

# now we define the observation process
obs = observation.NegativeBinomialObservation(
    "negbinom_rv",
    concentration_rv=nb_conc_rv,
)

Notice all the components are RandomVariable instances. We can now build the model:

hosp_model = model.HospitalAdmissionsModel(
    latent_infections_rv=latent_inf,
    latent_hosp_admissions_rv=latent_hosp,
    I0_rv=I0,
    gen_int_rv=gen_int,
    Rt_process_rv=rtproc,
    hosp_admission_obs_process_rv=obs,
)

Let’s simulate from the prior predictive distribution to check that the model is working:

import numpy as np

timeframe = 120

with numpyro.handlers.seed(rng_seed=223):
    simulated_data = hosp_model.sample(n_datapoints=timeframe)
import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].set_ylabel("Simulated Rt")

# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].set_ylabel("Simulated Admissions")

fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()

image3

Fitting the model#

We now fit the model, not to these simulated data, but rather to the dataset we retrieved above. We use the run method of the Model object:

import jax

hosp_model.run(
    num_samples=1000,
    num_warmup=1000,
    data_observed_hosp_admissions=daily_hosp_admits,
    rng_key=jax.random.key(54),
    mcmc_args=dict(progress_bar=False, num_chains=2),
)

We can use the Model object’s plot_posterior method to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process) [1]:

out = hosp_model.plot_posterior(
    var="latent_hospital_admissions",
    ylab="Hospital Admissions",
    obs_signal=daily_hosp_admits.astype(float),
)

image4

Results exploration and MCMC diagnostics#

To explore further, We can use ArviZ to visualize the results. Let’s start by loading the module and converting the fitted model to ArviZ InferenceData object:

import arviz as az

idata = az.from_numpyro(hosp_model.mcmc)

We obtain the summary of model diagnostics and print the diagnostics for latent_hospital_admissions[1]

diagnostic_stats_summary = az.summary(
    idata.posterior,
    kind="diagnostics",
)

print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"])
mcse_mean       0.013
mcse_sd         0.009
ess_bulk     2118.000
ess_tail     1652.000
r_hat           1.000
Name: latent_hospital_admissions[1], dtype: float64

Below we plot 90% and 50% highest density intervals for latent hospital admissions using plot_hdi:

x_data = idata.posterior["latent_hospital_admissions_dim_0"]
y_data = idata.posterior["latent_hospital_admissions"]


fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])

axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Hospital Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()

image5

We can also look at credible intervals for the posterior distribution of latent infections:

x_data = (
    idata.posterior["all_latent_infections_dim_0"] - n_initialization_points
)
y_data = idata.posterior["all_latent_infections"]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()

image6

Predictive checks and forecasting#

We can use the Model’s posterior_predictive and prior_predictive methods to generate posterior and prior predictive samples for observed admissions.

import arviz as az

idata = az.from_numpyro(
    hosp_model.mcmc,
    posterior_predictive=hosp_model.posterior_predictive(
        n_datapoints=len(daily_hosp_admits)
    ),
    prior=hosp_model.prior_predictive(
        n_datapoints=len(daily_hosp_admits),
        numpyro_predictive_args={"num_samples": 1000},
    ),
)

We will use plot_lm method from ArviZ to plot the posterior predictive distribution against the actual observed data below:

fig, ax = plt.subplots()
az.plot_lm(
    "negbinom_rv",
    idata=idata,
    kind_pp="hdi",
    y_kwargs={"color": "black"},
    y_hat_fill_kwargs={"color": "C0"},
    axes=ax,
)

ax.set_title("Posterior Predictive Plot")
ax.set_ylabel("Hospital Admissions")
ax.set_xlabel("Days")
plt.show()

image7

By increasing n_datapoints, we can perform forecasting using the posterior predictive distribution.

n_forecast_points = 28
idata = az.from_numpyro(
    hosp_model.mcmc,
    posterior_predictive=hosp_model.posterior_predictive(
        n_datapoints=len(daily_hosp_admits) + n_forecast_points,
    ),
    prior=hosp_model.prior_predictive(
        n_datapoints=len(daily_hosp_admits),
        numpyro_predictive_args={"num_samples": 1000},
    ),
)

Below we plot the prior predictive distributions using equal tailed Bayesian credible intervals:

def compute_eti(dataset, eti_prob):
    eti_bdry = dataset.quantile(
        ((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
    )
    return eti_bdry.T


fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    idata.prior_predictive["negbinom_rv_dim_0"],
    hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    idata.prior_predictive["negbinom_rv_dim_0"],
    hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

plt.scatter(
    idata.observed_data["negbinom_rv_dim_0"],
    idata.observed_data["negbinom_rv"],
    color="black",
)

axes.set_title("Prior Predictive Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Observed Admissions", fontsize=10)
plt.yscale("log")
plt.show()

image8

And now we plot the posterior predictive distributions with a 28-day-ahead forecast:

x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    hdi_data=compute_eti(y_data, 0.9),
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    hdi_data=compute_eti(y_data, 0.5),
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])

plt.plot(
    x_data,
    median_ts,
    color="C0",
    label="Median",
)
plt.scatter(
    idata.observed_data["negbinom_rv_dim_0"],
    idata.observed_data["negbinom_rv"],
    color="black",
)
axes.legend()
axes.set_title(
    "Posterior Predictive Admissions, including a forecast", fontsize=10
)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()

image9