Fitting a hospital admissions-only model#
This document illustrates how a hospital admissions-only model can be fitted using data from the Pyrenew package, particularly the wastewater dataset. The CFA wastewater team created this dataset, which contains simulated data.
We begin by loading numpyro
and configuring the device count to 2 to
enable running MCMC chains in parallel. By default, XLA (which is used
by JAX for compilation) considers all CPU cores as one device. Depending
on your system’s configuration, we recommend using numpyro’s
set_host_device_count()
function to set the number of devices available for parallel computing.
import numpyro
numpyro.set_host_device_count(2)
Model definition#
In this section, we provide the formal definition of the model. The hospital admissions model is a semi-mechanistic model that describes the number of observed hospital admissions as a function of a set of latent variables. Mainly, the observed number of hospital admissions is discretely distributed with location at the number of latent hospital admissions:
Where \(h(t)\) is the observed number of hospital admissions at time \(t\), and \(H(t)\) is the number of latent hospital admissions at time \(t\). The distribution \(\text{HospDist}\) is discrete. For this example, we will use a negative binomial distribution with an inferred concentration.
Were \(d(\tau)\) is the infection to hospital admission interval, \(I(t)\) is the number of latent infections at time \(t\), \(p_\mathrm{hosp}(t)\) is the infection to admission rate.
The number of latent hospital admissions at time \(t\) is a function of the number of latent infections at time \(t\) and the infection to admission rate. The latent infections are modeled as a renewal process:
The reproduction number \(\mathcal{R}(t)\) is modeled as a random walk in logarithmic space, i.e.:
Data processing#
We start by loading the data and inspecting the first five rows.
import polars as pl
from pyrenew import datasets
dat = datasets.load_wastewater()
dat.head(5)
t |
lab _wwtp _uniq ue_id |
log _conc |
date |
lod_s ewage |
belo w_lod |
da ily_h osp_a dmits |
pop |
for ecast _date |
h osp_c alibr ation _time |
site |
w w_pop |
inf_ per_c apita |
|
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
i64 |
i64 |
f64 |
date |
f64 |
i64 |
i64 |
i64 |
f64 |
date |
i64 |
i64 |
f64 |
f64 |
1 |
1 |
null |
2023- 10-30 |
null |
null |
6 |
6 |
1e6 |
2024- 02-05 |
90 |
1 |
400 000.0 |
0.0 00663 |
1 |
2 |
null |
2023- 10-30 |
null |
null |
6 |
6 |
1e6 |
2024- 02-05 |
90 |
1 |
400 000.0 |
0.0 00663 |
1 |
3 |
null |
2023- 10-30 |
null |
null |
6 |
6 |
1e6 |
2024- 02-05 |
90 |
2 |
200 000.0 |
0.0 00663 |
1 |
4 |
null |
2023- 10-30 |
null |
null |
6 |
6 |
1e6 |
2024- 02-05 |
90 |
3 |
100 000.0 |
0.0 00663 |
1 |
5 |
null |
2023- 10-30 |
null |
null |
6 |
6 |
1e6 |
2024- 02-05 |
90 |
4 |
50 000.0 |
0.0 00663 |
The data shows one entry per site, but the way it was simulated, the number of admissions is the same across sites. Thus, we will only keep the first observation per day.
# Keeping the first observation of each date
dat = dat.group_by("date").first().select(["date", "daily_hosp_admits"])
# Now, sorting by date
dat = dat.sort("date")
# Keeping the first 90 days
dat = dat.head(90)
dat.head(5)
date |
daily_hosp_admits |
---|---|
date |
i64 |
2023-10-30 |
6 |
2023-10-31 |
8 |
2023-11-01 |
4 |
2023-11-02 |
8 |
2023-11-03 |
4 |
Let’s take a look at the daily prevalence of hospital admissions.
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
daily_hosp_admits = dat["daily_hosp_admits"].to_numpy()
dates = dat["date"].to_numpy()
ax = plt.gca()
ax.xaxis.set_major_formatter(mdates.DateFormatter("%Y-%m-%d"))
ax.xaxis.set_major_locator(mdates.DayLocator(interval=7))
ax.set_xlim(dates[0], dates[-1])
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
plt.plot(dates, daily_hosp_admits, "-o")
plt.xlabel("Date")
plt.ylabel("Admissions")
plt.show()
Building the model#
First, we will extract two datasets we will use as deterministic quantities: the generation interval and the infection to hospital admission interval.
gen_int = datasets.load_generation_interval()
inf_hosp_int = datasets.load_infection_admission_interval()
# We only need the probability_mass column of each dataset
gen_int_array = gen_int["probability_mass"].to_numpy()
gen_int = gen_int_array
inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
# Taking a peek at the first 5 elements of each
gen_int[:5], inf_hosp_int_array[:5]
# Visualizing both quantities side by side
fig, axs = plt.subplots(1, 2)
axs[0].plot(gen_int)
axs[0].set_title("Generation interval")
axs[1].plot(inf_hosp_int_array)
axs[1].set_title("Infection to hospital admission interval")
plt.show()
With these two in hand, we can start building the model. First, we will define the latent hospital admissions:
from pyrenew import latent, deterministic, metaclass, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist
inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)
latent_hosp = latent.HospitalAdmissions(
infection_to_admission_interval_rv=inf_hosp_int,
infection_hospitalization_ratio_rv=hosp_rate,
)
The inf_hosp_int
is a DeterministicPMF
object that takes the
infection to hospital admission interval as input. The hosp_rate
is
a DistributionalVariable
object that takes a numpyro distribution to
represent the infection to hospital admission rate. The
HospitalAdmissions
class is a RandomVariable
that takes two
distributions as inputs: the infection to admission interval and the
infection to hospital admission rate. Now, we can define the rest of the
other components:
from pyrenew import model, process, observation, metaclass, transformation
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
)
# Infection process
latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
InitializeInfectionsExponentialGrowth(
n_initialization_points,
deterministic.DeterministicVariable(name="rate", value=0.05),
),
)
# Generation interval and Rt
gen_int = deterministic.DeterministicPMF(name="gen_int", value=gen_int)
class MyRt(metaclass.RandomVariable):
def validate(self):
pass
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = randomvariable.TransformedVariable(
name="log_rt_random_walk",
base_rv=process.RandomWalk(
name="log_rt",
step_rv=randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=transformation.ExpTransform(),
)
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rtproc = MyRt()
# The observation model
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
transformation.PowerTransform(-2),
)
# now we define the observation process
obs = observation.NegativeBinomialObservation(
"negbinom_rv",
concentration_rv=nb_conc_rv,
)
Notice all the components are RandomVariable
instances. We can now
build the model:
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
I0_rv=I0,
gen_int_rv=gen_int,
Rt_process_rv=rtproc,
hosp_admission_obs_process_rv=obs,
)
Let’s simulate from the prior predictive distribution to check that the model is working:
import numpy as np
timeframe = 120
with numpyro.handlers.seed(rng_seed=223):
simulated_data = hosp_model.sample(n_datapoints=timeframe)
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(simulated_data.Rt)
axs[0].set_ylabel("Simulated Rt")
# Admissions plot
axs[1].plot(simulated_data.observed_hosp_admissions, "-o")
axs[1].set_ylabel("Simulated Admissions")
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
Fitting the model#
We now fit the model, not to these simulated data, but rather to the
dataset we retrieved above. We use the run
method of the Model
object:
import jax
hosp_model.run(
num_samples=1000,
num_warmup=1000,
data_observed_hosp_admissions=daily_hosp_admits,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
We can use the Model
object’s plot_posterior
method to visualize
the model fit. Here, we plot the observed values against the inferred
latent values (i.e. the mean of the negative binomial observation
process) [1]:
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
)
Results exploration and MCMC diagnostics#
To explore further, We can use ArviZ to
visualize the results. Let’s start by loading the module and converting
the fitted model to ArviZ InferenceData
object:
import arviz as az
idata = az.from_numpyro(hosp_model.mcmc)
We obtain the summary of model diagnostics and print the diagnostics for
latent_hospital_admissions[1]
diagnostic_stats_summary = az.summary(
idata.posterior,
kind="diagnostics",
)
print(diagnostic_stats_summary.loc["latent_hospital_admissions[1]"])
mcse_mean 0.013
mcse_sd 0.009
ess_bulk 2118.000
ess_tail 1652.000
r_hat 1.000
Name: latent_hospital_admissions[1], dtype: float64
Below we plot 90% and 50% highest density intervals for latent hospital admissions using plot_hdi:
x_data = idata.posterior["latent_hospital_admissions_dim_0"]
y_data = idata.posterior["latent_hospital_admissions"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Hospital Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()
We can also look at credible intervals for the posterior distribution of latent infections:
x_data = (
idata.posterior["all_latent_infections_dim_0"] - n_initialization_points
)
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add the posterior median to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
Predictive checks and forecasting#
We can use the Model
’s posterior_predictive
and
prior_predictive
methods to generate posterior and prior predictive
samples for observed admissions.
import arviz as az
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_datapoints=len(daily_hosp_admits)
),
prior=hosp_model.prior_predictive(
n_datapoints=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
We will use plot_lm
method from ArviZ to plot the posterior
predictive distribution against the actual observed data below:
fig, ax = plt.subplots()
az.plot_lm(
"negbinom_rv",
idata=idata,
kind_pp="hdi",
y_kwargs={"color": "black"},
y_hat_fill_kwargs={"color": "C0"},
axes=ax,
)
ax.set_title("Posterior Predictive Plot")
ax.set_ylabel("Hospital Admissions")
ax.set_xlabel("Days")
plt.show()
By increasing n_datapoints
, we can perform forecasting using the
posterior predictive distribution.
n_forecast_points = 28
idata = az.from_numpyro(
hosp_model.mcmc,
posterior_predictive=hosp_model.posterior_predictive(
n_datapoints=len(daily_hosp_admits) + n_forecast_points,
),
prior=hosp_model.prior_predictive(
n_datapoints=len(daily_hosp_admits),
numpyro_predictive_args={"num_samples": 1000},
),
)
Below we plot the prior predictive distributions using equal tailed Bayesian credible intervals:
def compute_eti(dataset, eti_prob):
eti_bdry = dataset.quantile(
((1 - eti_prob) / 2, 1 / 2 + eti_prob / 2), dim=("chain", "draw")
)
return eti_bdry.T
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
idata.prior_predictive["negbinom_rv_dim_0"],
hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
axes.set_title("Prior Predictive Admissions", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Observed Admissions", fontsize=10)
plt.yscale("log")
plt.show()
And now we plot the posterior predictive distributions with a 28-day-ahead forecast:
x_data = idata.posterior_predictive["negbinom_rv_dim_0"]
y_data = idata.posterior_predictive["negbinom_rv"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.9),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
hdi_data=compute_eti(y_data, 0.5),
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add median of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
plt.plot(
x_data,
median_ts,
color="C0",
label="Median",
)
plt.scatter(
idata.observed_data["negbinom_rv_dim_0"],
idata.observed_data["negbinom_rv"],
color="black",
)
axes.legend()
axes.set_title(
"Posterior Predictive Admissions, including a forecast", fontsize=10
)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Hospital Admissions", fontsize=10)
plt.show()