Extending pyrenew#

This tutorial illustrates how to extend pyrenew with custom RandomVariable classes. We will use the InfectionsWithFeedback class as an example. The InfectionsWithFeedback class is a RandomVariable that models the number of infections at time \(t\) as a function of the number of infections at time \(t - \tau\) and the reproduction number at time \(t\). The reproduction number at time \(t\) is a function of the unadjusted reproduction number at time \(t - \tau\) and the number of infections at time \(t - \tau\):

\[ \begin{align*} I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\ \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) \end{align*} \]

Where \(\mathcal{R}^u(t)\) is the unadjusted reproduction number, \(g(t)\) is the generation interval, \(\gamma(t)\) is the infection feedback strength, and \(f(t)\) is the infection feedback pmf.

The expected outcome#

Before we start, let’s simulate the model with the original InfectionsWithFeedback class. To keep it simple, we will simulate the model with no observation process, in other words, only with latent infections. The following code-chunk loads the required libraries and defines the model components:

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RandomWalk
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew.latent import (
    InfectionInitializationProcess,
    InitializeInfectionsExponentialGrowth,
)
import pyrenew.transformation as t

The following code-chunk defines the model components. Notice that for both the generation interval and the infection feedback, we use a deterministic PMF with equal probabilities:

gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1])

gen_int = DeterministicPMF(name="gen_int", value=gen_int_array)
feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)


I0 = InfectionInitializationProcess(
    "I0_initialization",
    DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)),
    InitializeInfectionsExponentialGrowth(
        gen_int_array.size,
        DeterministicVariable(name="rate", value=0.05),
    ),
)

latent_infections = InfectionsWithFeedback(
    infection_feedback_strength=feedback_strength,
    infection_feedback_pmf=gen_int,
)


class MyRt(RandomVariable):
    def validate(self):
        pass

    def sample(self, n: int, **kwargs) -> tuple:
        sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

        rt_rv = TransformedVariable(
            name="log_rt_random_walk",
            base_rv=RandomWalk(
                name="log_rt",
                step_rv=DistributionalVariable(
                    name="rw_step_rv", distribution=dist.Normal(0, 0.025)
                ),
            ),
            transforms=t.ExpTransform(),
        )
        rt_init_rv = DistributionalVariable(
            name="init_log_rt", distribution=dist.Normal(0, 0.2)
        )
        init_rt = rt_init_rv.sample()

        return rt_rv.sample(n=n, init_vals=init_rt, **kwargs)

With all the components defined, we can build the model:

model0 = RtInfectionsRenewalModel(
    gen_int_rv=gen_int,
    I0_rv=I0,
    latent_infections_rv=latent_infections,
    Rt_process_rv=MyRt(),
    infection_obs_process_rv=None,
)

And simulate from it:

# Sampling and fitting model 0 (with no obs for infections)
with numpyro.handlers.seed(rng_seed=223):
    model0_samp = model0.sample(n_datapoints=30)
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(model0_samp.latent_infections)
ax.set_xlabel("Time")
ax.set_ylabel("Infections")
plt.show()

Figure 1: Simulated infections with no observation process

Pyrenew’s random variable class#

Fundamentals#

All instances of PyRenew’s RandomVariable should have at least three functions: __init__(), validate(), and sample(). The __init__() function is the constructor and initializes the class. The validate() function checks if the class is correctly initialized. Finally, the sample() method contains the core of the class; it should return a tuple or named tuple. The following is a minimal example of a RandomVariable class based on numpyro.distributions.Normal:

from pyrenew.metaclass import RandomVariable


class MyNormal(RandomVariable):
    def __init__(self, loc, scale):
        self.validate(scale)
        self.loc = loc
        self.scale = scale
        return None

    @staticmethod
    def validate(self):
        if self.scale <= 0:
            raise ValueError("Scale must be positive")
        return None

    def sample(self, **kwargs):
        return (dist.Normal(loc=self.loc, scale=self.scale),)

The @staticmethod decorator exposes the validate function to be used outside the class. Next, we show how to build a more complex RandomVariable class; the InfectionsWithFeedback class.

The InfectionsWithFeedback class#

Although returning namedtuples is not strictly required, they are the recommended return type, as they make the code more readable. The following code-chunk shows how to create a named tuple for the InfectionsWithFeedback class:

