Skip to content

Latent

AR1

AR1(autoreg: float, innovation_sd: float = 1.0)

Bases: TemporalProcess

AR(1) process.

Each value depends on the previous value plus noise, with reversion toward a mean level. Keeps Rt bounded near a baseline — values that drift away are "pulled back" over time.

This class wraps pyrenew.process.ARProcess with a simplified, protocol-compliant interface that handles vectorization automatically.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of noise at each time step. Larger values produce more volatile trajectories; smaller values produce smoother ones.

1.0

Initialize AR(1) process.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(self, autoreg: float, innovation_sd: float = 1.0):
    """
    Initialize AR(1) process.

    Parameters
    ----------
    autoreg : float
        Autoregressive coefficient. For stationarity, |autoreg| < 1,
        but this is not enforced (use priors to constrain if needed).
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.autoreg = autoreg
    self.innovation_sd = innovation_sd
    self.ar_process = ARProcess(name="ar1")

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
146
147
148
def __repr__(self) -> str:
    """Return string representation."""
    return f"AR1(autoreg={self.autoreg}, innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "ar1",
) -> ArrayLike

Sample AR(1) trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"ar1"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "ar1",
) -> ArrayLike:
    """
    Sample AR(1) trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "ar1"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    stationary_sd = self.innovation_sd / jnp.sqrt(1 - self.autoreg**2)

    with numpyro.plate(f"{name_prefix}_init_plate", n_processes):
        init_states = numpyro.sample(
            f"{name_prefix}_init",
            dist.Normal(initial_value, stationary_sd),
        )

    trajectories = self.ar_process(
        n=n_timepoints,
        init_vals=init_states[jnp.newaxis, :],
        autoreg=jnp.full((1, n_processes), self.autoreg),
        noise_sd=self.innovation_sd,
        noise_name=f"{name_prefix}_noise",
    )

    return trajectories

BaseLatentInfectionProcess

BaseLatentInfectionProcess(
    *, name: str, gen_int_rv: RandomVariable, n_initialization_points: int
)

Bases: RandomVariable

Base class for latent infection processes with subpopulation structure.

Provides common functionality for hierarchical and partitioned infection models: - Population fraction validation and parsing (at sample time) - Standard output structure via LatentSample

All subclasses return infections as a LatentSample named tuple with fields: (aggregate, all_subpops). Observation processes are responsible for selecting which subpopulations they observe via indexing.

The constructor specifies model structure (generation interval, priors, temporal processes). Population structure (subpop_fractions) is provided at sample time, allowing a single model to be fit to multiple jurisdictions.

Parameters:

Name Type Description Default
gen_int_rv RandomVariable

Generation interval PMF

required
n_initialization_points int

Number of initialization days before day 0. Must be at least len(gen_int_rv()) to provide enough history for the renewal equation convolution.

required
Notes

Population structure (subpop_fractions) is passed to the sample() method, not the constructor. This allows a single model instance to be fit to multiple datasets with different jurisdiction structures.

When using PyrenewBuilder (recommended), n_initialization_points is computed automatically from all observation processes. When constructing latent processes directly, you must specify n_initialization_points explicitly.

Initialize base latent infection process.

Parameters:

Name Type Description Default
name str

A name for this random variable.

required
gen_int_rv RandomVariable

Generation interval PMF

required
n_initialization_points int

Number of initialization days before day 0. Must be at least len(gen_int_rv()) to provide enough history for the renewal equation convolution.

required

Raises:

Type Description
ValueError

If gen_int_rv is None or n_initialization_points is insufficient.

Source code in pyrenew/latent/base.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def __init__(
    self,
    *,
    name: str,
    gen_int_rv: RandomVariable,
    n_initialization_points: int,
) -> None:
    """
    Initialize base latent infection process.

    Parameters
    ----------
    name : str
        A name for this random variable.
    gen_int_rv : RandomVariable
        Generation interval PMF
    n_initialization_points : int
        Number of initialization days before day 0. Must be at least
        ``len(gen_int_rv())`` to provide enough history for the renewal
        equation convolution.

    Raises
    ------
    ValueError
        If gen_int_rv is None or n_initialization_points is insufficient.
    """
    super().__init__(name=name)
    if gen_int_rv is None:
        raise ValueError("gen_int_rv is required")
    self.gen_int_rv = gen_int_rv

    gen_int_length = len(self.gen_int_rv())
    if n_initialization_points < gen_int_length:
        raise ValueError(
            f"n_initialization_points must be at least the generation "
            f"interval length ({gen_int_length}), got "
            f"{n_initialization_points}"
        )
    self.n_initialization_points = n_initialization_points

get_required_lookback

get_required_lookback() -> int

Return the generation interval length for builder pattern support.

This method is used by PyrenewBuilder to compute n_initialization_points from all model components. Returns the generation interval PMF length.

Returns:

Type Description
int

Length of generation interval PMF

Source code in pyrenew/latent/base.py
275
276
277
278
279
280
281
282
283
284
285
286
287
def get_required_lookback(self) -> int:
    """
    Return the generation interval length for builder pattern support.

    This method is used by PyrenewBuilder to compute n_initialization_points
    from all model components. Returns the generation interval PMF length.

    Returns
    -------
    int
        Length of generation interval PMF
    """
    return len(self.gen_int_rv())

sample abstractmethod

sample(
    n_days_post_init: int, *, subpop_fractions: ArrayLike = None, **kwargs
) -> LatentSample

Sample latent infections for all subpopulations.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period

required
subpop_fractions ArrayLike

Population fractions for all subpopulations. Shape: (n_subpops,). Must sum to 1.0.

None
**kwargs

Additional parameters required by specific implementations

{}

Returns:

Type Description
LatentSample

Named tuple with fields: - aggregate: shape (n_total_days,) - all_subpops: shape (n_total_days, n_subpops)

where n_total_days = n_initialization_points + n_days_post_init

Source code in pyrenew/latent/base.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
@abstractmethod
def sample(
    self,
    n_days_post_init: int,
    *,
    subpop_fractions: ArrayLike = None,
    **kwargs,
) -> LatentSample:
    """
    Sample latent infections for all subpopulations.

    Parameters
    ----------
    n_days_post_init : int
        Number of days to simulate after initialization period
    subpop_fractions : ArrayLike
        Population fractions for all subpopulations.
        Shape: (n_subpops,). Must sum to 1.0.
    **kwargs
        Additional parameters required by specific implementations

    Returns
    -------
    LatentSample
        Named tuple with fields:
        - aggregate: shape (n_total_days,)
        - all_subpops: shape (n_total_days, n_subpops)

        where n_total_days = n_initialization_points + n_days_post_init
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate latent process parameters.

Subclasses must implement this method to validate all parameters specific to their implementation (e.g., temporal process parameters, I0 parameters).

Common validation (n_initialization_points, gen_int_rv) is performed in init. Population structure validation is performed at sample time.

Raises:

Type Description
ValueError

If any parameters fail validation

Source code in pyrenew/latent/base.py
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
@abstractmethod
def validate(self) -> None:
    """
    Validate latent process parameters.

    Subclasses must implement this method to validate all parameters specific
    to their implementation (e.g., temporal process parameters, I0 parameters).

    Common validation (n_initialization_points, gen_int_rv) is performed in
    __init__. Population structure validation is performed at sample time.

    Raises
    ------
    ValueError
        If any parameters fail validation
    """
    pass  # pragma: no cover

DifferencedAR1

DifferencedAR1(autoreg: float, innovation_sd: float = 1.0)

Bases: TemporalProcess

AR(1) process on first differences.

Each change in value depends on the previous change plus noise, with the rate of change reverting toward a mean. Unlike AR(1), this allows Rt to trend persistently upward or downward while the growth rate stabilizes.

This class wraps pyrenew.process.DifferencedProcess with pyrenew.process.ARProcess as the fundamental process, providing a simplified, protocol-compliant interface.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient for differences. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of noise added to changes. Larger values produce more erratic growth rates; smaller values produce smoother trends.

1.0

Initialize differenced AR(1) process.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient for differences. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __init__(self, autoreg: float, innovation_sd: float = 1.0):
    """
    Initialize differenced AR(1) process.

    Parameters
    ----------
    autoreg : float
        Autoregressive coefficient for differences. For stationarity,
        |autoreg| < 1, but this is not enforced (use priors to constrain
        if needed).
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.autoreg = autoreg
    self.innovation_sd = innovation_sd
    self.process = DifferencedProcess(
        name="diff_ar1",
        fundamental_process=ARProcess(name="diff_ar1_fundamental"),
        differencing_order=1,
    )

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
251
252
253
def __repr__(self) -> str:
    """Return string representation."""
    return f"DifferencedAR1(autoreg={self.autoreg}, innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "diff_ar1",
) -> ArrayLike

