"""
Helper functions for doing analytical
and/or numerical calculations.
"""
from __future__ import annotations
import jax.numpy as jnp
from jax.lax import broadcast_shapes, scan
from jax.typing import ArrayLike
from pyrenew.distutil import validate_discrete_dist_vector
[docs]
def get_leslie_matrix(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Create the Leslie matrix
corresponding to a basic
renewal process with the
given :math:`\\mathcal{R}`
value and discrete
generation interval pmf
vector.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
ArrayLike
The Leslie matrix for the
renewal process, as a jax array.
"""
validate_discrete_dist_vector(generation_interval_pmf)
gen_int_len = generation_interval_pmf.size
aging_matrix = jnp.hstack(
[
jnp.identity(gen_int_len - 1),
jnp.zeros(gen_int_len - 1)[..., jnp.newaxis],
]
)
return jnp.vstack([R * generation_interval_pmf, aging_matrix])
[docs]
def get_asymptotic_growth_rate_and_age_dist(
R: float, generation_interval_pmf: ArrayLike
) -> tuple[float, ArrayLike]:
"""
Get the asymptotic per-timestep growth
rate of the renewal process (the dominant
eigenvalue of its Leslie matrix) and the
associated stable age distribution
(a normalized eigenvector associated to
that eigenvalue).
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
tuple[float, ArrayLike]
A tuple consisting of the asymptotic growth rate of
the process, as jax float, and the stable age distribution
of the process, as a jax array probability vector of the
same shape as the generation interval probability vector.
Raises
------
ValueError
If an age distribution vector with non-zero
imaginary part is produced.
"""
L = get_leslie_matrix(R, generation_interval_pmf)
eigenvals, eigenvecs = jnp.linalg.eig(L)
d = jnp.argmax(jnp.abs(eigenvals)) # index of dominant eigenvalue
d_vec, d_val = eigenvecs[:, d], eigenvals[d]
d_vec_real, d_val_real = jnp.real(d_vec), jnp.real(d_val)
if not all(d_vec_real == d_vec):
raise ValueError(
"get_asymptotic_growth_rate_and_age_dist() "
"produced an age distribution vector with "
"non-zero imaginary part. "
"Check your generation interval distribution "
"vector and R value"
)
if not d_val_real == d_val:
raise ValueError(
"get_asymptotic_growth_rate_and_age_dist() "
"produced an asymptotic growth rate with "
"non-zero imaginary part. "
"Check your generation interval distribution "
"vector and R value"
)
d_vec_norm = d_vec_real / jnp.sum(d_vec_real)
return d_val_real, d_vec_norm
[docs]
def get_stable_age_distribution(
R: float, generation_interval_pmf: ArrayLike
) -> ArrayLike:
"""
Get the stable age distribution for a
renewal process with a given value of
R and a given discrete generation
interval probability mass vector.
This function computes that stable age
distribution by finding and then normalizing
an eigenvector associated to the dominant
eigenvalue of the renewal process's
Leslie matrix.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
ArrayLike
The stable age distribution for the
process, as a jax array probability vector of
the same shape as the generation interval
probability vector.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
1
]
[docs]
def get_asymptotic_growth_rate(
R: float, generation_interval_pmf: ArrayLike
) -> float:
"""
Get the asymptotic per timestep growth rate
for a renewal process with a given value of
:math:`\\mathcal{R}` and a given discrete
generation interval probability mass vector.
This function computes that growth rate
finding the dominant eigenvalue of the
renewal process's Leslie matrix.
Parameters
----------
R : float
The reproduction number of the renewal process
generation_interval_pmf: ArrayLike
The discrete generation interval probability
mass vector of the renewal process
Returns
-------
float
The asymptotic growth rate of the renewal process,
as a jax float.
"""
return get_asymptotic_growth_rate_and_age_dist(R, generation_interval_pmf)[
0
]
[docs]
def integrate_discrete(
init_diff_vals: ArrayLike, highest_order_diff_vals: ArrayLike
) -> ArrayLike:
"""
Integrate (de-difference) the differenced process,
obtaining the process values :math:`X(t=0), X(t=1), ... X(t)`
from the :math:`n^{th}` differences and a set of
initial process / difference values
:math:`X(t=0), X^1(t=1), X^2(t=2), ... X^{(n-1)}(t=n-1)`,
where :math:`X^k(t)` is the value of the :math:`n^{th}`
difference at index :math:`t` of the process,
obtaining a sequence of length equal to the length of
the provided `highest_order_diff_vals` vector plus
the order of the process.
Parameters
----------
init_diff_vals : ArrayLike
Values of
:math:`X(t=0), X^1(t=1), X^2(t=2) ... X^{(n-1)}(t=n-1)`.
highest_order_diff_vals : ArrayLike
Array of differences at the highest order of
differencing, i.e. the order of the overall process,
starting with :math:`X^{n}(t=n)`
Returns
-------
ArrayLike
The integrated (de-differenced) sequence of values,
of length n_diffs + order, where n_diffs is the
number of highest_order_diff_vals and order is the
order of the process.
"""
inits_by_order = jnp.atleast_1d(init_diff_vals)
highest_diffs = jnp.atleast_1d(highest_order_diff_vals)
order = inits_by_order.shape[0]
n_diffs = highest_diffs.shape[0]
try:
batch_shape = broadcast_shapes(
highest_diffs.shape[1:], inits_by_order.shape[1:]
)
except Exception as e:
raise ValueError(
"Non-time dimensions "
"(i.e. dimensions after the first) "
"for highest_order_diff_vals and init_diff_vals "
"must be broadcastable together. "
"Got highest_order_diff_vals of shape "
f"{highest_diffs.shape} and "
"init_diff_vals of shape "
f"{inits_by_order.shape}"
) from e
highest_diffs = jnp.broadcast_to(highest_diffs, (n_diffs,) + batch_shape)
inits_by_order = jnp.broadcast_to(inits_by_order, (order,) + batch_shape)
highest_diffs = jnp.concatenate(
[jnp.zeros_like(inits_by_order), highest_diffs],
axis=0,
)
scan_arrays = (
jnp.arange(start=order - 1, stop=-1, step=-1),
jnp.flip(inits_by_order, axis=0),
)
integrated, _ = scan(
f=_integrate_one_step, init=highest_diffs, xs=scan_arrays
)
return integrated
def _integrate_one_step(
current_diffs: ArrayLike,
next_order_and_init: tuple[int, ArrayLike],
) -> tuple[ArrayLike, None]:
"""
Perform one step of integration
(de-differencing) for integrate_discrete().
Helper function passed to :func:`jax.lax.scan()`.
Parameters
----------
current_diffs: ArrayLike
Array of differences at the current
de-differencing order
next_order_and_init: tuple
Tuple containing with two entries.
First entry: the next order of de-differencing
(the current order - 1) as an integer.
Second entry: the initial value at
that the next order of de-differencing
as an ArrayLike of appropriate shape.
Returns
-------
tuple[ArrayLike, None]
A tuple whose first entry contains the
values at the next order of (de)-differencing
and whose second entry is None.
"""
next_order, next_init = next_order_and_init
next_diffs = jnp.cumsum(
current_diffs.at[next_order, ...].set(next_init), axis=0
)
return next_diffs, None