Fitting a basic renewal model

Fitting a basic renewal model#

This document will show the steps to build a simple renewal model featuring a latent infection process, a random walk Rt process, and an observation process for the reported infections.

We start by loading the needed components to build a basic renewal model:

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.process import RandomWalk
from pyrenew.latent import (
    Infections,
    InfectionInitializationProcess,
    InitializeInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam

By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s set_host_device_count() function to set the number of devices available for parallel computing. Here, we set the device count to be 2.

numpyro.set_host_device_count(2)

Architecture of RtInfectionsRenewalModel#

pyrenew leverages numpyro’s flexibility to build models via composition. As a principle, most objects in pyrenew can be treated as random variables we can sample. At the top-level pyrenew has two metaclasses from which most objects inherit: RandomVariable and Model. From them, the following four sub-modules arise:

  • The process sub-module,

  • The deterministic sub-module,

  • The observation sub-module,

  • The latent sub-module, and

  • The models sub-module

The first four are collections of instances of RandomVariable, and the last is a collection of instances of Model. The following diagram shows a detailed view of how metaclasses, modules, and classes interact to create the RtInfectionsRenewalModel instantiated in the next section:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart LR rand((RandomVariable\nmetaclass)) models((Model\nmetaclass)) subgraph observations[Observations module] obs["infection_obs_process_rv\n(PoissonObservation)"] end subgraph latent[Latent module] inf["latent_infections_rv\n(Infections)"] i0["I0_rv\n(DistributionalVariable)"] end subgraph process[Process module] rt["Rt_process_rv\n(Custom class built using RandomWalk)"] end subgraph deterministic[Deterministic module] detpmf["gen_int_rv\n(DeterministicPMF)"] end subgraph model[Model module] model1["model1\n(RtInfectionsRenewalModel)"] end rand-->|Inherited by|observations rand-->|Inherited by|process rand-->|Inherited by|latent rand-->|Inherited by|deterministic models-->|Inherited by|model detpmf-->|Composes|model1 i0-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1

The pyrenew package models the real-time reproductive number \(\mathcal{R}(t)\), the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:

  1. generation interval, the times between infections

  2. initial infections, occurring prior to time \(t = 0\)

  3. \(\mathcal{R}(t)\), the time-varying reproductive number,

  4. latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and

  5. observed infections, a subset of underlying true infections that are reported, perhaps via hospital admissions, physician’s office visits, or routine biosurveillance.

To initialize these five components within the renewal modeling framework, we estimate each component with:

  1. In this example, the generation interval is not estimated but passed as a deterministic instance of RandomVariable

  2. an instance of the InfectionInitializationProcess class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying InitializeInfectionsZeroPad, the latent infections before this time are assumed to be 0.

  3. A process to represent \(\mathcal{R}(t)\) as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom RandomVariable, MyRt.

  4. an instance of the Infections class with default values, and

  5. an instance of the PoissonObservation class with default values

# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)

# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
    "I0_initialization",
    DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
    InitializeInfectionsZeroPad(pmf_array.size),
)


# (3) The random walk on log Rt, with an inferred s.d. Here, we
# construct a custom RandomVariable.
class MyRt(RandomVariable):
    def validate(self):
        pass

    def sample(self, n: int, **kwargs) -> tuple:
        sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

        rt_rv = TransformedVariable(
            name="log_rt_random_walk",
            base_rv=RandomWalk(
                name="log_rt",
                step_rv=DistributionalVariable(
                    name="rw_step_rv", distribution=dist.Normal(0, 0.025)
                ),
            ),
            transforms=t.ExpTransform(),
        )
        rt_init_rv = DistributionalVariable(
            name="init_log_rt", distribution=dist.Normal(0, 0.2)
        )
        init_rt = rt_init_rv.sample()

        return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rt_proc = MyRt()

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation("poisson_rv")

With these five pieces, we can build the basic renewal model as an instance of the RtInfectionsRenewalModel class:

model1 = RtInfectionsRenewalModel(
    gen_int_rv=gen_int,
    I0_rv=I0,
    Rt_process_rv=rt_proc,
    latent_infections_rv=latent_infections,
    infection_obs_process_rv=observation_process,
)

