# numpydoc ignore=GL08
import jax.numpy as jnp
from jax.typing import ArrayLike
import pyrenew.arrayutils as au
from pyrenew.metaclass import RandomVariable
from pyrenew.process import ARProcess, DifferencedProcess
[docs]
class RtPeriodicDiffARProcess(RandomVariable):
r"""
Periodic Rt with autoregressive first differences
Notes
-----
This class samples a periodic reproduction number R(t)
by placing an AR(1) process
on the first differences in log[R(t)]. Formally:
.. math::
\log[\mathcal{R}^\mathrm{u}(t_3)] \sim \mathrm{Normal}\left(\log[\mathcal{R}^\mathrm{u}(t_2)] \
+ \beta \left(\log[\mathcal{R}^\mathrm{u}(t_2)] - \
\log[\mathcal{R}^\mathrm{u}(t_1)]\right), \sigma_r \right)
where :math:`\mathcal{R}^\mathrm{u}(t)` is the periodic reproduction number
at time :math:`t`, :math:`\beta` is the autoregressive parameter, and
:math:`\sigma_r` is the standard deviation of the noise.
"""
def __init__(
self,
name: str,
offset: int,
period_size: int,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
ar_process_suffix: str = "_first_diff_ar_process_noise",
) -> None:
"""
Default constructor for RtPeriodicDiffARProcess class.
Parameters
----------
name : str
Name of the site.
offset : int
Relative point at which data starts, must be between 0 and
period_size - 1.
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
ar_process_suffix : str
Suffix to append to the :class:`RandomVariable`'s ``name``
when naming the :class:`RandomVariable` that represents
the underlying AR process noise.
Default "_first_diff_ar_process_noise".
Returns
-------
None
"""
self.validate(
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)
self.name = name
self.period_size = period_size
self.offset = offset
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv
self.ar_process_suffix = ar_process_suffix
self.ar_diff = DifferencedProcess(
fundamental_process=ARProcess(),
differencing_order=1,
)
return None
[docs]
@staticmethod
def validate(
log_rt_rv: any,
autoreg_rv: any,
periodic_diff_sd_rv: any,
) -> None:
"""
Validate the input parameters.
Parameters
----------
log_rt_rv : any
Log Rt prior for the first two observations.
autoreg_rv : any
Autoregressive parameter.
periodic_diff_sd_rv : any
Standard deviation of the noise.
Returns
-------
None
"""
assert isinstance(log_rt_rv, RandomVariable)
assert isinstance(autoreg_rv, RandomVariable)
assert isinstance(periodic_diff_sd_rv, RandomVariable)
return None
[docs]
def sample(
self,
duration: int,
**kwargs,
) -> ArrayLike:
"""
Samples the periodic :math:`\\mathcal{R}(t)`
with autoregressive first differences.
Parameters
----------
duration : int
Duration of the sequence.
**kwargs : dict, optional
Additional keyword arguments passed through to
internal :meth:`sample` calls, should there be any.
Returns
-------
ArrayLike
Sampled :math:`\\mathcal{R}(t)` values.
"""
# Initial sample
log_rt_rv = self.log_rt_rv(**kwargs).squeeze()
b = self.autoreg_rv(**kwargs).squeeze()
s_r = self.periodic_diff_sd_rv(**kwargs).squeeze()
# How many periods to sample?
n_periods = (duration + self.period_size - 1) // self.period_size
# Running the process
log_rt = self.ar_diff(
noise_name=f"{self.name}{self.ar_process_suffix}",
n=n_periods,
init_vals=jnp.array(log_rt_rv[0]),
autoreg=b,
noise_sd=s_r,
fundamental_process_init_vals=jnp.array(
log_rt_rv[1] - log_rt_rv[0]
),
)
return au.repeat_until_n(
data=jnp.exp(log_rt),
n_timepoints=duration,
offset=self.offset,
period_size=self.period_size,
)
[docs]
class RtWeeklyDiffARProcess(RtPeriodicDiffARProcess):
"""
Weekly Rt with autoregressive first differences.
"""
def __init__(
self,
name: str,
offset: int,
log_rt_rv: RandomVariable,
autoreg_rv: RandomVariable,
periodic_diff_sd_rv: RandomVariable,
) -> None:
"""
Default constructor for RtWeeklyDiffARProcess class.
Parameters
----------
name : str
Name of the site.
offset : int
Relative point at which data starts, must be between 0 and 6.
log_rt_rv : RandomVariable
Log Rt prior for the first two observations.
autoreg_rv : RandomVariable
Autoregressive parameter.
periodic_diff_sd_rv : RandomVariable
Standard deviation of the noise.
Returns
-------
None
"""
super().__init__(
name=name,
offset=offset,
period_size=7,
log_rt_rv=log_rt_rv,
autoreg_rv=autoreg_rv,
periodic_diff_sd_rv=periodic_diff_sd_rv,
)
return None