Fitting a hospital admissions-only model#
This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.
We begin by loading numpyro
and configuring the device count to 2 to
enable running MCMC chains in parallel. By default, XLA (which is used
by JAX for compilation) considers all CPU cores as one device. Depending
on your system’s configuration, we recommend using numpyro’s
set_host_device_count()
function to set the number of devices available for parallel computing.
import numpyro
numpyro.set_host_device_count(2)
Model definition#
In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions:
Where \(h(t)\) is the observed number of hospital admissions at time \(t\), and \(H(t)\) is the number of latent hospital admissions at time \(t\). The distribution \(\text{HospDist}\) is discrete. For this example, we will use a negative binomial distribution with an inferred concentration.
Were \(d(\tau)\) is the infection to hospital admission interval, \(I(t)\) is the number of latent infections at time \(t\), \(p_\mathrm{hosp}(t)\) is the infection to admission rate.
The number of latent hospital admissions at time \(t\) is a function of the number of latent infections at time \(t\) and the infection to admission rate. The latent infections are modeled as a renewal process:
The reproduction number \(\mathcal{R}(t)\) is modeled as a random walk in logarithmic space, i.e.:
Data processing#
We start by loading the data and inspecting the first five rows.
import polars as pl
from pyrenew import datasets
dat = datasets.load_wastewater()
dat.head(5)
t |
lab_wwtp_unique_id |
log_conc |
date |
lod_sewage |
below_lod |
daily_hosp_admits |
daily_hosp_admits_for_eval |
pop |
forecast_date |
hosp_calibration_time |
site |
ww_pop |
inf_per_capita |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
i64 |
i64 |
f64 |
date |
f64 |
i64 |
i64 |
i64 |
f64 |
date |
i64 |
i64 |
f64 |
f64 |
1 |
1 |
null |
2023-10-30 |
null |
null |
6 |
6 |
1e6 |
2024-02-05 |
90 |
1 |
400000.0 |
0.000663 |
1 |
2 |
null |
2023-10-30 |
null |
null |
6 |
6 |
1e6 |
2024-02-05 |
90 |
1 |
400000.0 |
0.000663 |
1 |
3 |
null |
2023-10-30 |
null |
null |
6 |
6 |
1e6 |
2024-02-05 |
90 |
2 |
200000.0 |
0.000663 |
1 |
4 |
null |
2023-10-30 |
null |
null |
6 |
6 |
1e6 |
2024-02-05 |
90 |
3 |
100000.0 |
0.000663 |
1 |
5 |
null |
2023-10-30 |
null |
null |
6 |
6 |
1e6 |
2024-02-05 |
90 |
4 |
50000.0 |
0.000663 |
The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day.
# Keeping the first observation of each date
dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"])
# Now, sorting by date
dat = dat.sort("date")
# Keeping the first 90 days
dat = dat.head(90)
dat.head(5)
date |
daily_hosp_admits |
---|---|
date |
i64 |
2023-10-30 |
6 |
2023-10-31 |
8 |
2023-11-01 |
4 |
2023-11-02 |
8 |
2023-11-03 |
4 |
Let’s take a look at the daily prevalence of hospital admissions.
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
dates = dat["date"].to_numpy()
ax = plt.gca()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
ax.set_xlim(dates[0], dates[-1])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.plot(dates, daily_hosp_admits, "-o")
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()
Figure 1: Daily hospital admissions from the simulated data
Building the model#
First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval.
gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int_array[:5]
# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)
axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()
Figure 2: Generation interval and infection to hospital admission interval
With these two in hand, we can start building the model. First, we will define the latent hospital admissions:
from pyrenew import latent, deterministic, metaclass, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infection_hospitalization_ratio_rv=hosp_rate,
)
The inf_hosp_int
is a DeterministicPMF
object that takes the
infection to hospital admission interval as input. The hosp_rate
is a
DistributionalVariable
object that takes a numpyro distribution to
represent the infection to hospital admission rate. The
HospitalAdmissions
class is a RandomVariable
that takes two
distributions as inputs: the infection to admission interval and the
infection to hospital admission rate. Now, we can define the rest of the
other components:
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
)
# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
class MyRt(metaclass.RandomVariable):
def validate(self):
pass
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = randomvariable.TransformedVariable(
name="log_rt_random_walk",
base_rv=process.RandomWalk(
name="log_rt",
step_rv=randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=transformation.ExpTransform(),
)
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rtproc = MyRt()
# The observation model
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
# now we define the observation process
obs = observation.NegativeBinomialObservation(
"negbinom_rv",
concentration_rv=nb_conc_rv,
)
Notice all the components are RandomVariable
instances. We can now
build the model:
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
Let’s simulate from the prior predictive distribution to check that the model is working:
import numpy as np
timeframe = 120
with numpyro.handlers.seed(rng_seed=223):
simulated_data = hosp_model.sample(n_datapoints=timeframe)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].set_ylabel("Simulated Rt")
# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].set_ylabel("Simulated Admissions")
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
Figure 3: Simulated Rt and Admissions
Fitting the model#
We now fit the model, not to these simulated data, but rather to the
dataset we retrieved above. We use the run
method of the Model
object:
import jax
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
We can use arviz
to visualize the model fit. Here, we plot the
observed values against the inferred latent values (i.e. the mean of the
negative binomial observation process)[1]:
import arviz as az
ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
)
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best"
)
plt.show()
Figure 4: Latent hospital admissions posterior samples (gray) and observed admissions timeseries (red).
Results exploration and MCMC diagnostics#
To explore further, We can use ArviZ to
visualize the results. Let’s start by loading the module and converting
the fitted model to ArviZ InferenceData
object:
idata = az.from_numpyro(hosp_model.mcmc)
We obtain the summary of model diagnostics and print the diagnostics for
latent_hospital_admissions[1]
diagnostic_stats_summary = az.summary(
idata.posterior,
kind="diagnostics",
)
print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"])
mcse_mean 0.013
mcse_sd 0.009
ess_bulk 2144.000
ess_tail 1307.000
r_hat 1.000
Name: latent_hospital_admissions[1], dtype: float64
Below we plot 90% and 50% highest density intervals for latent hospital admissions using plot_hdi:
x_data = idata.posterior["latent_hospital_admissions_dim_0"]
y_data = idata.posterior["latent_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Hospital Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
Figure 5: Hospital Admissions posterior distribution
We can also look at credible intervals for the posterior distribution of latent infections:
x_data = (
idata.posterior["all_latent_infections_dim_0"] - n_initialization_points
)
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
Figure 6: Posterior Latent Infections
Predictive checks and forecasting#
We can use the Model
’s posterior_predictive
and prior_predictive
methods to generate posterior and prior predictive samples for observed
admissions.
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_datapoints=len(daily_hosp_admits)
),
prior=hosp_model.prior_predictive(
n_datapoints=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
We will use plot_lm
method from ArviZ to plot the posterior predictive
distribution against the actual observed data below:
fig, ax = plt.subplots()
az.plot_lm(
"negbinom_rv",
idata=idata,
kind_pp="hdi",
y_kwargs={"color": "black"},
y_hat_fill_kwargs={"color": "C0"},
axes=ax,
)
ax.set_title("Posterior Predictive Plot")
ax.set_ylabel("Hospital Admissions")
ax.set_xlabel("Days")
plt.show()
Figure 7: Hospital Admissions posterior distribution with plot_lm
By increasing n_datapoints
, we can perform forecasting using the
posterior predictive distribution.
n_forecast_points = 28
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_datapoints=len(daily_hosp_admits) + n_forecast_points,
),
prior=hosp_model.prior_predictive(
n_datapoints=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
Below we plot the prior predictive distributions using equal tailed Bayesian credible intervals:
def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.T
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
axes.set_title("Prior Predictive Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Observed Admissions", fontsize=10)
plt.yscale("log")
plt.show()
Figure 8: Prior Predictive Admissions
And now we plot the posterior predictive distributions with a 28-day-ahead forecast:
x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
axes.legend()
axes.set_title(
"Posterior Predictive Admissions, including a forecast", fontsize=10
)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
Figure 9: Posterior predictive admissions, including a forecast.