Getting started with pyrenew

Getting started with pyrenew#

pyrenew is a flexible tool for simulating and making statistical inferences from epidemiologic models, with an emphasis on renewal models. Built on numpyro, pyrenew provides core components for model building and pre-defined models for processing various observational processes. This document illustrates how pyrenew can be used to build a basic renewal model.

The fundamentals#

pyrenew’s core components are the metaclasses RandomVariable and Model (in Python, a metaclass is a class whose instances are also classes, where a class is a template for making objects). Within the pyrenew package, a RandomVariable is a quantity that models can estimate and sample from, including deterministic quantities. The benefit of this design is that the definition of the sample() function can be arbitrary, allowing the user to either sample from a distribution using numpyro.sample(), compute fixed quantities (like a mechanistic equation), or return a fixed value (like a pre-computed PMF.) For instance, when estimating a PMF, the RandomVariable sampling function may roughly be defined as:

# define a new class called MyRandVar that inherits from the RandomVariable class
class MyRandVar(RandomVariable):
    #define a method called sample that returns an object of type ArrayLike
    def sample(...) -> ArrayLike:
        # calls sample function from NumPyro package
        return numpyro.sample(...)

Whereas, in some other cases, we may instead use a fixed quantity for that variable (like a pre-computed PMF), where the RandomVariable’s sample function could instead be defined as:

# instead define MyRandVar to still inherit from the RandVariable class
class MyRandVar(RandomVariable):
    #define sample method that still returns an ArrayLike object
    def sample(...) -> ArrayLike:
        #sampling method is a pre-computed PMF, a JAX NumPy array with explicit elements
        return jax.numpy.array([0.2, 0.7, 0.1])

Thus, when a Model samples from MyRandVar, it could be either adding random variables to be estimated (first case) or just retrieving some quantity needed for other calculations (second case.)

The Model metaclass provides basic functionality for estimating and simulation. Like RandomVariable, the Model metaclass has a sample() method that defines the model structure. Ultimately, models can be nested (or inherited), providing a straightforward way to add layers of complexity.

‘Hello world’ model#

This section will show the steps to build a simple renewal model featuring a latent infection process, a random walk Rt process, and an observation process for the reported infections.

We start by loading the needed components to build a basic renewal model:

import jax.numpy as jnp
import numpy as np
import numpyro as npro
import numpyro.distributions as dist
from pyrenew.process import RtRandomWalkProcess
from pyrenew.latent import (
    Infections,
    InfectionSeedingProcess,
    SeedInfectionsZeroPad,
)
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import DistributionalRV
import pyrenew.transformation as t

The pyrenew package models the real-time reproductive number \(R_t\), the ratio of new infections at time \(t\) to previous infections at some time \(t-s\), as a renewal process model. Our basic renewal process model defines five components:

  1. generation interval, the times between infections

  2. initial infections, occurring prior to time \(t = 0\)

  3. \(R_t\), the real-time reproductive number,

  4. latent infections, i.e., those infections which are known to exist but are not observed (or not observable), and

  5. observed infections, a subset of underlying true infections that are reported, perhaps via hospital admissions, physician’s office visits, or routine biosurveillance.

To initialize these five components within the renewal modeling framework, we estimate each component with:

  1. In this example, the generation interval is not estimated but passed as a deterministic instance of RandomVariable

  2. an instance of the InfectionSeedingProcess class, where the number of latent infections immediately before the renewal process begins follows a log-normal distribution with mean = 0 and standard deviation = 1. By specifying SeedInfectionsZeroPad, the latent infections before this time are assumed to be 0.

  3. an instance of the RtRandomWalkProcess class with default values

  4. an instance of the Infections class with default values, and

  5. an instance of the PoissonObservation class with default values

# (1) The generation interval (deterministic)
pmf_array = jnp.array([0.25, 0.25, 0.25, 0.25])
gen_int = DeterministicPMF(pmf_array, name="gen_int")

# (2) Initial infections (inferred with a prior)
I0 = InfectionSeedingProcess(
    "I0_seeding",
    DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"),
    SeedInfectionsZeroPad(pmf_array.size),
)

# (3) The random process for Rt
rt_proc = RtRandomWalkProcess(
    Rt0_dist=dist.TruncatedNormal(loc=1.2, scale=0.2, low=0),
    Rt_transform=t.ExpTransform().inv,
    Rt_rw_dist=dist.Normal(0, 0.025),
)

# (4) Latent infection process (which will use 1 and 2)
latent_infections = Infections()