from collections import namedtuple

# Creating a tuple to store the output
InfFeedbackSample = namedtuple(
    typename="InfFeedbackSample",
    field_names=["post_initialization_infections", "rt"],
    defaults=(None, None),
)

The next step is to create the actual class. The bulk of its implementation lies in the function pyrenew.latent.compute_infections_from_rt_with_feedback(). We will also use the pyrenew.arrayutils.pad_edges_to_match() function to ensure the passed vectors match their lengths. The following code-chunk shows most of the implementation of the InfectionsWithFeedback class:

# Creating the class
from pyrenew.metaclass import RandomVariable
from pyrenew.latent import compute_infections_from_rt_with_feedback
from pyrenew import arrayutils as au
from jax.typing import ArrayLike
import jax.numpy as jnp


class InfFeedback(RandomVariable):
    """Latent infections"""

    def __init__(
        self,
        infection_feedback_strength: RandomVariable,
        infection_feedback_pmf: RandomVariable,
    ) -> None:
        """Constructor"""

        self.infection_feedback_strength = infection_feedback_strength
        self.infection_feedback_pmf = infection_feedback_pmf

        return None

    def validate(self):
        """
        Generally, this method should be more meaningful, but we will skip it for now
        """
        return None

    def sample(
        self,
        Rt: ArrayLike,
        I0: ArrayLike,
        gen_int: ArrayLike,
        **kwargs,
    ) -> tuple:
        """Sample infections with feedback"""

        # Generation interval
        gen_int_rev = jnp.flip(gen_int)

        # Baseline infections
        I0_vec = I0[-gen_int_rev.size :]

        # Sampling inf feedback strength and adjusting the shape
        inf_feedback_strength = self.infection_feedback_strength(
            **kwargs,
        )

        inf_feedback_strength = jnp.atleast_1d(inf_feedback_strength)

        inf_feedback_strength, _ = au.pad_edges_to_match(
            x=inf_feedback_strength,
            y=Rt,
        )

        # Sampling inf feedback and adjusting the shape
        inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)
        inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

        # Generating the infections with feedback
        all_infections, Rt_adj = compute_infections_from_rt_with_feedback(
            I0=I0_vec,
            Rt_raw=Rt,
            infection_feedback_strength=inf_feedback_strength,
            reversed_generation_interval_pmf=gen_int_rev,
            reversed_infection_feedback_pmf=inf_fb_pmf_rev,
        )

        # Storing adjusted Rt for future use
        numpyro.deterministic("Rt_adjusted", Rt_adj)

        # Preparing theoutput

        return InfFeedbackSample(
            post_initialization_infections=all_infections,
            rt=Rt_adj,
        )

The core of the class is implemented in the sample() method. Things to highlight from the above code:

  1. Arguments of sample: The InfFeedback class will be used within RtInfectionsRenewalModel to generate latent infections. During the sampling process, InfFeedback() will receive the reproduction number, the initial number of infections, and the generation interval. RandomVariable() calls are expected to include the **kwargs argument, even if unused.

  2. Saving computed quantities: Since Rt_adj is not generated via numpyro.sample(), we use numpyro.deterministic() to record the quantity to a site; allowing us to access it later.

  3. Return type of InfFeedback(): As said before, the sample() method should return a tuple or named tuple. In our case, we return a named tuple InfFeedbackSample with two fields: infections and rt.

latent_infections2 = InfFeedback(
    infection_feedback_strength=feedback_strength,
    infection_feedback_pmf=gen_int,
)

model1 = RtInfectionsRenewalModel(
    gen_int_rv=gen_int,
    I0_rv=I0,
    latent_infections_rv=latent_infections2,
    Rt_process_rv=MyRt(),
    infection_obs_process_rv=None,
)

# Sampling and fitting model 0 (with no obs for infections)
with numpyro.handlers.seed(rng_seed=223):
    model1_samp = model1.sample(n_datapoints=30)

Comparing model0 with model1, these two should match:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(ncols=2)
ax[0].plot(model0_samp.latent_infections)
ax[1].plot(model1_samp.latent_infections)
ax[0].set_xlabel("Time (model 0)")
ax[1].set_xlabel("Time (model 1)")
ax[0].set_ylabel("Infections")
plt.show()

Figure 2: Comparing latent infections from model 0 and model 1