Fitting a basic renewal model#
This document will show the steps to build a simple renewal model featuring a latent infection process, a random walk Rt process, and an observation process for the reported infections.
We start by loading the needed components to build a basic renewal model:
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.process import RandomWalk
from pyrenew.latent import (
Infections,
InfectionInitializationProcess,
InitializeInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam
By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s set_host_device_count() function to set the number of devices available for parallel computing. Here, we set the device count to be 2.
numpyro.set_host_device_count(2)
Architecture of RtInfectionsRenewalModel
#
pyrenew
leverages numpyro
’s flexibility to build models via
composition. As a principle, most objects in pyrenew
can be treated
as random variables we can sample. At the top-level pyrenew
has two
metaclasses from which most objects inherit: RandomVariable
and
Model
. From them, the following four sub-modules arise:
The
process
sub-module,The
deterministic
sub-module,The
observation
sub-module,The
latent
sub-module, andThe
models
sub-module
The first four are collections of instances of RandomVariable
, and
the last is a collection of instances of Model
. The following
diagram shows a detailed view of how metaclasses, modules, and classes
interact to create the RtInfectionsRenewalModel
instantiated in the
next section:
The pyrenew package models the real-time reproductive number \(\mathcal{R}(t)\), the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:
generation interval, the times between infections
initial infections, occurring prior to time \(t = 0\)
\(\mathcal{R}(t)\), the time-varying reproductive number,
latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and
observed infections, a subset of underlying true infections that are reported, perhaps via hospital admissions, physician’s office visits, or routine biosurveillance.
To initialize these five components within the renewal modeling framework, we estimate each component with:
In this example, the generation interval is not estimated but passed as a deterministic instance of
RandomVariable
an instance of the
InfectionInitializationProcess
class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifyingInitializeInfectionsZeroPad
, the latent infections before this time are assumed to be 0.A process to represent \(\mathcal{R}(t)\) as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom
RandomVariable
,MyRt
.an instance of the
Infections
class with default values, andan instance of the
PoissonObservation
class with default values
# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
)
# (3) The random walk on log Rt, with an inferred s.d. Here, we
# construct a custom RandomVariable.
class MyRt(RandomVariable):
def validate(self):
pass
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt = rt_init_rv.sample()
return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)
rt_proc = MyRt()
# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()
# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation("poisson_rv")
With these five pieces, we can build the basic renewal model as an
instance of the RtInfectionsRenewalModel
class:
model1 = RtInfectionsRenewalModel(
gen_int_rv=gen_int,
I0_rv=I0,
Rt_process_rv=rt_proc,
latent_infections_rv=latent_infections,
infection_obs_process_rv=observation_process,
)
The following diagram summarizes how the modules interact via
composition; notably, gen_int
, I0
, rt_proc
,
latent_infections
, and observed_infections
are instances of
RandomVariable
, which means these can be easily replaced to generate
a different instance of the RtInfectionsRenewalModel
class:
Using numpyro
, we can simulate data using the sample()
member
function of RtInfectionsRenewalModel
. The sample()
method of the
RtInfectionsRenewalModel
class returns a list composed of the Rt
and infections
sequences, called sim_data
:
with numpyro.handlers.seed(rng_seed=53):
sim_data = model1.sample(n_datapoints=40)
sim_data
RtInfectionsRenewalSample(Rt=[1.0193933 1.0239283 1.0201008 1.0572274 1.0299444 1.0925564 1.108174
1.0743027 1.0882349 1.1093383 1.1094797 1.1147494 1.1355134 1.1536688
1.1749985 1.2038172 1.243159 1.211333 1.2353402 1.1836537 1.2103558
1.2023542 1.2040431 1.199202 1.1868212 1.1774873 1.176468 1.2211965
1.2915479 1.2288901 1.263633 1.2350382 1.2147664 1.1920787 1.2335418
1.2487649 1.2792699 1.329618 1.3142097 1.3469969], latent_infections=[ 0. 0. 0. 5.365737 2.1879187 2.5443478
2.8024845 3.0220375 2.860384 3.1309352 3.319142 3.374619
3.545269 3.7799764 3.9746103 4.2029986 4.5240045 4.8955026
5.350312 5.939486 6.7284603 7.3077817 8.233023 8.788847
9.827905 10.755298 11.837625 12.958702 14.085928 15.266165
16.597315 18.723179 21.866283 23.60642 27.050367 29.823267
32.740788 35.54065 40.346916 45.369434 51.981926 61.198395
69.89298 82.50361 ], observed_infections=[ 5 2 2 4 1 3 7 4 3 4 6 5 6 6 7 1 7 2 9 10 4 9 11 17
20 5 17 14 33 23 26 24 22 27 42 53 65 62 71 91])
To understand what has been accomplished here, visualize an \(\mathcal{R}(t)\) sample path (left panel) and infections over time (right panel):
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2)
# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")
# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")
fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()
To fit the model, we can use the run()
method of the
RtInfectionsRenewalModel
class (an inherited method from the
metaclass Model
). model1.run()
will call the run
method of
the model1
object, which will generate an instance of model MCMC
simulation, with 2000 warm-up iterations for the MCMC algorithm, used to
tune the parameters of the MCMC algorithm to improve efficiency of the
sampling process. From the posterior distribution of the model
parameters, 1000 samples will be drawn and used to estimate the
posterior distributions and compute summary statistics. Observed data is
provided to the model using the sim_data
object previously
generated. mcmc_args
provides additional arguments for the MCMC
algorithm.
import jax
model1.run(
num_warmup=2000,
num_samples=1000,
data_observed_infections=sim_data.observed_infections,
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False, num_chains=2),
)
Now, let’s investigate the output, particularly the posterior distribution of the \(\mathcal{R}(t)\) estimates:
out = model1.plot_posterior(var="Rt")
We can use the get_samples
method to extract samples from the model
Rt_samp = model1.mcmc.get_samples()["Rt"]
latent_infection_samp = model1.mcmc.get_samples()["all_latent_infections"]
We can also convert the fitted model to ArviZ InferenceData object and use ArviZ package to extarct samples, calculate statistics, create model diagnostics and visualizations.
import arviz as az
idata = az.from_numpyro(model1.mcmc)
and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number \(\mathcal{R}(t)\).
diagnostic_stats_summary = az.summary(
idata.posterior["Rt"][::, ::, 4:], # ignore nan padding
kind="diagnostics",
)
print(diagnostic_stats_summary)
mcse_mean mcse_sd ess_bulk ess_tail r_hat
Rt[4] 0.001 0.001 2595.0 1781.0 1.00
Rt[5] 0.001 0.001 2711.0 1710.0 1.00
Rt[6] 0.001 0.001 2741.0 1662.0 1.00
Rt[7] 0.001 0.001 2826.0 1625.0 1.00
Rt[8] 0.001 0.001 2908.0 1479.0 1.00
Rt[9] 0.001 0.001 3410.0 1478.0 1.00
Rt[10] 0.001 0.001 3674.0 1605.0 1.00
Rt[11] 0.001 0.001 3509.0 1315.0 1.00
Rt[12] 0.001 0.000 3957.0 1228.0 1.00
Rt[13] 0.001 0.000 4078.0 1212.0 1.00
Rt[14] 0.001 0.000 4022.0 1506.0 1.00
Rt[15] 0.001 0.000 4218.0 1654.0 1.00
Rt[16] 0.001 0.000 4347.0 1791.0 1.00
Rt[17] 0.001 0.000 4017.0 1452.0 1.00
Rt[18] 0.001 0.000 3697.0 1467.0 1.00
Rt[19] 0.001 0.001 3121.0 1292.0 1.00
Rt[20] 0.001 0.000 3360.0 1128.0 1.00
Rt[21] 0.001 0.000 3860.0 1756.0 1.00
Rt[22] 0.001 0.000 3601.0 1697.0 1.00
Rt[23] 0.001 0.000 4003.0 1759.0 1.00
Rt[24] 0.001 0.000 3441.0 1356.0 1.00
Rt[25] 0.001 0.000 3275.0 1644.0 1.00
Rt[26] 0.001 0.000 3400.0 1441.0 1.01
Rt[27] 0.001 0.000 3455.0 1590.0 1.00
Rt[28] 0.001 0.000 4196.0 1454.0 1.00
Rt[29] 0.001 0.000 3952.0 1554.0 1.00
Rt[30] 0.001 0.000 3363.0 1660.0 1.00
Rt[31] 0.001 0.000 3338.0 1661.0 1.00
Rt[32] 0.001 0.000 4070.0 1679.0 1.00
Rt[33] 0.001 0.000 3552.0 1855.0 1.00
Rt[34] 0.001 0.000 3785.0 1714.0 1.00
Rt[35] 0.001 0.000 3359.0 1811.0 1.00
Rt[36] 0.001 0.001 3309.0 1562.0 1.00
Rt[37] 0.001 0.001 3085.0 1555.0 1.00
Rt[38] 0.001 0.001 3578.0 1677.0 1.00
Rt[39] 0.001 0.001 3168.0 1889.0 1.00
Below we use plot_trace
to inspect the trace of the first 10
inferred \(\mathcal{R}(t)\) values.
plt.rcParams["figure.constrained_layout.use"] = True
az.plot_trace(
idata.posterior,
var_names=["Rt"],
coords={"Rt_dim_0": np.arange(4, 14)},
compact=False,
)
plt.show()
We inspect the posterior distribution of \(\mathcal{R}(t)\) by plotting the 90% and 50% highest density intervals:
x_data = idata.posterior["Rt_dim_0"][4:]
y_data = idata.posterior["Rt"][::, ::, 4:]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# Add mean of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10)
plt.show()
and latent infections:
x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]
fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.9,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.3},
ax=axes,
)
az.plot_hdi(
x_data,
y_data,
hdi_prob=0.5,
color="C0",
smooth=False,
fill_kwargs={"alpha": 0.6},
ax=axes,
)
# plot the posterior median
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10)
plt.show()