# (5) The observed infections process (with mean at the latent infections)
observation_process = PoissonObservation()

With these five pieces, we can build the basic renewal model as an instance of the RtInfectionsRenewalModel class:

model1 = RtInfectionsRenewalModel(
    gen_int_rv=gen_int,
    I0_rv=I0,
    Rt_process_rv=rt_proc,
    latent_infections_rv=latent_infections,
    infection_obs_process_rv=observation_process,
)

The following diagram summarizes how the modules interact via composition; notably, gen_int, I0, rt_proc, latent_infections, and observed_infections are instances of RandomVariable, which means these can be easily replaced to generate a different instance of the RtInfectionsRenewalModel class:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart TB genint["(1) gen_int\n(DetermnisticPMF)"] i0["(2) I0\n(InfectionSeedingProcess)"] rt["(3) rt_proc\n(RtRandomWalkProcess)"] inf["(4) latent_infections\n(Infections)"] obs["(5) observation_process\n(PoissonObservation)"] model1["model1\n(RtInfectionsRenewalModel)"] i0-->|Composes|model1 genint-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1

Using numpyro, we can simulate data using the sample() member function of RtInfectionsRenewalModel. The sample() method of the RtInfectionsRenewalModel class returns a list composed of the Rt and infections sequences, called sim_data:

np.random.seed(223)
with npro.handlers.seed(rng_seed=np.random.randint(1, 60)):
    sim_data = model1.sample(n_timepoints_to_simulate=30)

sim_data
RtInfectionsRenewalSample(Rt=[      nan       nan       nan       nan 1.2022278 1.2111099 1.2325984
 1.2104921 1.2023039 1.1970979 1.2384264 1.2423582 1.245498  1.241344
 1.2081108 1.1938375 1.2711959 1.3189521 1.3054799 1.3317624 1.3068875
 1.3177233 1.2765354 1.3001204 1.3273207 1.3038    1.2788019 1.2726372
 1.2690852 1.3244348 1.263264  1.2540829 1.2211404 1.247711 ], latent_infections=[ 0.         0.         0.         7.7240214  2.3215084  3.0415602
  4.0327816  5.180868   4.381411   4.978916   5.750626   6.3024273
  6.66758    7.354823   7.8755097  8.416656   9.633939  10.973987
 12.043082  13.673094  15.135098  17.072838  18.485544  20.921074
 23.76387   26.155312  28.5575    31.624323  34.93189   40.15323
 42.71946   46.849056  50.266296  56.14326  ], observed_infections=[ 0  0  0  0  2  3  5  5  1  3  4  7  6  8  7  7 18  9 10 11 16 19 14 24
 17 34 34 25 38 39 43 52 50 51])

To understand what has been accomplished here, visualize an \(R_t\) sample path (left panel) and infections over time (right panel):

import matplotlib.pyplot as plt

fig, axs = plt.subplots(1, 2)

# Rt plot
axs[0].plot(sim_data.Rt)
axs[0].set_ylabel("Rt")

# Infections plot
axs[1].plot(sim_data.observed_infections)
axs[1].set_ylabel("Infections")

fig.suptitle("Basic renewal model")
fig.supxlabel("Time")
plt.tight_layout()
plt.show()

image1

To fit the model, we can use the run() method of the RtInfectionsRenewalModel class (an inherited method from the metaclass Model). model1.run() will call the run method of the model1 object, which will generate an instance of model MCMC simulation, with 2000 warm-up iterations for the MCMC algorithm, used to tune the parameters of the MCMC algorithm to improve efficiency of the sampling process. From the posterior distribution of the model parameters, 1000 samples will be drawn and used to estimate the posterior distributions and compute summary statistics. Observed data is provided to the model using the sim_data object previously generated. mcmc_args provides additional arguments for the MCMC algorithm.

import jax