Sample differenced AR(1) trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"diff_ar1"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "diff_ar1",
) -> ArrayLike:
    """
    Sample differenced AR(1) trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "diff_ar1"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    stationary_sd = self.innovation_sd / jnp.sqrt(1 - self.autoreg**2)

    with numpyro.plate(f"{name_prefix}_init_rate_plate", n_processes):
        init_rates = numpyro.sample(
            f"{name_prefix}_init_rate",
            dist.Normal(0, stationary_sd),
        )

    trajectories = self.process(
        n=n_timepoints,
        init_vals=initial_value[jnp.newaxis, :],
        autoreg=jnp.full((1, n_processes), self.autoreg),
        noise_sd=self.innovation_sd,
        fundamental_process_init_vals=init_rates[jnp.newaxis, :],
        noise_name=f"{name_prefix}_noise",
    )

    return trajectories

GammaGroupSdPrior

GammaGroupSdPrior(
    name: str,
    sd_mean_rv: RandomVariable,
    sd_concentration_rv: RandomVariable,
    sd_min: float = 0.05,
)

Bases: RandomVariable

Gamma prior for group-level standard deviations, bounded away from zero.

Samples n_groups positive values from Gamma(concentration, rate) + sd_min.

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter in numpyro.

required
sd_mean_rv RandomVariable

RandomVariable returning the mean of the Gamma distribution.

required
sd_concentration_rv RandomVariable

RandomVariable returning the concentration (shape) parameter of Gamma.

required
sd_min float

Minimum SD value (lower bound).

0.05

Initialize gamma group SD prior.

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter.

required
sd_mean_rv RandomVariable

RandomVariable returning the mean of the Gamma distribution.

required
sd_concentration_rv RandomVariable

RandomVariable returning the concentration parameter.

required
sd_min float

Minimum SD value (lower bound).

0.05
Source code in pyrenew/latent/hierarchical_priors.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
def __init__(
    self,
    name: str,
    sd_mean_rv: RandomVariable,
    sd_concentration_rv: RandomVariable,
    sd_min: float = 0.05,
) -> None:
    """
    Initialize gamma group SD prior.

    Parameters
    ----------
    name : str
        Unique name for the sampled parameter.
    sd_mean_rv : RandomVariable
        RandomVariable returning the mean of the Gamma distribution.
    sd_concentration_rv : RandomVariable
        RandomVariable returning the concentration parameter.
    sd_min : float, default=0.05
        Minimum SD value (lower bound).
    """
    if not isinstance(sd_mean_rv, RandomVariable):
        raise TypeError(
            f"sd_mean_rv must be a RandomVariable, got {type(sd_mean_rv).__name__}. "
            "Use DeterministicVariable(name, value) to wrap a fixed value."
        )
    if not isinstance(sd_concentration_rv, RandomVariable):
        raise TypeError(
            f"sd_concentration_rv must be a RandomVariable, got {type(sd_concentration_rv).__name__}. "
            "Use DeterministicVariable(name, value) to wrap a fixed value."
        )
    if sd_min < 0:
        raise ValueError(f"sd_min must be non-negative, got {sd_min}")

    super().__init__(name=name)
    self.sd_mean_rv = sd_mean_rv
    self.sd_concentration_rv = sd_concentration_rv
    self.sd_min = sd_min

sample

sample(n_groups: int, **kwargs)

Sample group-level standard deviations.

Parameters:

Name Type Description Default
n_groups int

Number of groups.

required

Returns:

Type Description
ArrayLike

Array of shape (n_groups,) with values >= sd_min.

Source code in pyrenew/latent/hierarchical_priors.py
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def sample(self, n_groups: int, **kwargs):
    """
    Sample group-level standard deviations.

    Parameters
    ----------
    n_groups : int
        Number of groups.

    Returns
    -------
    ArrayLike
        Array of shape (n_groups,) with values >= sd_min.
    """
    sd_mean = self.sd_mean_rv()
    concentration = self.sd_concentration_rv()
    rate = concentration / sd_mean

    with numpyro.plate(f"n_{self.name}", n_groups):
        raw_sd = numpyro.sample(
            f"{self.name}_raw",
            dist.Gamma(concentration, rate),
        )

        group_sd = numpyro.deterministic(
            self.name,
            jnp.maximum(raw_sd, self.sd_min),
        )
    return group_sd

validate

validate()

Validate the random variable (no-op for this class).

Source code in pyrenew/latent/hierarchical_priors.py
138
139
140
def validate(self):
    """Validate the random variable (no-op for this class)."""
    pass

HierarchicalInfections

HierarchicalInfections(
    *,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    baseline_rt_process: TemporalProcess,
    subpop_rt_deviation_process: TemporalProcess,
    initial_log_rt_rv: RandomVariable,
    n_initialization_points: int,
    name: str = "latent_infections",
)

Bases: BaseLatentInfectionProcess

Multi-subpopulation renewal model with hierarchical Rt structure.

Each subpopulation has its own renewal equation with Rt deviating from a shared baseline. Suitable when transmission dynamics vary substantially across subpopulations.

Mathematical form: - Baseline Rt: log[R_baseline(t)] ~ TemporalProcess - Subpopulation Rt: log R_k(t) = log[R_baseline(t)] + delta_k(t) - Deviations: delta_k(t) ~ TemporalProcess with sum-to-zero constraint - Renewal per subpop: I_k(t) = R_k(t) * sum_tau I_k(t-tau) * g(tau) - Aggregate total: I_aggregate(t) = sum_k p_k * I_k(t)

The constructor specifies model structure (priors, temporal processes). Population structure (subpop_fractions) is provided at sample time, allowing a single model to be fit to multiple jurisdictions.

Parameters:

Name Type Description Default
gen_int_rv RandomVariable

Generation interval PMF

required
I0_rv RandomVariable

Initial infection prevalence (proportion of population) at first observation time. Must return values in the interval (0, 1). Returns scalar (same for all subpops) or (n_subpops,) array (per-subpop). Full I0 matrix generated via exponential backprojection during sampling.

required
baseline_rt_process TemporalProcess

Temporal process for baseline Rt dynamics

required
subpop_rt_deviation_process TemporalProcess

Temporal process for subpopulation deviations

required
initial_log_rt_rv RandomVariable

Initial value for log(Rt) at time 0. Can be estimated from data or given a prior distribution.

required
n_initialization_points int

Number of initialization days before day 0. Must be at least len(gen_int_rv()) to provide enough history for the renewal equation convolution. When using PyrenewBuilder, this is computed automatically from all observation processes.

required
Notes

Sum-to-zero constraint on deviations ensures R_baseline(t) is the geometric mean of subpopulation Rt values, providing identifiability.

When using PyrenewBuilder (recommended), n_initialization_points is computed automatically from all observation processes.

Initialize hierarchical infections process.

Parameters:

Name Type Description Default
gen_int_rv RandomVariable

Generation interval PMF

required
I0_rv RandomVariable

Initial infection prevalence (proportion of population)

required
baseline_rt_process TemporalProcess

Temporal process for baseline Rt dynamics

required
subpop_rt_deviation_process TemporalProcess

Temporal process for subpopulation deviations

required
initial_log_rt_rv RandomVariable

Initial value for log(Rt) at time 0.

required
n_initialization_points int

Number of initialization days before day 0.

required
name str

Name prefix for numpyro sample sites. All deterministic quantities are recorded under this scope (e.g., "{name}/rt_baseline"). Default: "latent_infections".

'latent_infections'

Raises:

Type Description
ValueError

If required parameters are missing or invalid

Source code in pyrenew/latent/hierarchical_infections.py
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def __init__(
    self,
    *,
    gen_int_rv: RandomVariable,
    I0_rv: RandomVariable,
    baseline_rt_process: TemporalProcess,
    subpop_rt_deviation_process: TemporalProcess,
    initial_log_rt_rv: RandomVariable,
    n_initialization_points: int,
    name: str = "latent_infections",
) -> None:
    """
    Initialize hierarchical infections process.

    Parameters
    ----------
    gen_int_rv : RandomVariable
        Generation interval PMF
    I0_rv : RandomVariable
        Initial infection prevalence (proportion of population)
    baseline_rt_process : TemporalProcess
        Temporal process for baseline Rt dynamics
    subpop_rt_deviation_process : TemporalProcess
        Temporal process for subpopulation deviations
    initial_log_rt_rv : RandomVariable
        Initial value for log(Rt) at time 0.
    n_initialization_points : int
        Number of initialization days before day 0.
    name : str
        Name prefix for numpyro sample sites. All deterministic
        quantities are recorded under this scope (e.g.,
        ``"{name}/rt_baseline"``). Default: ``"latent_infections"``.

    Raises
    ------
    ValueError
        If required parameters are missing or invalid
    """
    super().__init__(
        name=name,
        gen_int_rv=gen_int_rv,
        n_initialization_points=n_initialization_points,
    )

    if I0_rv is None:
        raise ValueError("I0_rv is required")
    self.I0_rv = I0_rv

    # Validate I0 at construction time if it's deterministic
    if isinstance(I0_rv, DeterministicVariable):
        self._validate_I0(I0_rv.value)

    # Validate initial log Rt
    if initial_log_rt_rv is None:
        raise ValueError("initial_log_rt_rv is required")
    self.initial_log_rt_rv = initial_log_rt_rv

    if baseline_rt_process is None:
        raise ValueError("baseline_rt_process is required")
    self.baseline_rt_process = baseline_rt_process

    if subpop_rt_deviation_process is None:
        raise ValueError("subpop_rt_deviation_process is required")
    self.subpop_rt_deviation_process = subpop_rt_deviation_process

sample

sample(
    n_days_post_init: int, *, subpop_fractions: ArrayLike = None, **kwargs
) -> LatentSample

Sample hierarchical infections for all subpopulations.

Generates baseline Rt, subpopulation deviations with sum-to-zero constraint, initial infections, and runs n_subpops independent renewal processes.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period

required
subpop_fractions ArrayLike

Population fractions for all subpopulations. Shape: (n_subpops,). Must sum to 1.0.

None
**kwargs

Additional arguments (unused, for compatibility)

{}

Returns:

Type Description
LatentSample

Named tuple with fields: - aggregate: shape (n_total_days,) - all_subpops: shape (n_total_days, n_subpops)

Source code in pyrenew/latent/hierarchical_infections.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
def sample(
    self,
    n_days_post_init: int,
    *,
    subpop_fractions: ArrayLike = None,
    **kwargs,
) -> LatentSample:
    """
    Sample hierarchical infections for all subpopulations.

    Generates baseline Rt, subpopulation deviations with sum-to-zero
    constraint, initial infections, and runs n_subpops independent renewal processes.

    Parameters
    ----------
    n_days_post_init : int
        Number of days to simulate after initialization period
    subpop_fractions : ArrayLike
        Population fractions for all subpopulations. Shape: (n_subpops,).
        Must sum to 1.0.
    **kwargs
        Additional arguments (unused, for compatibility)

    Returns
    -------
    LatentSample
        Named tuple with fields:
        - aggregate: shape (n_total_days,)
        - all_subpops: shape (n_total_days, n_subpops)
    """
    # Parse and validate population structure
    pop = self._parse_and_validate_fractions(
        subpop_fractions=subpop_fractions,
    )

    n_total_days = self.n_initialization_points + n_days_post_init

    initial_log_rt = self.initial_log_rt_rv()

    log_rt_baseline = self.baseline_rt_process.sample(
        n_timepoints=n_total_days,
        initial_value=initial_log_rt,
        name_prefix="log_rt_baseline",
    )

    deviations_raw = self.subpop_rt_deviation_process.sample(
        n_timepoints=n_total_days,
        n_processes=pop.n_subpops,
        initial_value=jnp.zeros(pop.n_subpops),
        name_prefix="subpop_deviations",
    )

    # Sum-to-zero constraint ensures identifiability
    mean_deviation = jnp.mean(deviations_raw, axis=1, keepdims=True)
    deviations = deviations_raw - mean_deviation

    log_rt_subpop = log_rt_baseline + deviations
    rt_subpop = jnp.exp(log_rt_subpop)

    gen_int = self.gen_int_rv()

    I0 = jnp.asarray(self.I0_rv())
    self._validate_I0(I0)

    if I0.ndim == 0:
        I0_subpop = jnp.full(pop.n_subpops, I0)
    else:
        I0_subpop = I0

    initial_r_subpop = jax.vmap(
        partial(r_approx_from_R, g=gen_int, n_newton_steps=4)
    )(rt_subpop[0, :])

    # Vectorized exponential growth initialization for all subpopulations
    # Formula: I0_subpop[k] * exp(initial_r_subpop[k] * t) for t in [0, n_init)
    time_indices = jnp.arange(self.n_initialization_points)
    I0_all = I0_subpop[jnp.newaxis, :] * jnp.exp(
        initial_r_subpop[jnp.newaxis, :] * time_indices[:, jnp.newaxis]
    )

    gen_int_reversed = jnp.flip(gen_int)
    recent_I0_all = I0_all[-gen_int.size :, :]

    # Vectorized renewal equation for all subpopulations via vmap
    post_init_infections_all = jax.vmap(
        lambda I0_col, Rt_col: compute_infections_from_rt(
            I0=I0_col,
            Rt=Rt_col,
            reversed_generation_interval_pmf=gen_int_reversed,
        ),
        in_axes=1,
        out_axes=1,
    )(recent_I0_all, rt_subpop[self.n_initialization_points :, :])

    infections_all = jnp.vstack([I0_all, post_init_infections_all])

    infections_aggregate = jnp.sum(
        infections_all * pop.fractions[jnp.newaxis, :], axis=1
    )

    self._validate_output_shapes(
        infections_aggregate,
        infections_all,
        n_total_days,
        pop,
    )

    # Record key quantities for diagnostics and posterior analysis
    with numpyro.handlers.scope(prefix=self.name):
        numpyro.deterministic("I0_init_all_subpops", I0_all)
        numpyro.deterministic("log_rt_baseline", log_rt_baseline)
        numpyro.deterministic("rt_baseline", jnp.exp(log_rt_baseline))
        numpyro.deterministic("rt_subpop", rt_subpop)
        numpyro.deterministic("subpop_deviations", deviations)
        numpyro.deterministic("infections_aggregate", infections_aggregate)

    return LatentSample(
        aggregate=infections_aggregate,
        all_subpops=infections_all,
    )

validate

validate() -> None

Validate hierarchical infections parameters.

Checks that the generation interval is a valid PMF.

Raises:

Type Description
ValueError

If gen_int_rv does not return a valid discrete distribution

Source code in pyrenew/latent/hierarchical_infections.py
141
142
143
144
145
146
147
148
149
150
151
152
def validate(self) -> None:
    """
    Validate hierarchical infections parameters.

    Checks that the generation interval is a valid PMF.

    Raises
    ------
    ValueError
        If gen_int_rv does not return a valid discrete distribution
    """
    validate_discrete_dist_vector(self.gen_int_rv())

HierarchicalNormalPrior

HierarchicalNormalPrior(name: str, sd_rv: RandomVariable)

Bases: RandomVariable

Zero-centered Normal prior for group-level effects.

Samples n_groups values from Normal(0, sd).

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter in numpyro.

required
sd_rv RandomVariable

RandomVariable returning the standard deviation.

required

Initialize hierarchical normal prior.

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter.

required
sd_rv RandomVariable

RandomVariable returning the standard deviation.

required
Source code in pyrenew/latent/hierarchical_priors.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    name: str,
    sd_rv: RandomVariable,
) -> None:
    """
    Initialize hierarchical normal prior.

    Parameters
    ----------
    name : str
        Unique name for the sampled parameter.
    sd_rv : RandomVariable
        RandomVariable returning the standard deviation.
    """
    if not isinstance(sd_rv, RandomVariable):
        raise TypeError(
            f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. "
            "Use DeterministicVariable(name, value) to wrap a fixed value."
        )

    super().__init__(name=name)
    self.sd_rv = sd_rv

sample

sample(n_groups: int, **kwargs)

Sample group-level effects.

Parameters:

Name Type Description Default
n_groups int

Number of groups.

required

Returns:

Type Description
ArrayLike

Array of shape (n_groups,).

Source code in pyrenew/latent/hierarchical_priors.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
def sample(self, n_groups: int, **kwargs):
    """
    Sample group-level effects.

    Parameters
    ----------
    n_groups : int
        Number of groups.

    Returns
    -------
    ArrayLike
        Array of shape (n_groups,).
    """
    sd = self.sd_rv()

    with numpyro.plate(f"n_{self.name}", n_groups):
        effects = numpyro.sample(
            self.name,
            dist.Normal(0.0, sd),
        )
    return effects

validate

validate()

Validate the random variable (no-op for this class).

Source code in pyrenew/latent/hierarchical_priors.py
53
54
55
def validate(self):
    """Validate the random variable (no-op for this class)."""
    pass

InfectionInitializationMethod

InfectionInitializationMethod(n_timepoints: int)

Method for initializing infections in a renewal process.

Default constructor for pyrenew.latent.infection_initialization_method.InfectionInitializationMethod.

Parameters:

Name Type Description Default
n_timepoints int

the number of time points for which to generate initial infections

required

Returns:

Type Description
None
Source code in pyrenew/latent/infection_initialization_method.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, n_timepoints: int):
    """Default constructor for
    [`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][].

    Parameters
    ----------
    n_timepoints
        the number of time points for which to
        generate initial infections

    Returns
    -------
    None
    """
    self.validate(n_timepoints)
    self.n_timepoints = n_timepoints

initialize_infections abstractmethod

initialize_infections(I_pre_init: ArrayLike)

Generate the number of initialized infections at each time point.

Parameters:

Name Type Description Default
I_pre_init ArrayLike

An array representing some number of latent infections to be used with the specified [pyrenew.latent.infection_initialization_method.InfectionInitializationMethod][].

required

Returns:

Type Description
ArrayLike

An array of length n_timepoints with the number of initialized infections at each time point.

Source code in pyrenew/latent/infection_initialization_method.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@abstractmethod
def initialize_infections(self, I_pre_init: ArrayLike):
    """Generate the number of initialized infections at each time point.

    Parameters
    ----------
    I_pre_init
        An array representing some number of latent infections to be used with the specified `[`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][]`.

    Returns
    -------
    ArrayLike
        An array of length ``n_timepoints`` with the number of initialized infections at each time point.
    """

validate staticmethod

validate(n_timepoints: int) -> None

Validate inputs to the pyrenew.latent.infection_initialization_method.InfectionInitializationMethod constructor.

Parameters:

Name Type Description Default
n_timepoints int

the number of time points to generate initial infections for

required

Returns:

Type Description
None
Source code in pyrenew/latent/infection_initialization_method.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
@staticmethod
def validate(n_timepoints: int) -> None:
    """
    Validate inputs to the
    [`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][]
    constructor.

    Parameters
    ----------
    n_timepoints
        the number of time points to generate initial infections for

    Returns
    -------
    None
    """
    if not isinstance(n_timepoints, int):
        raise TypeError(
            f"n_timepoints must be an integer. Got {type(n_timepoints)}"
        )
    if n_timepoints <= 0:
        raise ValueError(f"n_timepoints must be positive. Got {n_timepoints}")

InfectionInitializationProcess

InfectionInitializationProcess(
    name,
    I_pre_init_rv: RandomVariable,
    infection_init_method: InfectionInitializationMethod,
)

Bases: RandomVariable

Generate an initial infection history

Default class constructor for InfectionInitializationProcess

Parameters:

Name Type Description Default
name

A name to assign to the RandomVariable.

required
I_pre_init_rv RandomVariable

A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each infection_init_method uses this random variable in different ways.

required
infection_init_method InfectionInitializationMethod

An pyrenew.latent.infection_initialization_method.InfectionInitializationMethod that generates the initial infections for the renewal process.

required

Returns:

Type Description
None
Source code in pyrenew/latent/infection_initialization_process.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    name,
    I_pre_init_rv: RandomVariable,
    infection_init_method: InfectionInitializationMethod,
) -> None:
    """Default class constructor for InfectionInitializationProcess

    Parameters
    ----------
    name
        A name to assign to the RandomVariable.
    I_pre_init_rv
        A RandomVariable representing the number of infections that occur at some time before the renewal process begins. Each `infection_init_method` uses this random variable in different ways.
    infection_init_method
        An [`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][] that generates the initial infections for the renewal process.

    Returns
    -------
    None
    """
    InfectionInitializationProcess.validate(I_pre_init_rv, infection_init_method)

    super().__init__(name=name)
    self.I_pre_init_rv = I_pre_init_rv
    self.infection_init_method = infection_init_method

sample

sample() -> ArrayLike

Sample the Infection Initialization Process.

Returns:

Type Description
ArrayLike

the number of initialized infections at each time point.

Source code in pyrenew/latent/infection_initialization_process.py
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def sample(self) -> ArrayLike:
    """Sample the Infection Initialization Process.

    Returns
    -------
    ArrayLike
        the number of initialized infections at each time point.
    """

    I_pre_init = self.I_pre_init_rv()

    infection_initialization = self.infection_init_method(
        I_pre_init,
    )

    return infection_initialization

validate staticmethod

validate(
    I_pre_init_rv: RandomVariable,
    infection_init_method: InfectionInitializationMethod,
) -> None

Validate the input arguments to the InfectionInitializationProcess class constructor

Parameters:

Name Type Description Default
I_pre_init_rv RandomVariable

A random variable representing the number of infections that occur at some time before the renewal process begins.

required
infection_init_method InfectionInitializationMethod

An method to generate the initial infections.

required

Returns:

Type Description
None
Source code in pyrenew/latent/infection_initialization_process.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
@staticmethod
def validate(
    I_pre_init_rv: RandomVariable,
    infection_init_method: InfectionInitializationMethod,
) -> None:
    """Validate the input arguments to the InfectionInitializationProcess class constructor

    Parameters
    ----------
    I_pre_init_rv
        A random variable representing the number of infections that occur at some time before the renewal process begins.
    infection_init_method
        An method to generate the initial infections.

    Returns
    -------
    None
    """
    _assert_type("I_pre_init_rv", I_pre_init_rv, RandomVariable)
    _assert_type(
        "infection_init_method",
        infection_init_method,
        InfectionInitializationMethod,
    )

Infections

Infections(name: str)

Bases: RandomVariable

Latent infections

This class samples infections given \(\mathcal{R}(t)\), initial infections, and generation interval.

Parameters:

Name Type Description Default
name str

A name for this random variable.

required
Notes

The mathematical model is given by:

\[ I(t) = \mathcal{R}(t) \times \sum_{\tau < t} I(\tau) g(t-\tau) \]

where \(I(t)\) is the number of infections at time \(t\), \(\mathcal{R}(t)\) is the reproduction number at time \(t\), and \(g(t-\tau)\) is the generation interval.

Default constructor.

Parameters:

Name Type Description Default
name str

A name for this random variable.

required
Source code in pyrenew/latent/infections.py
52
53
54
55
56
57
58
59
60
61
def __init__(self, name: str) -> None:
    """
    Default constructor.

    Parameters
    ----------
    name : str
        A name for this random variable.
    """
    super().__init__(name=name)

sample

sample(
    Rt: ArrayLike, I0: ArrayLike, gen_int: ArrayLike, **kwargs
) -> InfectionsSample

Sample infections given \(\mathcal{R}(t)\), initial infections, and generation interval.

Parameters:

Name Type Description Default
Rt ArrayLike

Reproduction number.

required
I0 ArrayLike

Initial infections vector of the same length as the generation interval.

required
gen_int ArrayLike

Generation interval pmf vector.

required
**kwargs

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
InfectionsSample

A named tuple with a post_initialization_infections field.

Source code in pyrenew/latent/infections.py
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def sample(
    self,
    Rt: ArrayLike,
    I0: ArrayLike,
    gen_int: ArrayLike,
    **kwargs,
) -> InfectionsSample:
    r"""
    Sample infections given
    $\mathcal{R}(t)$, initial infections,
    and generation interval.

    Parameters
    ----------
    Rt
        Reproduction number.
    I0
        Initial infections vector
        of the same length as the
        generation interval.
    gen_int
        Generation interval pmf vector.
    **kwargs
        Additional keyword arguments passed through to internal
        sample calls, should there be any.

    Returns
    -------
    InfectionsSample
        A named tuple with a
        `post_initialization_infections` field.
    """
    if I0.shape[0] < gen_int.size:
        raise ValueError(
            "Initial infections vector must be at least as long as "
            "the generation interval. "
            f"Initial infections vector length: {I0.shape[0]}, "
            f"generation interval length: {gen_int.size}."
        )

    if I0.shape[1:] != Rt.shape[1:]:
        raise ValueError(
            "Initial infections and Rt must have the "
            "same batch shapes. "
            f"Got initial infections of batch shape {I0.shape[1:]} "
            f"and Rt of batch shape {Rt.shape[1:]}."
        )

    gen_int_rev = jnp.flip(gen_int)
    recent_I0 = I0[-gen_int_rev.size :]

    post_initialization_infections = inf.compute_infections_from_rt(
        I0=recent_I0,
        Rt=Rt,
        reversed_generation_interval_pmf=gen_int_rev,
    )

    return InfectionsSample(post_initialization_infections)

InfectionsWithFeedback

InfectionsWithFeedback(
    name: str,
    infection_feedback_strength: RandomVariable,
    infection_feedback_pmf: RandomVariable,
)

Bases: RandomVariable

Latent infections

This class computes infections, given Rt, initial infections, and generation interval.

Parameters:

Name Type Description Default
infection_feedback_strength RandomVariable

Infection feedback strength.

required
infection_feedback_pmf RandomVariable

Infection feedback pmf.

required
Notes

This function implements the following renewal process (reproduced from pyrenew.latent.infection_functions.compute_infections_from_rt_with_feedback):

\[ I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(-\gamma(t)\ \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) \]

where \(\mathcal{R}(t)\) is the reproductive number, \(\gamma(t)\) is the infection feedback strength, \(T_g\) is the max-length of the generation interval, \(\mathcal{R}^u(t)\) is the raw reproduction number, \(f(t)\) is the infection feedback pmf, and \(T_f\) is the max-length of the infection feedback pmf.

Default constructor for InfectionsWithFeedback class.

Parameters:

Name Type Description Default
name str

A name for this random variable.

required
infection_feedback_strength RandomVariable

Infection feedback strength.

required
infection_feedback_pmf RandomVariable

Infection feedback pmf.

required

Returns:

Type Description
None
Source code in pyrenew/latent/infectionswithfeedback.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def __init__(
    self,
    name: str,
    infection_feedback_strength: RandomVariable,
    infection_feedback_pmf: RandomVariable,
) -> None:
    """
    Default constructor for InfectionsWithFeedback class.

    Parameters
    ----------
    name : str
        A name for this random variable.
    infection_feedback_strength
        Infection feedback strength.
    infection_feedback_pmf
        Infection feedback pmf.

    Returns
    -------
    None
    """

    super().__init__(name=name)
    self.validate(infection_feedback_strength, infection_feedback_pmf)

    self.infection_feedback_strength = infection_feedback_strength
    self.infection_feedback_pmf = infection_feedback_pmf

    return None

sample

sample(
    Rt: ArrayLike, I0: ArrayLike, gen_int: ArrayLike, **kwargs
) -> InfectionsRtFeedbackSample

Samples infections given Rt, initial infections, and generation interval.

Parameters:

Name Type Description Default
Rt ArrayLike

Reproduction number.

required
I0 ArrayLike

Initial infections, as an array at least as long as the generation interval PMF.

required
gen_int ArrayLike

Generation interval PMF.

required
**kwargs

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
InfectionsWithFeedback

Named tuple with "infections".

Source code in pyrenew/latent/infectionswithfeedback.py
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
def sample(
    self,
    Rt: ArrayLike,
    I0: ArrayLike,
    gen_int: ArrayLike,
    **kwargs,
) -> InfectionsRtFeedbackSample:
    """
    Samples infections given Rt, initial infections, and generation
    interval.

    Parameters
    ----------
    Rt
        Reproduction number.
    I0
        Initial infections, as an array
        at least as long as the generation
        interval PMF.
    gen_int
        Generation interval PMF.
    **kwargs
        Additional keyword arguments passed through to internal
        sample calls, should there be any.

    Returns
    -------
    InfectionsWithFeedback
        Named tuple with "infections".
    """
    if I0.shape[0] < gen_int.size:
        raise ValueError(
            "Initial infections must be at least as long as the "
            f"generation interval. Got initial infections length {I0.shape[0]}"
            f"and generation interval length {gen_int.size}."
        )

    if I0.shape[1:] != Rt.shape[1:]:
        raise ValueError(
            "Initial infections and Rt must have the same batch shapes. "
            f"Got initial infections of batch shape {I0.shape[1:]} "
            f"and Rt of batch shape {Rt.shape[1:]}."
        )

    gen_int_rev = jnp.flip(gen_int)

    I0 = I0[-gen_int_rev.size :]

    # Sampling inf feedback strength
    inf_feedback_strength = jnp.atleast_1d(
        self.infection_feedback_strength(
            **kwargs,
        )
    )

    try:
        inf_feedback_strength = jnp.broadcast_to(inf_feedback_strength, Rt.shape)
    except Exception as e:
        raise ValueError(
            "Could not broadcast inf_feedback_strength "
            f"(shape {inf_feedback_strength.shape}) "
            "to the shape of Rt"
            f"{Rt.shape}"
        ) from e

    # Sampling inf feedback pmf
    inf_feedback_pmf = self.infection_feedback_pmf(**kwargs)

    inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf)

    (
        post_initialization_infections,
        Rt_adj,
    ) = inf.compute_infections_from_rt_with_feedback(
        I0=I0,
        Rt_raw=Rt,
        infection_feedback_strength=inf_feedback_strength,
        reversed_generation_interval_pmf=gen_int_rev,
        reversed_infection_feedback_pmf=inf_fb_pmf_rev,
    )

    return InfectionsRtFeedbackSample(
        post_initialization_infections=post_initialization_infections,
        rt=Rt_adj,
    )

validate staticmethod

validate(inf_feedback_strength: any, inf_feedback_pmf: any) -> None

Validates the input parameters.

Parameters:

Name Type Description Default
inf_feedback_strength any

Infection feedback strength.

required
inf_feedback_pmf any

Infection feedback pmf.

required

Returns:

Type Description
None
Source code in pyrenew/latent/infectionswithfeedback.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
@staticmethod
def validate(
    inf_feedback_strength: any,
    inf_feedback_pmf: any,
) -> None:  # numpydoc ignore=GL08
    """
    Validates the input parameters.

    Parameters
    ----------
    inf_feedback_strength
        Infection feedback strength.
    inf_feedback_pmf
        Infection feedback pmf.

    Returns
    -------
    None
    """
    assert isinstance(inf_feedback_strength, RandomVariable)
    assert isinstance(inf_feedback_pmf, RandomVariable)

    return None

InitializeInfectionsExponentialGrowth

InitializeInfectionsExponentialGrowth(
    n_timepoints: int, rate_rv: RandomVariable, t_pre_init: int | None = None
)

Bases: InfectionInitializationMethod

Generate initial infections according to exponential growth.

Notes

The number of incident infections at time t is given by:

\[ I(t) = I_p \exp \left( r (t - t_p) \right) \]

Where \(I_p\) is I_pre_init, \(r\) is rate, and \(t_p\) is t_pre_init. This ensures that \(I(t_p) = I_p\). We default to t_pre_init = n_timepoints - 1, so that I_pre_init represents the number of incident infections immediately before the renewal process begins.

Default constructor for the pyrenew.latent.infection_initialization_method.InitializeInfectionsExponentialGrowth class.

Parameters:

Name Type Description Default
n_timepoints int

the number of time points to generate initial infections for

required
rate_rv RandomVariable

A random variable representing the rate of exponential growth

required
t_pre_init int | None

The time point whose number of infections is described by I_pre_init. Defaults to n_timepoints - 1.

None
Source code in pyrenew/latent/infection_initialization_method.py
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def __init__(
    self,
    n_timepoints: int,
    rate_rv: RandomVariable,
    t_pre_init: int | None = None,
):
    """Default constructor for the [`pyrenew.latent.infection_initialization_method.InitializeInfectionsExponentialGrowth`][] class.

    Parameters
    ----------
    n_timepoints
        the number of time points to generate initial infections for
    rate_rv
        A random variable representing the rate of exponential growth
    t_pre_init
         The time point whose number of infections is described by ``I_pre_init``. Defaults to ``n_timepoints - 1``.
    """
    super().__init__(n_timepoints)
    self.rate_rv = rate_rv
    if t_pre_init is None:
        t_pre_init = n_timepoints - 1
    self.t_pre_init = t_pre_init

initialize_infections

initialize_infections(I_pre_init: ArrayLike)

Generate initial infections according to exponential growth.

Parameters:

Name Type Description Default
I_pre_init ArrayLike

An array of size 1 representing the number of infections at time t_pre_init.

required

Returns:

Type Description
ArrayLike

An array of length n_timepoints with the number of initialized infections at each time point.

Source code in pyrenew/latent/infection_initialization_method.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def initialize_infections(self, I_pre_init: ArrayLike):
    """Generate initial infections according to exponential growth.

    Parameters
    ----------
    I_pre_init
        An array of size 1 representing the number of infections at time ``t_pre_init``.

    Returns
    -------
    ArrayLike
        An array of length ``n_timepoints`` with the number of initialized infections at each time point.
    """
    I_pre_init = jnp.array(I_pre_init)
    rate = jnp.array(self.rate_rv())
    initial_infections = I_pre_init * jnp.exp(
        rate * (jnp.arange(self.n_timepoints)[:, jnp.newaxis] - self.t_pre_init)
    )
    return jnp.squeeze(initial_infections)

InitializeInfectionsFromVec

InitializeInfectionsFromVec(n_timepoints: int)

Bases: InfectionInitializationMethod

Create initial infections from a vector of infections.

Source code in pyrenew/latent/infection_initialization_method.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, n_timepoints: int):
    """Default constructor for
    [`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][].

    Parameters
    ----------
    n_timepoints
        the number of time points for which to
        generate initial infections

    Returns
    -------
    None
    """
    self.validate(n_timepoints)
    self.n_timepoints = n_timepoints

initialize_infections

initialize_infections(I_pre_init: ArrayLike) -> ArrayLike

Create initial infections from a vector of infections.

Parameters:

Name Type Description Default
I_pre_init ArrayLike

An array with the same length as n_timepoints to be used as the initial infections.

required

Returns:

Type Description
ArrayLike

An array of length n_timepoints with the number of initialized infections at each time point.

Source code in pyrenew/latent/infection_initialization_method.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
def initialize_infections(self, I_pre_init: ArrayLike) -> ArrayLike:
    """Create initial infections from a vector of infections.

    Parameters
    ----------
    I_pre_init
        An array with the same length as ``n_timepoints`` to be
        used as the initial infections.

    Returns
    -------
    ArrayLike
        An array of length ``n_timepoints`` with the number of
        initialized infections at each time point.
    """
    I_pre_init = jnp.array(I_pre_init)
    if I_pre_init.size != self.n_timepoints:
        raise ValueError(
            "I_pre_init must have the same size as n_timepoints. "
            f"Got I_pre_init of size {I_pre_init.size} "
            f"and n_timepoints of size {self.n_timepoints}."
        )
    return I_pre_init

InitializeInfectionsZeroPad

InitializeInfectionsZeroPad(n_timepoints: int)

Bases: InfectionInitializationMethod

Create an initial infection vector of specified length by padding a shorter vector with an appropriate number of zeros at the beginning of the time series.

Source code in pyrenew/latent/infection_initialization_method.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def __init__(self, n_timepoints: int):
    """Default constructor for
    [`pyrenew.latent.infection_initialization_method.InfectionInitializationMethod`][].

    Parameters
    ----------
    n_timepoints
        the number of time points for which to
        generate initial infections

    Returns
    -------
    None
    """
    self.validate(n_timepoints)
    self.n_timepoints = n_timepoints

initialize_infections

initialize_infections(I_pre_init: ArrayLike)

Pad the initial infections with zeros at the beginning of the time series.

Parameters:

Name Type Description Default
I_pre_init ArrayLike

An array with initialized infections to be padded with zeros.

required

Returns:

Type Description
ArrayLike

An array of length n_timepoints with the number of initialized infections at each time point.

Source code in pyrenew/latent/infection_initialization_method.py
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def initialize_infections(self, I_pre_init: ArrayLike):
    """Pad the initial infections with zeros at the beginning of the time series.

    Parameters
    ----------
    I_pre_init
        An array with initialized infections to be padded with zeros.

    Returns
    -------
    ArrayLike
        An array of length ``n_timepoints`` with the number of initialized infections at each time point.
    """
    I_pre_init = jnp.atleast_1d(I_pre_init)
    if self.n_timepoints < I_pre_init.size:
        raise ValueError(
            "I_pre_init must be no longer than n_timepoints. "
            f"Got I_pre_init of size {I_pre_init.size} and "
            f" n_timepoints of size {self.n_timepoints}."
        )
    return jnp.pad(I_pre_init, (self.n_timepoints - I_pre_init.size, 0))

LatentSample

Bases: NamedTuple

Output from latent infection process sampling.

Attributes:

Name Type Description
aggregate ArrayLike

Total infections aggregated across all subpopulations. Shape: (n_total_days,)

all_subpops ArrayLike

Infections for all subpopulations. Shape: (n_total_days, n_subpops)

PopulationStructure dataclass

PopulationStructure(fractions: ArrayLike)

Parsed and validated population structure for a jurisdiction.

Attributes:

Name Type Description
fractions ArrayLike

Population fractions for all subpopulations. Shape: (n_subpops,)

n_subpops property

n_subpops: int

Total number of subpopulations.

Returns:

Type Description
int

The number of subpopulations.

RandomWalk

RandomWalk(innovation_sd: float = 1.0)

Bases: TemporalProcess

Random walk process for log(Rt).

Each value equals the previous value plus noise, with no reversion toward a mean. Allows Rt to drift without bound — suitable when you have no prior expectation that Rt will return to a baseline.

This class wraps pyrenew.process.RandomWalk with a simplified, protocol-compliant interface that handles vectorization automatically.

Parameters:

Name Type Description Default
innovation_sd float

Standard deviation of noise at each time step. Larger values produce faster drift; smaller values produce more gradual changes.

1.0
Notes

Unlike AR(1), variance grows over time — the process can wander arbitrarily far from its starting point. For long time horizons, consider AR(1) if you want Rt to stay bounded near a baseline.

For non-centered parameterization (to avoid funnel problems in inference), apply LocScaleReparam(centered=0) to the step sample site ({name_prefix}_step) via numpyro.handlers.reparam.

Initialize random walk process.

Parameters:

Name Type Description Default
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def __init__(self, innovation_sd: float = 1.0):
    """
    Initialize random walk process.

    Parameters
    ----------
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.innovation_sd = innovation_sd

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
352
353
354
def __repr__(self) -> str:
    """Return string representation."""
    return f"RandomWalk(innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "rw",
) -> ArrayLike

Sample random walk trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"rw"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "rw",
) -> ArrayLike:
    """
    Sample random walk trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "rw"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    rw = ProcessRandomWalk(
        name=f"{name_prefix}_random_walk",
        step_rv=DistributionalVariable(
            name=f"{name_prefix}_step",
            distribution=dist.Normal(
                jnp.zeros(n_processes),
                self.innovation_sd,
            ),
        ),
    )

    return rw.sample(
        init_vals=initial_value[jnp.newaxis, :],
        n=n_timepoints,
    )

StudentTGroupModePrior

StudentTGroupModePrior(name: str, sd_rv: RandomVariable, df_rv: RandomVariable)

Bases: RandomVariable

Zero-centered Student-t prior for group-level modes (robust alternative to Normal).

Samples n_groups values from StudentT(df, 0, sd).

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter in numpyro.

required
sd_rv RandomVariable

RandomVariable returning the scale parameter.

required
df_rv RandomVariable

RandomVariable returning the degrees of freedom.

required

Initialize Student-t group mode prior.

Parameters:

Name Type Description Default
name str

Unique name for the sampled parameter.

required
sd_rv RandomVariable

RandomVariable returning the scale parameter.

required
df_rv RandomVariable

RandomVariable returning the degrees of freedom.

required
Source code in pyrenew/latent/hierarchical_priors.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def __init__(
    self,
    name: str,
    sd_rv: RandomVariable,
    df_rv: RandomVariable,
) -> None:
    """
    Initialize Student-t group mode prior.

    Parameters
    ----------
    name : str
        Unique name for the sampled parameter.
    sd_rv : RandomVariable
        RandomVariable returning the scale parameter.
    df_rv : RandomVariable
        RandomVariable returning the degrees of freedom.
    """
    if not isinstance(sd_rv, RandomVariable):
        raise TypeError(
            f"sd_rv must be a RandomVariable, got {type(sd_rv).__name__}. "
            "Use DeterministicVariable(name, value) to wrap a fixed value."
        )
    if not isinstance(df_rv, RandomVariable):
        raise TypeError(
            f"df_rv must be a RandomVariable, got {type(df_rv).__name__}. "
            "Use DeterministicVariable(name, value) to wrap a fixed value."
        )

    super().__init__(name=name)
    self.sd_rv = sd_rv
    self.df_rv = df_rv

sample

sample(n_groups: int, **kwargs)

Sample group-level modes.

Parameters:

Name Type Description Default
n_groups int

Number of groups.

required

Returns:

Type Description
ArrayLike

Array of shape (n_groups,).

Source code in pyrenew/latent/hierarchical_priors.py
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def sample(self, n_groups: int, **kwargs):
    """
    Sample group-level modes.

    Parameters
    ----------
    n_groups : int
        Number of groups.

    Returns
    -------
    ArrayLike
        Array of shape (n_groups,).
    """
    sd = self.sd_rv()
    df = self.df_rv()

    with numpyro.plate(f"n_{self.name}", n_groups):
        effects = numpyro.sample(
            self.name,
            dist.StudentT(df=df, loc=0.0, scale=sd),
        )
    return effects

validate

validate()

Validate the random variable (no-op for this class).

Source code in pyrenew/latent/hierarchical_priors.py
222
223
224
def validate(self):
    """Validate the random variable (no-op for this class)."""
    pass

TemporalProcess

Bases: Protocol

Protocol for temporal processes generating time-varying parameters.

Used for jurisdiction-level Rt dynamics, subpopulation deviations, or allocation trajectories. All processes return 2D arrays of shape (n_timepoints, n_processes) for consistent handling.

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "temporal",
) -> ArrayLike

Sample temporal trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s) for the process(es). Scalar (broadcast to all processes) or array of shape (n_processes,). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample site names to avoid collisions

"temporal"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "temporal",
) -> ArrayLike:
    """
    Sample temporal trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s) for the process(es).
        Scalar (broadcast to all processes) or array of shape (n_processes,).
        Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "temporal"
        Prefix for numpyro sample site names to avoid collisions

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    ...

compute_infections_from_rt

compute_infections_from_rt(
    I0: ArrayLike, Rt: ArrayLike, reversed_generation_interval_pmf: ArrayLike
) -> ndarray

Generate infections according to a renewal process with a time-varying reproduction number \(\mathcal{R}(t)\)

Parameters:

Name Type Description Default
I0 ArrayLike

Array of initial infections of the same length as the generation interval pmf vector.

required
Rt ArrayLike

Timeseries of \(\mathcal{R}(t)\) values

required
reversed_generation_interval_pmf ArrayLike

discrete probability mass vector representing the generation interval of the infection process, where the final entry represents an infection 1 time unit in the past, the second-to-last entry represents an infection two time units in the past, etc.

required

Returns:

Type Description
ndarray

The timeseries of infections.

Source code in pyrenew/latent/infection_functions.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def compute_infections_from_rt(
    I0: ArrayLike,
    Rt: ArrayLike,
    reversed_generation_interval_pmf: ArrayLike,
) -> jnp.ndarray:
    """
    Generate infections according to a
    renewal process with a time-varying
    reproduction number $\\mathcal{R}(t)$

    Parameters
    ----------
    I0
        Array of initial infections of the
        same length as the generation interval
        pmf vector.
    Rt
        Timeseries of $\\mathcal{R}(t)$ values
    reversed_generation_interval_pmf
        discrete probability mass vector
        representing the generation interval
        of the infection process, where the final
        entry represents an infection 1 time unit in the
        past, the second-to-last entry represents
        an infection two time units in the past, etc.

    Returns
    -------
    jnp.ndarray
        The timeseries of infections.
    """
    incidence_func = new_convolve_scanner(
        reversed_generation_interval_pmf, IdentityTransform()
    )

    latest, all_infections = jax.lax.scan(f=incidence_func, init=I0, xs=Rt)

    return all_infections

compute_infections_from_rt_with_feedback

compute_infections_from_rt_with_feedback(
    I0: ArrayLike,
    Rt_raw: ArrayLike,
    infection_feedback_strength: ArrayLike,
    reversed_generation_interval_pmf: ArrayLike,
    reversed_infection_feedback_pmf: ArrayLike,
) -> tuple

Generate infections according to a renewal process with infection feedback (generalizing Asher 2018 <https://doi.org/10.1016/j.epidem.2017.02.009>_).

Parameters:

Name Type Description Default
I0 ArrayLike

Array of initial infections of the same length as the generation interval pmf vector.

required
Rt_raw ArrayLike

Timeseries of raw \(\mathcal{R}(t)\) values not adjusted by infection feedback

required
infection_feedback_strength ArrayLike

Strength of the infection feedback. Either a scalar (constant feedback strength in time) or a vector representing the infection feedback strength at a given point in time.

required
reversed_generation_interval_pmf ArrayLike

discrete probability mass vector representing the generation interval of the infection process, where the final entry represents an infection 1 time unit in the past, the second-to-last entry represents an infection two time units in the past, etc.

required
reversed_infection_feedback_pmf ArrayLike

discrete probability mass vector representing the infection feedback process, where the final entry represents the relative contribution to infection feedback from infections that occurred 1 time unit in the past, the second-to-last entry represents the contribution from infections that occurred 2 time units in the past, etc.

required

Returns:

Type Description
tuple

A tuple (infections, Rt_adjusted), where Rt_adjusted is the infection-feedback-adjusted timeseries of the reproduction number \(\mathcal{R}(t)\) and infections is the incident infection timeseries.

Notes

This function implements the following renewal process:

\[ \begin{aligned} I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\ \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\ \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right) \end{aligned} \]

where \(\mathcal{R}(t)\) is the reproductive number, \(\gamma(t)\) is the infection feedback strength, \(T_g\) is the max-length of the generation interval, \(\mathcal{R}^u(t)\) is the raw reproduction number, \(f(t)\) is the infection feedback pmf, and \(T_f\) is the max-length of the infection feedback pmf.

Note that negative \(\gamma(t)\) implies that recent incident infections reduce \(\mathcal{R}(t)\) below its raw value in the absence of feedback, while positive \(\gamma\) implies that recent incident infections increase \(\mathcal{R}(t)\) above its raw value, and \(\gamma(t)=0\) implies no feedback.

In general, negative \(\gamma\) is the more common modeling choice, as it can be used to model susceptible depletion, reductions in contact rate due to awareness of high incidence, et cetera.

Source code in pyrenew/latent/infection_functions.py
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
def compute_infections_from_rt_with_feedback(
    I0: ArrayLike,
    Rt_raw: ArrayLike,
    infection_feedback_strength: ArrayLike,
    reversed_generation_interval_pmf: ArrayLike,
    reversed_infection_feedback_pmf: ArrayLike,
) -> tuple:
    r"""
    Generate infections according to
    a renewal process with infection
    feedback (generalizing `Asher 2018
    <https://doi.org/10.1016/j.epidem.2017.02.009>`_).

    Parameters
    ----------
    I0
        Array of initial infections of the
        same length as the generation interval
        pmf vector.
    Rt_raw
        Timeseries of raw $\mathcal{R}(t)$ values not
        adjusted by infection feedback
    infection_feedback_strength
        Strength of the infection feedback.
        Either a scalar (constant feedback
        strength in time) or a vector representing
        the infection feedback strength at a
        given point in time.
    reversed_generation_interval_pmf
        discrete probability mass vector
        representing the generation interval
        of the infection process, where the final
        entry represents an infection 1 time unit in the
        past, the second-to-last entry represents
        an infection two time units in the past, etc.
    reversed_infection_feedback_pmf
        discrete probability mass vector
        representing the infection feedback
        process, where the final entry represents
        the relative contribution to infection
        feedback from infections that occurred
        1 time unit in the past, the second-to-last
        entry represents the contribution from infections
        that occurred 2 time units in the past, etc.

    Returns
    -------
    tuple
        A tuple ``(infections, Rt_adjusted)``,
        where `Rt_adjusted` is the infection-feedback-adjusted
        timeseries of the reproduction number $\mathcal{R}(t)$
        and `infections` is the incident infection timeseries.

    Notes
    -----
    This function implements the following renewal process:

    ```math
    \begin{aligned}
    I(t) & = \mathcal{R}(t)\sum_{\tau=1}^{T_g}I(t - \tau)g(\tau) \\
    \mathcal{R}(t) & = \mathcal{R}^u(t)\exp\left(\gamma(t)\
        \sum_{\tau=1}^{T_f}I(t - \tau)f(\tau)\right)
    \end{aligned}
    ```

    where $\mathcal{R}(t)$ is the reproductive number,
    $\gamma(t)$ is the infection feedback strength,
    $T_g$ is the max-length of the
    generation interval, $\mathcal{R}^u(t)$ is the raw reproduction
    number, $f(t)$ is the infection feedback pmf, and $T_f$
    is the max-length of the infection feedback pmf.

    Note that negative $\gamma(t)$ implies
    that recent incident infections reduce $\mathcal{R}(t)$
    below its raw value in the absence of feedback, while
    positive $\gamma$ implies that recent incident infections
    *increase* $\mathcal{R}(t)$ above its raw value, and
    $\gamma(t)=0$ implies no feedback.

    In general, negative $\gamma$ is the more common modeling
    choice, as it can be used to model susceptible depletion,
    reductions in contact rate due to awareness of high incidence,
    et cetera.
    """
    feedback_scanner = new_double_convolve_scanner(
        arrays_to_convolve=(
            reversed_infection_feedback_pmf,
            reversed_generation_interval_pmf,
        ),
        transforms=(ExpTransform(), IdentityTransform()),
    )
    latest, infs_and_R_adj = jax.lax.scan(
        f=feedback_scanner,
        init=I0,
        xs=(infection_feedback_strength, Rt_raw),
    )

    infections, R_adjustment = infs_and_R_adj
    Rt_adjusted = R_adjustment * Rt_raw
    return infections, Rt_adjusted

logistic_susceptibility_adjustment

logistic_susceptibility_adjustment(
    I_raw_t: float, frac_susceptible: float, n_population: float
) -> float

Apply the logistic susceptibility adjustment to a potential new incidence I_raw_t proposed in equation 6 of Bhatt et al 2023 <https://doi.org/10.1093/jrsssa/qnad030>_.

Parameters:

Name Type Description Default
I_raw_t float

The "unadjusted" incidence at time t, i.e. the incidence given an infinite number of available susceptible individuals.

required
frac_susceptible float

fraction of remaining susceptible individuals in the population

required
n_population float

Total size of the population.

required

Returns:

Type Description
float

The adjusted value of \(I(t)\).

Source code in pyrenew/latent/infection_functions.py
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def logistic_susceptibility_adjustment(
    I_raw_t: float,
    frac_susceptible: float,
    n_population: float,
) -> float:
    """
    Apply the logistic susceptibility
    adjustment to a potential new
    incidence `I_raw_t` proposed in
    equation 6 of `Bhatt et al 2023
    <https://doi.org/10.1093/jrsssa/qnad030>`_.

    Parameters
    ----------
    I_raw_t
        The "unadjusted" incidence at time t,
        i.e. the incidence given an infinite
        number of available susceptible individuals.
    frac_susceptible
        fraction of remaining susceptible individuals
        in the population
    n_population
        Total size of the population.

    Returns
    -------
    float
        The adjusted value of $I(t)$.
    """
    approx_frac_infected = 1 - jnp.exp(-I_raw_t / n_population)
    return n_population * frac_susceptible * approx_frac_infected

Temporal Processes

Temporal processes for latent infection models.

Provides time-series processes for modeling Rt dynamics and subpopulation deviations in hierarchical infection models. All processes return 2D arrays of shape (n_timepoints, n_processes) through a unified TemporalProcess protocol.

Relationship to pyrenew.process:

This module provides high-level, domain-specific wrappers around the low-level building blocks in pyrenew.process. The key differences:

Aspect pyrenew.process pyrenew.latent.temporal_processes
Abstraction level Low-level composable primitives High-level domain-specific API
Interface Varied signatures per class Unified TemporalProcess protocol
Target use General time-series modeling Rt dynamics, hierarchical infections
Vectorization Caller manages array shapes Automatic via n_processes parameter
Validation Minimal constraints Validates positive innovation_sd

When to use which:

  • Use pyrenew.process classes (ARProcess, DifferencedProcess, RandomWalk) when building novel statistical models or when you need fine-grained control over array shapes and numpyro sampling semantics.

  • Use this module's classes (AR1, DifferencedAR1, RandomWalk) when modeling Rt trajectories in hierarchical infection models. These provide a consistent interface, automatic vectorization, and enforce epidemiologically-sensible constraints.

Temporal processes provided:

All implementations satisfy the TemporalProcess protocol and can be used interchangeably in hierarchical infection models.

AR1

AR1(autoreg: float, innovation_sd: float = 1.0)

Bases: TemporalProcess

AR(1) process.

Each value depends on the previous value plus noise, with reversion toward a mean level. Keeps Rt bounded near a baseline — values that drift away are "pulled back" over time.

This class wraps pyrenew.process.ARProcess with a simplified, protocol-compliant interface that handles vectorization automatically.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of noise at each time step. Larger values produce more volatile trajectories; smaller values produce smoother ones.

1.0

Initialize AR(1) process.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(self, autoreg: float, innovation_sd: float = 1.0):
    """
    Initialize AR(1) process.

    Parameters
    ----------
    autoreg : float
        Autoregressive coefficient. For stationarity, |autoreg| < 1,
        but this is not enforced (use priors to constrain if needed).
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.autoreg = autoreg
    self.innovation_sd = innovation_sd
    self.ar_process = ARProcess(name="ar1")

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
146
147
148
def __repr__(self) -> str:
    """Return string representation."""
    return f"AR1(autoreg={self.autoreg}, innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "ar1",
) -> ArrayLike

Sample AR(1) trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"ar1"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "ar1",
) -> ArrayLike:
    """
    Sample AR(1) trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "ar1"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    stationary_sd = self.innovation_sd / jnp.sqrt(1 - self.autoreg**2)

    with numpyro.plate(f"{name_prefix}_init_plate", n_processes):
        init_states = numpyro.sample(
            f"{name_prefix}_init",
            dist.Normal(initial_value, stationary_sd),
        )

    trajectories = self.ar_process(
        n=n_timepoints,
        init_vals=init_states[jnp.newaxis, :],
        autoreg=jnp.full((1, n_processes), self.autoreg),
        noise_sd=self.innovation_sd,
        noise_name=f"{name_prefix}_noise",
    )

    return trajectories

DifferencedAR1

DifferencedAR1(autoreg: float, innovation_sd: float = 1.0)

Bases: TemporalProcess

AR(1) process on first differences.

Each change in value depends on the previous change plus noise, with the rate of change reverting toward a mean. Unlike AR(1), this allows Rt to trend persistently upward or downward while the growth rate stabilizes.

This class wraps pyrenew.process.DifferencedProcess with pyrenew.process.ARProcess as the fundamental process, providing a simplified, protocol-compliant interface.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient for differences. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of noise added to changes. Larger values produce more erratic growth rates; smaller values produce smoother trends.

1.0

Initialize differenced AR(1) process.

Parameters:

Name Type Description Default
autoreg float

Autoregressive coefficient for differences. For stationarity, |autoreg| < 1, but this is not enforced (use priors to constrain if needed).

required
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
def __init__(self, autoreg: float, innovation_sd: float = 1.0):
    """
    Initialize differenced AR(1) process.

    Parameters
    ----------
    autoreg : float
        Autoregressive coefficient for differences. For stationarity,
        |autoreg| < 1, but this is not enforced (use priors to constrain
        if needed).
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.autoreg = autoreg
    self.innovation_sd = innovation_sd
    self.process = DifferencedProcess(
        name="diff_ar1",
        fundamental_process=ARProcess(name="diff_ar1_fundamental"),
        differencing_order=1,
    )

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
251
252
253
def __repr__(self) -> str:
    """Return string representation."""
    return f"DifferencedAR1(autoreg={self.autoreg}, innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "diff_ar1",
) -> ArrayLike

