Fitting a basic renewal model

Fitting a basic renewal model#

This document will show the steps to build a simple renewal model featuring a latent infection process, a random walk Rt process, and an observation process for the reported infections.

We start by loading the needed components to build a basic renewal model:

import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.process import RandomWalk
from pyrenew.latent import (
    Infections,
    InfectionInitializationProcess,
    InitializeInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam

By default, XLA (which is used by JAX for compilation) considers all CPU cores as one device. Depending on your system’s configuration, we recommend using numpyro’s set_host_device_count() function to set the number of devices available for parallel computing. Here, we set the device count to be 2.

numpyro.set_host_device_count(2)

Architecture of RtInfectionsRenewalModel#

pyrenew leverages numpyro’s flexibility to build models via composition. As a principle, most objects in pyrenew can be treated as random variables we can sample. At the top-level pyrenew has two metaclasses from which most objects inherit: RandomVariable and Model. From them, the following four sub-modules arise:

  • The process sub-module,

  • The deterministic sub-module,

  • The observation sub-module,

  • The latent sub-module, and

  • The models sub-module

The first four are collections of instances of RandomVariable, and the last is a collection of instances of Model. The following diagram shows a detailed view of how metaclasses, modules, and classes interact to create the RtInfectionsRenewalModel instantiated in the next section:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart LR rand((RandomVariable\nmetaclass)) models((Model\nmetaclass)) subgraph observations[Observations module] obs["infection_obs_process_rv\n(PoissonObservation)"] end subgraph latent[Latent module] inf["latent_infections_rv\n(Infections)"] i0["I0_rv\n(DistributionalVariable)"] end subgraph process[Process module] rt["Rt_process_rv\n(Custom class built using RandomWalk)"] end subgraph deterministic[Deterministic module] detpmf["gen_int_rv\n(DeterministicPMF)"] end subgraph model[Model module] model1["model1\n(RtInfectionsRenewalModel)"] end rand-->|Inherited by|observations rand-->|Inherited by|process rand-->|Inherited by|latent rand-->|Inherited by|deterministic models-->|Inherited by|model detpmf-->|Composes|model1 i0-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1

The pyrenew package models the real-time reproductive number \(\mathcal{R}(t)\), the average number of secondary infections caused by an infected individual, as a renewal process model. Our basic renewal process model defines five components:

  1. generation interval, the times between infections

  2. initial infections, occurring prior to time \(t = 0\)

  3. \(\mathcal{R}(t)\), the time-varying reproductive number,

  4. latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and

  5. observed infections, a subset of underlying true infections that are reported, perhaps via hospital admissions, physician’s office visits, or routine biosurveillance.

To initialize these five components within the renewal modeling framework, we estimate each component with:

  1. In this example, the generation interval is not estimated but passed as a deterministic instance of RandomVariable

  2. an instance of the InfectionInitializationProcess class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying InitializeInfectionsZeroPad, the latent infections before this time are assumed to be 0.

  3. A process to represent \(\mathcal{R}(t)\) as a random walk on the log scale, with an inferred initial value and a fixed Normal step-size distribution. For this, we construct a custom RandomVariable, MyRt.

  4. an instance of the Infections class with default values, and

  5. an instance of the PoissonObservation class with default values

# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.4, 0.3, 0.2, 0.1])
gen_int = DeterministicPMF(name="gen_int", value=pmf_array)

# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
    "I0_initialization",
    DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
    InitializeInfectionsZeroPad(pmf_array.size),
)


# (3) The random walk on log Rt, with an inferred s.d. Here, we
# construct a custom RandomVariable.
class MyRt(RandomVariable):

    def validate(self):
        pass

    def sample(self, n: int, **kwargs) -> tuple:
        sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

        rt_rv = TransformedVariable(
            name="log_rt_random_walk",
            base_rv=RandomWalk(
                name="log_rt",
                step_rv=DistributionalVariable(
                    name="rw_step_rv", distribution=dist.Normal(0, 0.025)
                ),
            ),
            transforms=t.ExpTransform(),
        )
        rt_init_rv = DistributionalVariable(
            name="init_log_rt", distribution=dist.Normal(0, 0.2)
        )
        init_rt = rt_init_rv.sample()

        return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)


