Fitting a basic renewal model#
This document will show the steps to build a simple renewal model featuring a latent infection process, a random walk Rt process, and an observation process for the reported infections.
We start by loading the needed components to build a basic renewal model:
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.process import RandomWalk
from pyrenew.latent import (
Infections,
InfectionInitializationProcess,
InitializeInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam
By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s set_host_device_count() function to set the number of devices available for parallel computing. Here, we set the device count to be 2.
numpyro.set_host_device_count(2)
Architecture of RtInfectionsRenewalModel
#
pyrenew
leverages numpyro
’s flexibility to build models via
composition. As a principle, most objects in pyrenew
can be treated as
random variables we can sample. At the top-level pyrenew
has two
metaclasses from which most objects inherit: RandomVariable
and
Model
. From them, the following four sub-modules arise:
The
process
sub-module,The
deterministic
sub-module,The
observation
sub-module,The
latent
sub-module, andThe
models
sub-module
The first four are collections of instances of RandomVariable
, and the
last is a collection of instances of Model
. The following diagram
shows a detailed view of how metaclasses, modules, and classes interact
to create the RtInfectionsRenewalModel
instantiated in the next
section:
The pyrenew package models the real-time reproductive number \(\mathcal{R}(t)\), the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:
generation interval, the times between infections
initial infections, occurring prior to time \(t = 0\)
\(\mathcal{R}(t)\), the time-varying reproductive number,
latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and
observed infections, a subset of underlying true infections that are reported, perhaps via hospital admissions, physician’s office visits, or routine biosurveillance.
To initialize these five components within the renewal modeling framework, we estimate each component with:
In this example, the generation interval is not estimated but passed as a deterministic instance of
RandomVariable
an instance of the
InfectionInitializationProcess
class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifyingInitializeInfectionsZeroPad
, the latent infections before this time are assumed to be 0.A process to represent \(\mathcal{R}(t)\) as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom
RandomVariable
,MyRt
.an instance of the
Infections
class with default values, andan instance of the
PoissonObservation
class with default values
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
)
# (3) The random walk on log Rt, with an inferred s.d. Here, we
# construct a custom RandomVariable.
class MyRt(RandomVariable):
def validate(self):
pass
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rt_proc = MyRt()
# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()
# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation("poisson_rv")
With these five pieces, we can build the basic renewal model as an
instance of the RtInfectionsRenewalModel
class:
model1 = RtInfectionsRenewalModel(
gen_int_rv=gen_int,
I0_rv=I0,
Rt_process_rv=rt_proc,
latent_infections_rv=latent_infections,
infection_obs_process_rv=observation_process,
)
The following diagram summarizes how the modules interact via
composition; notably, gen_int
, I0
, rt_proc
, latent_infections
,
and observed_infections
are instances of RandomVariable
, which means
these can be easily replaced to generate a different instance of the
RtInfectionsRenewalModel
class:
Using numpyro
, we can simulate data using the sample()
member
function of RtInfectionsRenewalModel
. The sample()
method of the
RtInfectionsRenewalModel
class returns a list composed of the Rt
and
infections
sequences, called sim_data
:
with numpyro.handlers.seed(rng_seed=53):
sim_data = model1.sample(n_datapoints=40)
sim_data
RtInfectionsRenewalSample(Rt=[1.0193933 1.0239283 1.0201008 1.0572274 1.0299444 1.0925564 1.108174
1.0743027 1.0882349 1.1093383 1.1094797 1.1147494 1.1355134 1.1536688
1.1749985 1.2038172 1.243159 1.211333 1.2353402 1.1836537 1.2103558
1.2023542 1.2040431 1.199202 1.1868212 1.1774873 1.176468 1.2211965
1.2915479 1.2288901 1.263633 1.2350382 1.2147664 1.1920787 1.2335418
1.2487649 1.2792699 1.329618 1.3142097 1.3469969], latent_infections=[ 0. 0. 0. 28.330536 11.551984 13.433893
14.796828 15.956045 15.102532 16.531013 17.524725 17.817638
18.718653 19.957882 20.985529 22.191395 23.886274 25.847744
28.249094 31.359867 35.525574 38.58433 43.469505 46.4042
51.89032 56.786854 62.501434 68.4206 74.372246 80.603775
87.6321 98.85645 115.45172 124.639465 142.82314 157.46376
172.86795 187.65094 213.02754 239.5459 274.4592 323.1212
369.02765 435.61053 ], observed_infections=[ 8 12 15 17 12 15 20 22 20 22 13 28 24 32 27 27 36 35
53 35 71 61 60 68 62 83 101 103 122 124 136 137 173 190 231 247
258 353 379 455])
To understand what has been accomplished here, visualize an \(\mathcal{R}(t)\) sample path (left panel) and infections over time (right panel):
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
Figure 1: Rt and Infections
To fit the model, we can use the run()
method of the
RtInfectionsRenewalModel
class (an inherited method from the metaclass
Model
). model1.run()
will call the run
method of the model1
object, which will generate an instance of model MCMC simulation, with
2000 warm-up iterations for the MCMC algorithm, used to tune the
parameters of the MCMC algorithm to improve efficiency of the sampling
process. From the posterior distribution of the model parameters, 1000
samples will be drawn and used to estimate the posterior distributions
and compute summary statistics. Observed data is provided to the model
using the sim_data
object previously generated. mcmc_args
provides
additional arguments for the MCMC algorithm.
import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Now, let’s investigate the output, particularly the posterior distribution of the \(\mathcal{R}(t)\) estimates:
import arviz as az
# Create arviz inference data object
idata = az.from_numpyro(
posterior=model1.mcmc,
)
# Extract Rt signal samples across chains
rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values
# Plot Rt signal
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(
np.arange(rt.shape[0]),
rt,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples")
ax.plot(
np.arange(rt.shape[0]),
rt.mean(axis=1),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel(r"$\mathscr{R}_t$ Signal", fontsize=20)
ax.set_xlabel("Days", fontsize=20)
plt.show()
Figure 2: Rt posterior distribution
We can use the get_samples
method to extract samples from the model
Rt_samp = model1.mcmc.get_samples()["Rt"]
latent_infection_samp = model1.mcmc.get_samples()["all_latent_infections"]
We can also convert the fitted model to ArviZ InferenceData object and use ArviZ package to extarct samples, calculate statistics, create model diagnostics and visualizations.
import arviz as az
idata = az.from_numpyro(model1.mcmc)
and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number \(\mathcal{R}(t)\).
diagnostic_stats_summary = az.summary(
idata.posterior["Rt"][::, ::, 4:], # ignore nan padding
kind="diagnostics",
)
print(diagnostic_stats_summary)
mcse_mean mcse_sd ess_bulk ess_tail r_hat
Rt[4] 0.001 0.001 2423.0 1563.0 1.00
Rt[5] 0.001 0.001 2704.0 1537.0 1.00
Rt[6] 0.001 0.001 3346.0 1412.0 1.00
Rt[7] 0.001 0.000 3909.0 1632.0 1.00
Rt[8] 0.001 0.000 3603.0 1433.0 1.00
Rt[9] 0.001 0.000 3790.0 1617.0 1.00
Rt[10] 0.001 0.000 3766.0 1604.0 1.01
Rt[11] 0.001 0.000 4517.0 1239.0 1.00
Rt[12] 0.001 0.000 4372.0 1518.0 1.00
Rt[13] 0.000 0.000 4602.0 1561.0 1.00
Rt[14] 0.001 0.000 4300.0 1648.0 1.00
Rt[15] 0.001 0.000 3674.0 1585.0 1.00
Rt[16] 0.001 0.000 3474.0 1728.0 1.00
Rt[17] 0.001 0.000 3655.0 1663.0 1.00
Rt[18] 0.001 0.000 4183.0 1697.0 1.00
Rt[19] 0.001 0.000 3889.0 1792.0 1.00
Rt[20] 0.000 0.000 4536.0 1679.0 1.00
Rt[21] 0.001 0.000 4182.0 1720.0 1.00
Rt[22] 0.001 0.000 3714.0 1458.0 1.00
Rt[23] 0.001 0.000 3561.0 1522.0 1.00
Rt[24] 0.001 0.000 3262.0 1714.0 1.00
Rt[25] 0.001 0.000 3431.0 1902.0 1.00
Rt[26] 0.001 0.000 3602.0 1679.0 1.00
Rt[27] 0.000 0.000 3929.0 1340.0 1.00
Rt[28] 0.000 0.000 3446.0 1516.0 1.00
Rt[29] 0.000 0.000 3457.0 1618.0 1.00
Rt[30] 0.000 0.000 3614.0 1578.0 1.00
Rt[31] 0.000 0.000 3678.0 1885.0 1.00
Rt[32] 0.000 0.000 3680.0 1658.0 1.00
Rt[33] 0.001 0.000 3566.0 1233.0 1.00
Rt[34] 0.000 0.000 5470.0 1513.0 1.00
Rt[35] 0.000 0.000 3617.0 1695.0 1.00
Rt[36] 0.000 0.000 3082.0 1733.0 1.00
Rt[37] 0.001 0.000 2896.0 1939.0 1.00
Rt[38] 0.001 0.000 2873.0 1716.0 1.00
Rt[39] 0.001 0.001 3087.0 1611.0 1.00
Below we use plot_trace
to inspect the trace of the first 10 inferred
\(\mathcal{R}(t)\) values.
plt.rcParams["figure.constrained_layout.use"] = True
az.plot_trace(
idata.posterior,
var_names=["Rt"],
coords={"Rt_dim_0": np.arange(4, 14)},
compact=False,
)
plt.show()
Figure 3: Trace plot of Rt posterior distribution
We inspect the posterior distribution of \(\mathcal{R}(t)\) by plotting the 90% and 50% highest density intervals:
x_data = idata.posterior["Rt_dim_0"][4:]
y_data = idata.posterior["Rt"][::, ::, 4:]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add mean of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10)
plt.show()
Figure 4: High density interval for Effective Reproduction Number
and latent infections:
x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# plot the posterior median
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10)
plt.show()
Figure 5: High density interval for Latent Infections