Source code for pyrenew.latent.infection_functions

# numpydoc ignore=GL08

from __future__ import annotations

import jax
import jax.numpy as jnp
from jax.typing import ArrayLike

from pyrenew.convolve import new_convolve_scanner, new_double_convolve_scanner
from pyrenew.transformation import ExpTransform, IdentityTransform


[docs] def compute_infections_from_rt( I0: ArrayLike, Rt: ArrayLike, reversed_generation_interval_pmf: ArrayLike, ) -> jnp.ndarray: """ Generate infections according to a renewal process with a time-varying reproduction number :math:`\\mathcal{R}(t)` Parameters ---------- I0 : ArrayLike Array of initial infections of the same length as the generation interval pmf vector. Rt : ArrayLike Timeseries of :math:`\\mathcal{R}(t)` values reversed_generation_interval_pmf : ArrayLike discrete probability mass vector representing the generation interval of the infection process, where the final entry represents an infection 1 time unit in the past, the second-to-last entry represents an infection two time units in the past, etc. Returns ------- jnp.ndarray The timeseries of infections. """ incidence_func = new_convolve_scanner( reversed_generation_interval_pmf, IdentityTransform() ) latest, all_infections = jax.lax.scan(f=incidence_func, init=I0, xs=Rt) return all_infections
[docs] def logistic_susceptibility_adjustment( I_raw_t: float, frac_susceptible: float, n_population: float, ) -> float: """ Apply the logistic susceptibility adjustment to a potential new incidence ``I_raw_t`` proposed in equation 6 of `Bhatt et al 2023 <https://doi.org/10.1093/jrsssa/qnad030>`_. Parameters ---------- I_raw_t : float The "unadjusted" incidence at time t, i.e. the incidence given an infinite number of available susceptible individuals. frac_susceptible : float fraction of remaining susceptible individuals in the population n_population : float Total size of the population. Returns ------- float The adjusted value of :math:`I(t)`. """ approx_frac_infected = 1 - jnp.exp(-I_raw_t / n_population) return n_population * frac_susceptible * approx_frac_infected
[docs] def compute_infections_from_rt_with_feedback( I0: ArrayLike, Rt_raw: ArrayLike, infection_feedback_strength: ArrayLike, reversed_generation_interval_pmf: ArrayLike, reversed_infection_feedback_pmf: ArrayLike, ) -> tuple: r""" Generate infections according to a renewal process with infection feedback (generalizing `Asher 2018 <https://doi.org/10.1016/j.epidem.2017.02.009>`_). Parameters ---------- I0 : ArrayLike Array of initial infections of the same length as the generation interval pmf vector. Rt_raw : ArrayLike Timeseries of raw :math:`\mathcal{R}(t)` values not adjusted by infection feedback infection_feedback_strength : ArrayLike Strength of the infection feedback. Either a scalar (constant feedback strength in time) or a vector representing the infection feedback strength at a given point in time. reversed_generation_interval_pmf : ArrayLike discrete probability mass vector representing the generation interval of the infection process, where the final entry represents an infection 1 time unit in the past, the second-to-last entry represents an infection two time units in the past, etc. reversed_infection_feedback_pmf : ArrayLike discrete probability mass vector representing the infection feedback process, where the final entry represents the relative contribution to infection feedback from infections that occurred 1 time unit in the past, the second-to-last entry represents the contribution from infections that occurred 2 time units in the past, etc. Returns ------- tuple A tuple ``(infections, Rt_adjusted)``, where ``Rt_adjusted`` is the infection-feedback-adjusted timeseries of the reproduction number :math:`\mathcal{R}(t)` and ``infections`` is the incident infection timeseries. Notes ----- This function implements the following renewal process: .. math:: \begin{aligned} I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\ \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\ \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) \end{aligned} where :math:`\mathcal{R}(t)` is the reproductive number, :math:`\gamma(t)` is the infection feedback strength, :math:`T_g` is the max-length of the generation interval, :math:`\mathcal{R}^u(t)` is the raw reproduction number, :math:`f(t)` is the infection feedback pmf, and :math:`T_f` is the max-length of the infection feedback pmf. Note that negative :math:`\gamma(t)` implies that recent incident infections reduce :math:`\mathcal{R}(t)` below its raw value in the absence of feedback, while positive :math:`\gamma` implies that recent incident infections *increase* :math:`\mathcal{R}(t)` above its raw value, and :math:`\gamma(t)=0` implies no feedback. In general, negative :math:`\gamma` is the more common modeling choice, as it can be used to model susceptible depletion, reductions in contact rate due to awareness of high incidence, et cetera. """ feedback_scanner = new_double_convolve_scanner( arrays_to_convolve=( reversed_infection_feedback_pmf, reversed_generation_interval_pmf, ), transforms=(ExpTransform(), IdentityTransform()), ) latest, infs_and_R_adj = jax.lax.scan( f=feedback_scanner, init=I0, xs=(infection_feedback_strength, Rt_raw), ) infections, R_adjustment = infs_and_R_adj Rt_adjusted = R_adjustment * Rt_raw return infections, Rt_adjusted