Sample differenced AR(1) trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"diff_ar1"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "diff_ar1",
) -> ArrayLike:
    """
    Sample differenced AR(1) trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "diff_ar1"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    stationary_sd = self.innovation_sd / jnp.sqrt(1 - self.autoreg**2)

    with numpyro.plate(f"{name_prefix}_init_rate_plate", n_processes):
        init_rates = numpyro.sample(
            f"{name_prefix}_init_rate",
            dist.Normal(0, stationary_sd),
        )

    trajectories = self.process(
        n=n_timepoints,
        init_vals=initial_value[jnp.newaxis, :],
        autoreg=jnp.full((1, n_processes), self.autoreg),
        noise_sd=self.innovation_sd,
        fundamental_process_init_vals=init_rates[jnp.newaxis, :],
        noise_name=f"{name_prefix}_noise",
    )

    return trajectories

RandomWalk

RandomWalk(innovation_sd: float = 1.0)

Bases: TemporalProcess

Random walk process for log(Rt).

Each value equals the previous value plus noise, with no reversion toward a mean. Allows Rt to drift without bound — suitable when you have no prior expectation that Rt will return to a baseline.

This class wraps pyrenew.process.RandomWalk with a simplified, protocol-compliant interface that handles vectorization automatically.

Parameters:

Name Type Description Default
innovation_sd float

Standard deviation of noise at each time step. Larger values produce faster drift; smaller values produce more gradual changes.

1.0
Notes

Unlike AR(1), variance grows over time — the process can wander arbitrarily far from its starting point. For long time horizons, consider AR(1) if you want Rt to stay bounded near a baseline.

For non-centered parameterization (to avoid funnel problems in inference), apply LocScaleReparam(centered=0) to the step sample site ({name_prefix}_step) via numpyro.handlers.reparam.

Initialize random walk process.

Parameters:

Name Type Description Default
innovation_sd float

Standard deviation of innovations

1.0

Raises:

Type Description
ValueError

If innovation_sd <= 0

Source code in pyrenew/latent/temporal_processes.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
def __init__(self, innovation_sd: float = 1.0):
    """
    Initialize random walk process.

    Parameters
    ----------
    innovation_sd : float, default 1.0
        Standard deviation of innovations

    Raises
    ------
    ValueError
        If innovation_sd <= 0
    """
    if innovation_sd <= 0:
        raise ValueError(f"innovation_sd must be positive, got {innovation_sd}")
    self.innovation_sd = innovation_sd

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/latent/temporal_processes.py
352
353
354
def __repr__(self) -> str:
    """Return string representation."""
    return f"RandomWalk(innovation_sd={self.innovation_sd})"

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "rw",
) -> ArrayLike

Sample random walk trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample sites

"rw"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "rw",
) -> ArrayLike:
    """
    Sample random walk trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s). Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "rw"
        Prefix for numpyro sample sites

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    if initial_value is None:
        initial_value = jnp.zeros(n_processes)
    elif jnp.isscalar(initial_value):
        initial_value = jnp.full(n_processes, initial_value)

    rw = ProcessRandomWalk(
        name=f"{name_prefix}_random_walk",
        step_rv=DistributionalVariable(
            name=f"{name_prefix}_step",
            distribution=dist.Normal(
                jnp.zeros(n_processes),
                self.innovation_sd,
            ),
        ),
    )

    return rw.sample(
        init_vals=initial_value[jnp.newaxis, :],
        n=n_timepoints,
    )