The following diagram summarizes how the modules interact via composition; notably, gen_int, I0, rt_proc, latent_infections, and observed_infections are instances of RandomVariable, which means these can be easily replaced to generate a different instance of the RtInfectionsRenewalModel class:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] rt["(3) rt_proc\n(MyRt, the custom RV defined above)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] model1["model1\n(RtInfectionsRenewalModel)"] i0-->|Composes|model1 genint-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1

Using numpyro, we can simulate data using the sample() member function of RtInfectionsRenewalModel. The sample() method of the RtInfectionsRenewalModel class returns a list composed of the Rt and infections sequences, called sim_data:

with numpyro.handlers.seed(rng_seed=53):
    sim_data = model1.sample(n_datapoints=40)

sim_data
RtInfectionsRenewalSample(Rt=[1.0193933 1.0239283 1.0201008 1.0572274 1.0299444 1.0925564 1.108174
 1.0743027 1.0882349 1.1093383 1.1094797 1.1147494 1.1355134 1.1536688
 1.1749985 1.2038172 1.243159  1.211333  1.2353402 1.1836537 1.2103558
 1.2023542 1.2040431 1.199202  1.1868212 1.1774873 1.176468  1.2211965
 1.2915479 1.2288901 1.263633  1.2350382 1.2147664 1.1920787 1.2335418
 1.2487649 1.2792699 1.329618  1.3142097 1.3469969], latent_infections=[  0.         0.         0.        28.330536  11.551984  13.433893
  14.796828  15.956045  15.102532  16.531013  17.524725  17.817638
  18.718653  19.957882  20.985529  22.191395  23.886274  25.847744
  28.249094  31.359867  35.525574  38.58433   43.469505  46.4042
  51.89032   56.786854  62.501434  68.4206    74.372246  80.603775
  87.6321    98.85645  115.45172  124.639465 142.82314  157.46376
 172.86795  187.65094  213.02754  239.5459   274.4592   323.1212
 369.02765  435.61053 ], observed_infections=[  8  12  15  17  12  15  20  22  20  22  13  28  24  32  27  27  36  35
  53  35  71  61  60  68  62  83 101 103 122 124 136 137 173 190 231 247
 258 353 379 455])

To understand what has been accomplished here, visualize an \(\mathcal{R}(t)\) sample path (left panel) and infections over time (right panel):

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")

# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")

fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()

Figure 1: Rt and Infections

To fit the model, we can use the run() method of the RtInfectionsRenewalModel class (an inherited method from the metaclass Model). model1.run() will call the run method of the model1 object, which will generate an instance of model MCMC simulation, with 2000 warm-up iterations for the MCMC algorithm, used to tune the parameters of the MCMC algorithm to improve efficiency of the sampling process. From the posterior distribution of the model parameters, 1000 samples will be drawn and used to estimate the posterior distributions and compute summary statistics. Observed data is provided to the model using the sim_data object previously generated. mcmc_args provides additional arguments for the MCMC algorithm.

import jax

model1.run(
    num_warmup=2000,
    num_samples=1000,
    data_observed_infections=sim_data.observed_infections,
    rng_key=jax.random.key(54),
    mcmc_args=dict(progress_bar=False, num_chains=2),
)

Now, let’s investigate the output, particularly the posterior distribution of the \(\mathcal{R}(t)\) estimates:

import arviz as az

# Create arviz inference data object
idata = az.from_numpyro(
    posterior=model1.mcmc,
)

# Extract Rt signal samples across chains
rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values


# Plot Rt signal
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(
    np.arange(rt.shape[0]),
    rt,
    color="skyblue",
    alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples")
ax.plot(
    np.arange(rt.shape[0]),
    rt.mean(axis=1),
    color="black",
    linewidth=2.0,
    linestyle="--",
    label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel(r"$\mathscr{R}_t$ Signal", fontsize=20)
ax.set_xlabel("Days", fontsize=20)
plt.show()

Figure 2: Rt posterior distribution

We can use the get_samples method to extract samples from the model

Rt_samp = model1.mcmc.get_samples()["Rt"]
latent_infection_samp = model1.mcmc.get_samples()["all_latent_infections"]

We can also convert the fitted model to ArviZ InferenceData object and use ArviZ package to extarct samples, calculate statistics, create model diagnostics and visualizations.

import arviz as az

idata = az.from_numpyro(model1.mcmc)

and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number \(\mathcal{R}(t)\).

diagnostic_stats_summary = az.summary(
    idata.posterior["Rt"][::, ::, 4:],  # ignore nan padding
    kind="diagnostics",
)

print(diagnostic_stats_summary)
        mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
Rt[4]       0.001    0.001    2423.0    1563.0   1.00
Rt[5]       0.001    0.001    2704.0    1537.0   1.00
Rt[6]       0.001    0.001    3346.0    1412.0   1.00
Rt[7]       0.001    0.000    3909.0    1632.0   1.00
Rt[8]       0.001    0.000    3603.0    1433.0   1.00
Rt[9]       0.001    0.000    3790.0    1617.0   1.00
Rt[10]      0.001    0.000    3766.0    1604.0   1.01
Rt[11]      0.001    0.000    4517.0    1239.0   1.00
Rt[12]      0.001    0.000    4372.0    1518.0   1.00
Rt[13]      0.000    0.000    4602.0    1561.0   1.00
Rt[14]      0.001    0.000    4300.0    1648.0   1.00
Rt[15]      0.001    0.000    3674.0    1585.0   1.00
Rt[16]      0.001    0.000    3474.0    1728.0   1.00
Rt[17]      0.001    0.000    3655.0    1663.0   1.00
Rt[18]      0.001    0.000    4183.0    1697.0   1.00
Rt[19]      0.001    0.000    3889.0    1792.0   1.00
Rt[20]      0.000    0.000    4536.0    1679.0   1.00
Rt[21]      0.001    0.000    4182.0    1720.0   1.00
Rt[22]      0.001    0.000    3714.0    1458.0   1.00
Rt[23]      0.001    0.000    3561.0    1522.0   1.00
Rt[24]      0.001    0.000    3262.0    1714.0   1.00
Rt[25]      0.001    0.000    3431.0    1902.0   1.00
Rt[26]      0.001    0.000    3602.0    1679.0   1.00
Rt[27]      0.000    0.000    3929.0    1340.0   1.00
Rt[28]      0.000    0.000    3446.0    1516.0   1.00
Rt[29]      0.000    0.000    3457.0    1618.0   1.00
Rt[30]      0.000    0.000    3614.0    1578.0   1.00
Rt[31]      0.000    0.000    3678.0    1885.0   1.00
Rt[32]      0.000    0.000    3680.0    1658.0   1.00
Rt[33]      0.001    0.000    3566.0    1233.0   1.00
Rt[34]      0.000    0.000    5470.0    1513.0   1.00
Rt[35]      0.000    0.000    3617.0    1695.0   1.00
Rt[36]      0.000    0.000    3082.0    1733.0   1.00
Rt[37]      0.001    0.000    2896.0    1939.0   1.00
Rt[38]      0.001    0.000    2873.0    1716.0   1.00
Rt[39]      0.001    0.001    3087.0    1611.0   1.00

Below we use plot_trace to inspect the trace of the first 10 inferred \(\mathcal{R}(t)\) values.

plt.rcParams["figure.constrained_layout.use"] = True

az.plot_trace(
    idata.posterior,
    var_names=["Rt"],
    coords={"Rt_dim_0": np.arange(4, 14)},
    compact=False,
)
plt.show()

Figure 3: Trace plot of Rt posterior distribution

We inspect the posterior distribution of \(\mathcal{R}(t)\) by plotting the 90% and 50% highest density intervals:

x_data = idata.posterior["Rt_dim_0"][4:]
y_data = idata.posterior["Rt"][::, ::, 4:]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add mean of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10)
plt.show()

Figure 4: High density interval for Effective Reproduction Number

and latent infections:

x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# plot the posterior median
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")

axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10)
plt.show()

Figure 5: High density interval for Latent Infections