Source code for pyrenew.math
# -*- coding: utf-8 -*-
"""
Helper functions for doing analytical
and/or numerical calculations about
a given renewal process.
"""
from __future__ import annotations
import jax.numpy as jnp
from jax.typing import ArrayLike
from pyrenew.distutil import validate_discrete_dist_vector
[docs]
def get_leslie_matrix(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Create the Leslie matrix
corresponding to a basic
renewal process with the
given R value and discrete
generation interval pmf
vector.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
ArrayLike
The Leslie matrix for the
renewal process, as a jax array.
"""
validate_discrete_dist_vector(generation_interval_pmf)
gen_int_len = generation_interval_pmf.size
aging_matrix = jnp.hstack(
[
jnp.identity(gen_int_len - 1),
jnp.zeros(gen_int_len - 1)[..., jnp.newaxis],
]
)
return jnp.vstack([R * generation_interval_pmf, aging_matrix])
[docs]
def get_asymptotic_growth_rate_and_age_dist(
R: float, generation_interval_pmf: ArrayLike
) -> tuple[float, ArrayLike]:
"""
Get the asymptotic per-timestep growth
rate of the renewal process (the dominant
eigenvalue of its Leslie matrix) and the
associated stable age distribution
(a normalized eigenvector associated to
that eigenvalue).
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
tuple[float, ArrayLike]
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.
Raises
------
ValueError
If an age distribution vector with non-zero imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
d = jnp.argmax(jnp.abs(eigenvals)) # index of dominant eigenvalue
d_vec, d_val = eigenvecs[:, d], eigenvals[d]
d_vec_real, d_val_real = jnp.real(d_vec), jnp.real(d_val)
if not all(d_vec_real == d_vec):
raise ValueError(
"get_asymptotic_growth_rate_and_age_dist() "
"produced an age distribution vector with "
"non-zero imaginary part. "
"Check your generation interval distribution "
"vector and R value"
)
if not d_val_real == d_val:
raise ValueError(
"get_asymptotic_growth_rate_and_age_dist() "
"produced an asymptotic growth rate with "
"non-zero imaginary part. "
"Check your generation interval distribution "
"vector and R value"
)
d_vec_norm = d_vec_real / jnp.sum(d_vec_real)
return d_val_real, d_vec_norm
[docs]
def get_stable_age_distribution(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Get the stable age distribution for a
renewal process with a given value of
R and a given discrete generation
interval probability mass vector.
This function computes that stable age
distribution by finding and then normalizing
an eigenvector associated to the dominant
eigenvalue of the renewal process's
Leslie matrix.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
ArrayLike
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
1
]
[docs]
def get_asymptotic_growth_rate(
R: float, generation_interval_pmf: ArrayLike
) -> float:
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
R and a given discrete generation interval
probability mass vector.
This function computes that growth rate
finding the dominant eigenvalue of the
renewal process's Leslie matrix.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
float
The asymptotic growth rate of the renewal process,
as a jax float.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
]