rt_proc = MyRt()

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation("poisson_rv")

With these five pieces, we can build the basic renewal model as an instance of the RtInfectionsRenewalModel class:

model1 = RtInfectionsRenewalModel(
    gen_int_rv=gen_int,
    I0_rv=I0,
    Rt_process_rv=rt_proc,
    latent_infections_rv=latent_infections,
    infection_obs_process_rv=observation_process,
)

The following diagram summarizes how the modules interact via composition; notably, gen_int, I0, rt_proc, latent_infections, and observed_infections are instances of RandomVariable, which means these can be easily replaced to generate a different instance of the RtInfectionsRenewalModel class:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionInitializationProcess)"] rt["(3) rt_proc\n(MyRt, the custom RV defined above)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] model1["model1\n(RtInfectionsRenewalModel)"] i0-->|Composes|model1 genint-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1

Using numpyro, we can simulate data using the sample() member function of RtInfectionsRenewalModel. The sample() method of the RtInfectionsRenewalModel class returns a list composed of the Rt and infections sequences, called sim_data:

with numpyro.handlers.seed(rng_seed=53):
    sim_data = model1.sample(n_datapoints=40)

sim_data
RtInfectionsRenewalSample(Rt=[1.0193933 1.0239283 1.0201008 1.0572274 1.0299444 1.0925564 1.108174
 1.0743027 1.0882349 1.1093383 1.1094797 1.1147494 1.1355134 1.1536688
 1.1749985 1.2038172 1.243159  1.211333  1.2353402 1.1836537 1.2103558
 1.2023542 1.2040431 1.199202  1.1868212 1.1774873 1.176468  1.2211965
 1.2915479 1.2288901 1.263633  1.2350382 1.2147664 1.1920787 1.2335418
 1.2487649 1.2792699 1.329618  1.3142097 1.3469969], latent_infections=[ 0.         0.         0.         5.365737   2.1879187  2.5443478
  2.8024845  3.0220375  2.860384   3.1309352  3.319142   3.374619
  3.545269   3.7799764  3.9746103  4.2029986  4.5240045  4.8955026
  5.350312   5.939486   6.7284603  7.3077817  8.233023   8.788847
  9.827905  10.755298  11.837625  12.958702  14.085928  15.266165
 16.597315  18.723179  21.866283  23.60642   27.050367  29.823267
 32.740788  35.54065   40.346916  45.369434  51.981926  61.198395
 69.89298   82.50361  ], observed_infections=[ 5  2  2  4  1  3  7  4  3  4  6  5  6  6  7  1  7  2  9 10  4  9 11 17
 20  5 17 14 33 23 26 24 22 27 42 53 65 62 71 91])

To understand what has been accomplished here, visualize an \(\mathcal{R}(t)\) sample path (left panel) and infections over time (right panel):

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")

# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")

fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()

image1

To fit the model, we can use the run() method of the RtInfectionsRenewalModel class (an inherited method from the metaclass Model). model1.run() will call the run method of the model1 object, which will generate an instance of model MCMC simulation, with 2000 warm-up iterations for the MCMC algorithm, used to tune the parameters of the MCMC algorithm to improve efficiency of the sampling process. From the posterior distribution of the model parameters, 1000 samples will be drawn and used to estimate the posterior distributions and compute summary statistics. Observed data is provided to the model using the sim_data object previously generated. mcmc_args provides additional arguments for the MCMC algorithm.

import jax

model1.run(
    num_warmup=2000,
    num_samples=1000,
    data_observed_infections=sim_data.observed_infections,
    rng_key=jax.random.key(54),
    mcmc_args=dict(progress_bar=False, num_chains=2),
)

Now, let’s investigate the output, particularly the posterior distribution of the \(\mathcal{R}(t)\) estimates:

out = model1.plot_posterior(var="Rt")

image2

We can use the get_samples method to extract samples from the model

Rt_samp = model1.mcmc.get_samples()["Rt"]
latent_infection_samp = model1.mcmc.get_samples()["all_latent_infections"]