TemporalProcess

Bases: Protocol

Protocol for temporal processes generating time-varying parameters.

Used for jurisdiction-level Rt dynamics, subpopulation deviations, or allocation trajectories. All processes return 2D arrays of shape (n_timepoints, n_processes) for consistent handling.

sample

sample(
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "temporal",
) -> ArrayLike

Sample temporal trajectory or trajectories.

Parameters:

Name Type Description Default
n_timepoints int

Number of time points to generate

required
initial_value float or ArrayLike

Initial value(s) for the process(es). Scalar (broadcast to all processes) or array of shape (n_processes,). Defaults to 0.0.

None
n_processes int

Number of parallel processes.

1
name_prefix str

Prefix for numpyro sample site names to avoid collisions

"temporal"

Returns:

Type Description
ArrayLike

Trajectories of shape (n_timepoints, n_processes)

Source code in pyrenew/latent/temporal_processes.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def sample(
    self,
    n_timepoints: int,
    initial_value: float | ArrayLike | None = None,
    n_processes: int = 1,
    name_prefix: str = "temporal",
) -> ArrayLike:
    """
    Sample temporal trajectory or trajectories.

    Parameters
    ----------
    n_timepoints : int
        Number of time points to generate
    initial_value : float or ArrayLike, optional
        Initial value(s) for the process(es).
        Scalar (broadcast to all processes) or array of shape (n_processes,).
        Defaults to 0.0.
    n_processes : int, default 1
        Number of parallel processes.
    name_prefix : str, default "temporal"
        Prefix for numpyro sample site names to avoid collisions

    Returns
    -------
    ArrayLike
        Trajectories of shape (n_timepoints, n_processes)
    """
    ...