model1.run(
    num_warmup=2000,
    num_samples=1000,
    data_observed_infections=sim_data.observed_infections,
    rng_key=jax.random.PRNGKey(54),
    mcmc_args=dict(progress_bar=False, num_chains=2),
)
/home/runner/work/multisignal-epi-inference/multisignal-epi-inference/model/src/pyrenew/metaclass.py:304: UserWarning: There are not enough devices to run parallel chains: expected 2 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(2)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  self.mcmc = MCMC(

Now, let’s investigate the output, particularly the posterior distribution of the \(R_t\) estimates:

out = model1.plot_posterior(var="Rt")

image2

We can use ArviZ package to create model diagnostics and visualizations. We start by converting the fitted model to ArviZ InferenceData object:

import arviz as az

idata = az.from_numpyro(model1.mcmc)

and use the InferenceData to compute the model-fit diagnostics. Here, we show diagnostic summary for the first 10 effective reproduction number \(R_t\).

diagnostic_stats_summary = az.summary(
    idata.posterior["Rt"],
    kind="diagnostics",
)

print(diagnostic_stats_summary[:10])
       mcse_mean  mcse_sd  ess_bulk  ess_tail  r_hat
Rt[0]      0.003    0.002    1309.0    1339.0    1.0
Rt[1]      0.003    0.002    1331.0    1264.0    1.0
Rt[2]      0.003    0.002    1340.0    1395.0    1.0
Rt[3]      0.003    0.002    1421.0    1335.0    1.0
Rt[4]      0.002    0.002    1459.0    1464.0    1.0
Rt[5]      0.002    0.002    1597.0    1303.0    1.0
Rt[6]      0.002    0.002    1539.0    1454.0    1.0
Rt[7]      0.002    0.001    1707.0    1529.0    1.0
Rt[8]      0.002    0.001    1494.0    1674.0    1.0
Rt[9]      0.002    0.001    1951.0    1517.0    1.0

Below we use plot_trace to inspect the trace of the first 10 \(R_t\) estimates.

plt.rcParams["figure.constrained_layout.use"] = True

az.plot_trace(
    idata.posterior,
    var_names=["Rt"],
    coords={"Rt_dim_0": np.arange(10)},
    compact=False,
);

image3

We inspect the posterior distribution of \(R_t\) by plotting the 90% and 50% highest density intervals:

x_data = idata.posterior["Rt_dim_0"]
y_data = idata.posterior["Rt"]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add mean of the posterior to the figure
mean_Rt = np.mean(idata.posterior["Rt"], axis=1)
axes.plot(x_data, mean_Rt[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Effective Reproduction Number", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("$R_t$", fontsize=10);

image4

and latent infections:

x_data = idata.posterior["all_latent_infections_dim_0"]
y_data = idata.posterior["all_latent_infections"]

fig, axes = plt.subplots(figsize=(6, 5))
az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.9,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.3},
    ax=axes,
)

az.plot_hdi(
    x_data,
    y_data,
    hdi_prob=0.5,
    color="C0",
    smooth=False,
    fill_kwargs={"alpha": 0.6},
    ax=axes,
)

# Add mean of the posterior to the figure
mean_latent_infection = np.mean(
    idata.posterior["all_latent_infections"], axis=1
)
axes.plot(x_data, mean_latent_infection[0], color="C0", label="Mean")
axes.legend()
axes.set_title("Posterior Latent Infections", fontsize=10)
axes.set_xlabel("Time", fontsize=10)
axes.set_ylabel("Latent Infections", fontsize=10);

image5

Architecture of pyrenew#

pyrenew leverages numpyro’s flexibility to build models via composition. As a principle, most objects in pyrenew can be treated as random variables we can sample. At the top-level pyrenew has two metaclasses from which most objects inherit: RandomVariable and Model. From them, the following four sub-modules arise:

  • The process sub-module,

  • The deterministic sub-module,

  • The observation sub-module,

  • The latent sub-module, and

  • The models sub-module

The first four are collections of instances of RandomVariable, and the last is a collection of instances of Model. The following diagram shows a detailed view of how metaclasses, modules, and classes interact to create the RtInfectionsRenewalModel instantiated in the previous section:

%%| label: overview-of-RtInfectionsRenewalModel %%| include: true flowchart LR rand((RandomVariable\nmetaclass)) models((Model\nmetaclass)) subgraph observations[Observations module] obs["observation_process\n(PoissonObservation)"] end subgraph latent[Latent module] inf["latent_infections\n(Infections)"] i0["I0\n(DistributionalRV)"] end subgraph process[Process module] rt["rt_proc\n(RtRandomWalkProcess)"] end subgraph deterministic[Deterministic module] detpmf["gen_int\n(DeterministicPMF)"] end subgraph model[Model module] model1["model1\n(RtInfectionsRenewalModel)"] end rand-->|Inherited by|observations rand-->|Inherited by|process rand-->|Inherited by|latent rand-->|Inherited by|deterministic models-->|Inherited by|model detpmf-->|Composes|model1 i0-->|Composes|model1 rt-->|Composes|model1 obs-->|Composes|model1 inf-->|Composes|model1