We can also convert the fitted model to ArviZ InferenceData object and use ArviZ package to extarct samples, calculate statistics, create model diagnostics and visualizations.

import arviz as az

idata = az.from_numpyro(model1.mcmc)

and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number \(\mathcal{R}(t)\).

diagnostic_stats_summary = az.summary(
    idata.posterior["Rt"][::, ::, 4:],  # ignore nan padding
    kind="diagnostics",
)

print(diagnostic_stats_summary)
        mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
Rt[4]       0.001    0.001    2595.0    1781.0   1.00
Rt[5]       0.001    0.001    2711.0    1710.0   1.00
Rt[6]       0.001    0.001    2741.0    1662.0   1.00
Rt[7]       0.001    0.001    2826.0    1625.0   1.00
Rt[8]       0.001    0.001    2908.0    1479.0   1.00
Rt[9]       0.001    0.001    3410.0    1478.0   1.00
Rt[10]      0.001    0.001    3674.0    1605.0   1.00
Rt[11]      0.001    0.001    3509.0    1315.0   1.00
Rt[12]      0.001    0.000    3957.0    1228.0   1.00
Rt[13]      0.001    0.000    4078.0    1212.0   1.00
Rt[14]      0.001    0.000    4022.0    1506.0   1.00
Rt[15]      0.001    0.000    4218.0    1654.0   1.00
Rt[16]      0.001    0.000    4347.0    1791.0   1.00
Rt[17]      0.001    0.000    4017.0    1452.0   1.00
Rt[18]      0.001    0.000    3697.0    1467.0   1.00
Rt[19]      0.001    0.001    3121.0    1292.0   1.00
Rt[20]      0.001    0.000    3360.0    1128.0   1.00
Rt[21]      0.001    0.000    3860.0    1756.0   1.00
Rt[22]      0.001    0.000    3601.0    1697.0   1.00
Rt[23]      0.001    0.000    4003.0    1759.0   1.00
Rt[24]      0.001    0.000    3441.0    1356.0   1.00
Rt[25]      0.001    0.000    3275.0    1644.0   1.00
Rt[26]      0.001    0.000    3400.0    1441.0   1.01
Rt[27]      0.001    0.000    3455.0    1590.0   1.00
Rt[28]      0.001    0.000    4196.0    1454.0   1.00
Rt[29]      0.001    0.000    3952.0    1554.0   1.00
Rt[30]      0.001    0.000    3363.0    1660.0   1.00
Rt[31]      0.001    0.000    3338.0    1661.0   1.00
Rt[32]      0.001    0.000    4070.0    1679.0   1.00
Rt[33]      0.001    0.000    3552.0    1855.0   1.00
Rt[34]      0.001    0.000    3785.0    1714.0   1.00
Rt[35]      0.001    0.000    3359.0    1811.0   1.00
Rt[36]      0.001    0.001    3309.0    1562.0   1.00
Rt[37]      0.001    0.001    3085.0    1555.0   1.00
Rt[38]      0.001    0.001    3578.0    1677.0   1.00
Rt[39]      0.001    0.001    3168.0    1889.0   1.00

Below we use plot_trace to inspect the trace of the first 10 inferred \(\mathcal{R}(t)\) values.

plt.rcParams["figure.constrained_layout.use"] = True

az.plot_trace(
    idata.posterior,
    var_names=["Rt"],
    coords={"Rt_dim_0": np.arange(4, 14)},
    compact=False,
)
plt.show()

image3

We inspect the posterior distribution of \(\mathcal{R}(t)\) by plotting the 90% and 50% highest density intervals:

x_data = idata.posterior["Rt_dim_0"][4:]
y_data = idata.posterior["Rt"][::, ::, 4:]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add mean of the posterior to the figure
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$\\mathcal{R}(t)$", fontsize=10)
plt.show()

image4

and latent infections:

x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# plot the posterior median
median_ts = y_data.median(dim=["chain", "draw"])
axes.plot(x_data, median_ts, color="C0", label="Median")

axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10)
plt.show()

image5