Skip to content

Model

MultiSignalModel

MultiSignalModel(
    latent_process: BaseLatentInfectionProcess,
    observations: dict[str, BaseObservationProcess],
    ascertainment_models: dict[str, AscertainmentModel] | None = None,
)

Bases: Model

Multi-signal renewal model.

Combines a latent infection process (e.g., SubpopulationInfections) with multiple observation processes (e.g., CountObservation, WastewaterObservation).

Built via PyrenewBuilder to ensure n_initialization_points is computed correctly from all components. Can also be constructed manually for advanced use cases.

Parameters:

Name Type Description Default
latent_process BaseLatentInfectionProcess

Latent infection process generating infections at jurisdiction and/or subpopulation levels

required
observations dict[str, BaseObservationProcess]

Dictionary mapping names to observation process instances. Names are used when passing observation data to sample().

required
ascertainment_models dict[str, AscertainmentModel] | None

Optional dictionary mapping names to ascertainment model instances. Each ascertainment model is sampled once per model execution before observation processes run.

None
Notes

The model automatically routes latent infections to observations based on each observation's infection_resolution() method: - "aggregate" → receives aggregate infections from latent process - "subpop" → receives all subpopulation infections; observation selects via indices

Initialize multi-signal model.

Parameters:

Name Type Description Default
latent_process BaseLatentInfectionProcess

Configured latent infection process

required
observations dict[str, BaseObservationProcess]

Dictionary mapping observation names to observation process instances

required
ascertainment_models dict[str, AscertainmentModel] | None

Optional dictionary mapping names to ascertainment model instances

None

Raises:

Type Description
ValueError

If validation fails (e.g., observation requires subpopulations but latent process doesn't support them)

Source code in pyrenew/model/multisignal_model.py
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def __init__(
    self,
    latent_process: BaseLatentInfectionProcess,
    observations: dict[str, BaseObservationProcess],
    ascertainment_models: dict[str, AscertainmentModel] | None = None,
) -> None:
    """
    Initialize multi-signal model.

    Parameters
    ----------
    latent_process
        Configured latent infection process
    observations
        Dictionary mapping observation names to observation process instances
    ascertainment_models
        Optional dictionary mapping names to ascertainment model instances

    Raises
    ------
    ValueError
        If validation fails (e.g., observation requires subpopulations
        but latent process doesn't support them)
    """
    self.latent = latent_process
    self.observations = observations
    if ascertainment_models is None:
        ascertainment_models = {}
    self.ascertainment_models = ascertainment_models
    self.validate()

pad_observations

pad_observations(obs: ndarray, axis: int = 0) -> ndarray

Pad observations with NaN for the initialization period.

Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. This method prepends n_init NaN values to align user data (starting at day 0 of observations) with the shared axis.

Parameters:

Name Type Description Default
obs ndarray

Observations in natural coordinates (index 0 = first observation day). Integer arrays are converted to float (required for NaN).

required
axis int

Axis along which to pad (typically 0 for time axis).

0

Returns:

Type Description
ndarray

Padded observations. First n_init values are NaN.

Source code in pyrenew/model/multisignal_model.py
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def pad_observations(
    self,
    obs: jnp.ndarray,
    axis: int = 0,
) -> jnp.ndarray:
    """
    Pad observations with NaN for the initialization period.

    Observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days. This method prepends n_init NaN values
    to align user data (starting at day 0 of observations) with the
    shared axis.

    Parameters
    ----------
    obs
        Observations in natural coordinates (index 0 = first observation day).
        Integer arrays are converted to float (required for NaN).
    axis
        Axis along which to pad (typically 0 for time axis).

    Returns
    -------
    jnp.ndarray
        Padded observations. First n_init values are NaN.
    """
    n_init = self.latent.n_initialization_points
    obs = jnp.asarray(obs, dtype=float)
    pad_shape = list(obs.shape)
    pad_shape[axis] = n_init
    padding = jnp.full(pad_shape, jnp.nan)
    return jnp.concatenate([padding, obs], axis=axis)

sample

sample(
    n_days_post_init: int,
    population_size: float,
    *,
    subpop_fractions: ArrayLike | None = None,
    obs_start_date: date | datetime | datetime64 | None = None,
    **observation_data: dict[str, object],
) -> None

Sample from the joint generative model.

This is the model function called by NumPyro during inference.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period

required
population_size float

Total population size. Used to convert infection proportions (from latent process) to infection counts (for observation processes).

required
subpop_fractions ArrayLike | None

Population fractions for all subpopulations. Shape: (n_subpops,).

None
obs_start_date date | datetime | datetime64 | None

Date of the first observation day. Converted once to the axis-origin first_day_dow (day-of-week of element 0 of the padded axis, after subtracting n_init) and forwarded to the latent process and every observation. Required when any observation uses aggregation="weekly" or a day-of-week effect, or when a latent temporal process uses calendar-week alignment.

None
**observation_data dict[str, object]

Data for each observation process, keyed by observation name (the name attribute of each observation process). Each value should be a dict of kwargs for that observation's sample().

{}

Returns:

Type Description
None

All quantities are recorded as NumPyro deterministic sites (latent_infections, latent_infections_by_subpop) and observation sites. Use numpyro.infer.Predictive for forward sampling.

Source code in pyrenew/model/multisignal_model.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
def sample(
    self,
    n_days_post_init: int,
    population_size: float,
    *,
    subpop_fractions: ArrayLike | None = None,
    obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None,
    **observation_data: dict[str, object],
) -> None:
    """
    Sample from the joint generative model.

    This is the model function called by NumPyro during inference.

    Parameters
    ----------
    n_days_post_init
        Number of days to simulate after initialization period
    population_size
        Total population size. Used to convert infection proportions
        (from latent process) to infection counts (for observation processes).
    subpop_fractions
        Population fractions for all subpopulations. Shape: (n_subpops,).
    obs_start_date
        Date of the first observation day. Converted once to the
        axis-origin ``first_day_dow`` (day-of-week of element 0 of
        the padded axis, after subtracting ``n_init``) and forwarded
        to the latent process and every observation. Required when
        any observation uses ``aggregation="weekly"`` or a
        day-of-week effect, or when a latent temporal process uses
        calendar-week alignment.
    **observation_data
        Data for each observation process, keyed by observation name
        (the ``name`` attribute of each observation process).
        Each value should be a dict of kwargs for that observation's sample().

    Returns
    -------
    None
        All quantities are recorded as NumPyro deterministic sites
        (``latent_infections``, ``latent_infections_by_subpop``) and
        observation sites. Use ``numpyro.infer.Predictive`` for forward
        sampling.
    """
    self._check_obs_start_date(obs_start_date)
    first_day_dow = self._resolve_first_day_dow(obs_start_date)

    # Generate latent infections (proportions)
    latent_sample = self.latent.sample(
        n_days_post_init=n_days_post_init,
        subpop_fractions=subpop_fractions,
        first_day_dow=first_day_dow,
    )

    # Scale from proportions to counts
    inf_aggregate = latent_sample.aggregate * population_size
    inf_all = latent_sample.all_subpops * population_size

    # Record scaled infections for posterior analysis
    numpyro.deterministic("latent_infections", inf_aggregate)
    numpyro.deterministic("latent_infections_by_subpop", inf_all)

    # Map infection resolution to infection arrays
    latent_map = {
        "aggregate": inf_aggregate,
        "subpop": inf_all,
    }

    ascertainment_values = {
        name: ascertainment_model.sample()
        for name, ascertainment_model in self.ascertainment_models.items()
    }

    with ascertainment_context(ascertainment_values):
        # Apply each observation process
        for name, obs_process in self.observations.items():
            # Get the appropriate latent infections based on observation type
            resolution = obs_process.infection_resolution()
            if resolution not in latent_map:
                raise ValueError(
                    f"Observation '{name}' returned invalid infection_resolution "
                    f"'{resolution}'. Expected one of "
                    f"{self._SUPPORTED_RESOLUTIONS}."
                )
            latent_infections = latent_map[resolution]

            # Get observation-specific data
            obs_data = observation_data.get(name, {})

            # Sample from observation process
            obs_process.sample(
                infections=latent_infections,
                first_day_dow=first_day_dow,
                **obs_data,
            )

    return None

shift_times

shift_times(times: ndarray) -> ndarray

Shift time indices from natural coordinates to shared time axis.

Observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days. User-provided times in natural coordinates (0 = first observation day) must be shifted by n_init to align with the shared axis.

Parameters:

Name Type Description Default
times ndarray

Time indices in natural coordinates (0 = first observation day).

required

Returns:

Type Description
ndarray

Time indices on the shared axis [n_init, n_total).

Source code in pyrenew/model/multisignal_model.py
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def shift_times(self, times: jnp.ndarray) -> jnp.ndarray:
    """
    Shift time indices from natural coordinates to shared time axis.

    Observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days. User-provided times in natural coordinates
    (0 = first observation day) must be shifted by n_init to align with
    the shared axis.

    Parameters
    ----------
    times
        Time indices in natural coordinates (0 = first observation day).

    Returns
    -------
    jnp.ndarray
        Time indices on the shared axis [n_init, n_total).
    """
    n_init = self.latent.n_initialization_points
    return jnp.asarray(times) + n_init

validate

validate() -> None

Validate that observation processes are compatible with latent process.

Checks that each observation implements infection_resolution() and returns a supported resolution.

Raises:

Type Description
ValueError

If an observation doesn't implement infection_resolution() or returns an unsupported resolution.

Source code in pyrenew/model/multisignal_model.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def validate(self) -> None:
    """
    Validate that observation processes are compatible with latent process.

    Checks that each observation implements infection_resolution()
    and returns a supported resolution.

    Raises
    ------
    ValueError
        If an observation doesn't implement infection_resolution()
        or returns an unsupported resolution.
    """
    for name, obs in self.observations.items():
        if not hasattr(obs, "infection_resolution"):
            raise ValueError(
                f"Observation '{name}' must implement infection_resolution()"
            )
        resolution = obs.infection_resolution()
        if resolution not in self._SUPPORTED_RESOLUTIONS:
            raise ValueError(
                f"Observation '{name}' returned invalid infection_resolution "
                f"'{resolution}'. Expected one of {self._SUPPORTED_RESOLUTIONS}."
            )
    for name, ascertainment_model in self.ascertainment_models.items():
        if not isinstance(ascertainment_model, AscertainmentModel):
            raise TypeError(
                f"Ascertainment model '{name}' must be an AscertainmentModel, "
                f"got {type(ascertainment_model).__name__}."
            )
        if ascertainment_model.name != name:
            raise ValueError(
                f"Ascertainment model dictionary key {name!r} must match "
                f"the model name {ascertainment_model.name!r}."
            )

validate_data

validate_data(
    n_days_post_init: int,
    *,
    subpop_fractions: ArrayLike | None = None,
    obs_start_date: date | datetime | datetime64 | None = None,
    **observation_data: dict[str, object],
) -> None

Validate observation data before running MCMC.

All observation data uses a shared time axis [0, n_total) where n_total = n_init + n_days_post_init. Dense observations must have length n_total with NaN padding for the initialization period. Sparse observations provide times indices on this shared axis.

This method must be called with concrete (non-traced) values before running inference. Validation using Python control flow (if/raise) cannot be done during JAX tracing.

Parameters:

Name Type Description Default
n_days_post_init int

Number of days to simulate after initialization period.

required
subpop_fractions ArrayLike | None

Population fractions for all subpopulations. Shape: (n_subpops,).

None
obs_start_date date | datetime | datetime64 | None

Date of the first observation day. Required when any observation uses aggregation="weekly" or a day-of-week effect, or when a latent temporal process uses calendar-week alignment. Converted once to the axis-origin first_day_dow and forwarded to the latent process and every observation.

None
**observation_data dict[str, object]

Data for each observation process, keyed by observation name. Each value should be a dict of kwargs for that observation's sample().

{}

Raises:

Type Description
ValueError

If times indices are out of bounds or negative, if dense obs length doesn't match n_total, if data shapes are inconsistent, or if obs_start_date is missing when an observation requires it.

Source code in pyrenew/model/multisignal_model.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def validate_data(
    self,
    n_days_post_init: int,
    *,
    subpop_fractions: ArrayLike | None = None,
    obs_start_date: dt.date | dt.datetime | np.datetime64 | None = None,
    **observation_data: dict[str, object],
) -> None:
    """
    Validate observation data before running MCMC.

    All observation data uses a shared time axis [0, n_total) where
    n_total = n_init + n_days_post_init. Dense observations must have
    length n_total with NaN padding for the initialization period.
    Sparse observations provide times indices on this shared axis.

    This method must be called with concrete (non-traced) values
    before running inference. Validation using Python control flow
    (if/raise) cannot be done during JAX tracing.

    Parameters
    ----------
    n_days_post_init
        Number of days to simulate after initialization period.
    subpop_fractions
        Population fractions for all subpopulations. Shape: (n_subpops,).
    obs_start_date
        Date of the first observation day. Required when any
        observation uses ``aggregation="weekly"`` or a day-of-week
        effect, or when a latent temporal process uses
        calendar-week alignment. Converted once to the axis-origin
        ``first_day_dow`` and forwarded to the latent process and
        every observation.
    **observation_data
        Data for each observation process, keyed by observation name.
        Each value should be a dict of kwargs for that observation's sample().

    Raises
    ------
    ValueError
        If times indices are out of bounds or negative, if dense obs
        length doesn't match n_total, if data shapes are inconsistent,
        or if ``obs_start_date`` is missing when an observation requires it.
    """
    self._check_obs_start_date(obs_start_date)

    pop = self.latent._parse_and_validate_fractions(
        subpop_fractions=subpop_fractions,
    )

    n_init = self.latent.n_initialization_points
    n_total = n_init + n_days_post_init
    first_day_dow = self._resolve_first_day_dow(obs_start_date)

    for name, obs_data in observation_data.items():
        if name not in self.observations:
            raise ValueError(
                f"Unknown observation '{name}'. "
                f"Available: {list(self.observations.keys())}"
            )

        obs = self.observations[name]
        obs.validate_data(
            n_total=n_total,
            n_subpops=pop.n_subpops,
            first_day_dow=first_day_dow,
            **obs_data,
        )

PyrenewBuilder

PyrenewBuilder()

Builder for multi-signal renewal models.

Automatically computes n_initialization_points from observation processes and constructs a properly configured model.

The builder pattern ensures that: 1. n_initialization_points is computed correctly from all components 2. Validation happens at build time (fail-fast) 3. Latent infections are routed to the correct observations 4. The API is clean and easy to use

Population structure (subpop_fractions) is passed at sample/fit time, not at configure time. This allows a single model to be fit to multiple jurisdictions with different population structures.

Initialize a new model builder.

The builder starts empty and is configured by calling: 1. configure_latent() - set up the latent infection process 2. add_observation() - add one or more observation processes 3. build() - construct the final model

Source code in pyrenew/model/pyrenew_builder.py
41
42
43
44
45
46
47
48
49
50
51
52
53
def __init__(self) -> None:
    """
    Initialize a new model builder.

    The builder starts empty and is configured by calling:
    1. configure_latent() - set up the latent infection process
    2. add_observation() - add one or more observation processes
    3. build() - construct the final model
    """
    self.latent_class: type[BaseLatentInfectionProcess] | None = None
    self.latent_params: dict[str, Any] = {}
    self.observations: dict[str, BaseObservationProcess] = {}
    self.ascertainment_models: dict[str, AscertainmentModel] = {}

add_ascertainment

add_ascertainment(ascertainment_model: AscertainmentModel) -> PyrenewBuilder

Add shared ascertainment structure to the model.

Use this method when observation probabilities are related across signals. Independent scalar ascertainment rates do not require an ascertainment model; those can be passed directly to an observation process as ordinary RandomVariable objects.

A registered ascertainment model is sampled once per model execution, before observation processes run. Observation processes receive signal-specific accessors from ascertainment_model.for_signal(...):

ascertainment = JointAscertainment(
    name="he_ascertainment",
    signals=("hospital", "ed_visits"),
    baseline_rates=...,
    scale_tril=...,
)
builder.add_ascertainment(ascertainment)

builder.add_observation(
    PopulationCounts(
        name="hospital",
        ascertainment_rate_rv=ascertainment.for_signal("hospital"),
        ...
    )
)

The ascertainment model's name attribute is used as the unique identifier in the built MultiSignalModel.

Parameters:

Name Type Description Default
ascertainment_model AscertainmentModel

Configured ascertainment model instance, such as JointAscertainment.

required

Returns:

Type Description
PyrenewBuilder

Self, for method chaining.

Raises:

Type Description
TypeError

If ascertainment_model is not an AscertainmentModel.

ValueError

If an ascertainment model with this name already exists.

Source code in pyrenew/model/pyrenew_builder.py
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
def add_ascertainment(
    self,
    ascertainment_model: AscertainmentModel,
) -> PyrenewBuilder:
    """
    Add shared ascertainment structure to the model.

    Use this method when observation probabilities are related across
    signals. Independent scalar ascertainment rates do not require an
    ascertainment model; those can be passed directly to an observation
    process as ordinary ``RandomVariable`` objects.

    A registered ascertainment model is sampled once per model execution,
    before observation processes run. Observation processes receive
    signal-specific accessors from ``ascertainment_model.for_signal(...)``:

    ```python
    ascertainment = JointAscertainment(
        name="he_ascertainment",
        signals=("hospital", "ed_visits"),
        baseline_rates=...,
        scale_tril=...,
    )
    builder.add_ascertainment(ascertainment)

    builder.add_observation(
        PopulationCounts(
            name="hospital",
            ascertainment_rate_rv=ascertainment.for_signal("hospital"),
            ...
        )
    )
    ```

    The ascertainment model's ``name`` attribute is used as the unique
    identifier in the built ``MultiSignalModel``.

    Parameters
    ----------
    ascertainment_model
        Configured ascertainment model instance, such as
        ``JointAscertainment``.

    Returns
    -------
    PyrenewBuilder
        Self, for method chaining.

    Raises
    ------
    TypeError
        If ``ascertainment_model`` is not an ``AscertainmentModel``.
    ValueError
        If an ascertainment model with this name already exists.
    """
    if not isinstance(ascertainment_model, AscertainmentModel):
        raise TypeError(
            "ascertainment_model must be an AscertainmentModel, "
            f"got {type(ascertainment_model).__name__}."
        )

    name = ascertainment_model.name
    if name in self.ascertainment_models:
        raise ValueError(
            f"Ascertainment model '{name}' already added. "
            "Each ascertainment model must have a unique name."
        )

    self.ascertainment_models[name] = ascertainment_model
    return self

add_observation

add_observation(obs_process: BaseObservationProcess) -> PyrenewBuilder

Add an observation process to the model.

The observation process's name attribute is used as the unique identifier for this observation. This name is used when passing observation data to model.sample() and model.fit(), and also prefixes all numpyro sample sites created by the process.

Parameters:

Name Type Description Default
obs_process BaseObservationProcess

Configured observation process instance (e.g., PopulationCounts, Wastewater, SubpopulationCounts). Must have a name attribute.

required

Returns:

Type Description
PyrenewBuilder

Self, for method chaining

Raises:

Type Description
ValueError

If an observation with this name already exists

Source code in pyrenew/model/pyrenew_builder.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
def add_observation(
    self,
    obs_process: BaseObservationProcess,
) -> PyrenewBuilder:
    """
    Add an observation process to the model.

    The observation process's ``name`` attribute is used as the unique
    identifier for this observation. This name is used when passing
    observation data to ``model.sample()`` and ``model.fit()``, and also
    prefixes all numpyro sample sites created by the process.

    Parameters
    ----------
    obs_process
        Configured observation process instance (e.g., PopulationCounts,
        Wastewater, SubpopulationCounts). Must have a ``name`` attribute.

    Returns
    -------
    PyrenewBuilder
        Self, for method chaining

    Raises
    ------
    ValueError
        If an observation with this name already exists
    """
    name = obs_process.name
    if name in self.observations:
        raise ValueError(
            f"Observation '{name}' already added. "
            f"Each observation must have a unique name."
        )

    self.observations[name] = obs_process
    return self

build

build() -> MultiSignalModel

Build the multi-signal model with computed n_initialization_points.

This method: 1. Computes n_initialization_points from all components 2. Constructs the latent process with the computed value 3. Creates a MultiSignalModel with automatic infection routing

Can be called multiple times to create multiple model instances.

Returns:

Type Description
MultiSignalModel

Configured model ready for sampling

Raises:

Type Description
ValueError

If latent process not configured.

Source code in pyrenew/model/pyrenew_builder.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
def build(self) -> MultiSignalModel:
    """
    Build the multi-signal model with computed n_initialization_points.

    This method:
    1. Computes n_initialization_points from all components
    2. Constructs the latent process with the computed value
    3. Creates a MultiSignalModel with automatic infection routing

    Can be called multiple times to create multiple model instances.

    Returns
    -------
    MultiSignalModel
        Configured model ready for sampling

    Raises
    ------
    ValueError
        If latent process not configured.
    """
    if self.latent_class is None:
        raise ValueError("Must call configure_latent() before build()")

    # Compute n_initialization_points
    n_init = self.compute_n_initialization_points()

    # Construct latent process with computed n_initialization_points
    latent_params = {**self.latent_params, "n_initialization_points": n_init}
    if "name" not in latent_params:
        latent_params["name"] = self.latent_class.__name__

    latent_process = self.latent_class(**latent_params)

    # Build model
    model = MultiSignalModel(
        latent_process=latent_process,
        observations=self.observations,
        ascertainment_models=self.ascertainment_models,
    )

    return model

compute_n_initialization_points

compute_n_initialization_points() -> int

Compute required n_initialization_points from all components.

Formula: n_initialization_points = max(all lookbacks)

Where lookbacks include: - Generation interval length (from latent process) - All observation delay/shedding PMF lengths

Useful for inspection before building the model.

Returns:

Type Description
int

Minimum n_initialization_points needed to satisfy all components

Raises:

Type Description
ValueError

If latent process not configured or gen_int_rv missing If any observation process doesn't implement get_required_lookback()

Source code in pyrenew/model/pyrenew_builder.py
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def compute_n_initialization_points(self) -> int:
    """
    Compute required n_initialization_points from all components.

    Formula: n_initialization_points = max(all lookbacks)

    Where lookbacks include:
    - Generation interval length (from latent process)
    - All observation delay/shedding PMF lengths

    Useful for inspection before building the model.

    Returns
    -------
    int
        Minimum n_initialization_points needed to satisfy all components

    Raises
    ------
    ValueError
        If latent process not configured or gen_int_rv missing
        If any observation process doesn't implement get_required_lookback()
    """
    if self.latent_class is None:
        raise ValueError(
            "Must call configure_latent() before computing n_initialization_points"
        )

    # Get generation interval length from latent params
    gen_int_rv = self.latent_params.get("gen_int_rv")
    if gen_int_rv is None:
        raise ValueError("gen_int_rv is required in latent process parameters")

    # Start with generation interval lookback
    lookbacks = [len(gen_int_rv())]

    # Add lookback from each observation process
    for name, obs_process in self.observations.items():
        lookbacks.append(obs_process.lookback_days())

    # Formula: max(all lookbacks)
    # For generation interval (1-indexed): L-element PMF has max lag L days → need L init points
    # For delay distributions (0-indexed): L-element PMF has max delay L-1 days
    # We need at least max(all lookbacks) to satisfy the renewal equation extraction
    n_init = max(lookbacks)

    return n_init

configure_latent

configure_latent(
    latent_class: type[BaseLatentInfectionProcess], **params: dict[str, object]
) -> PyrenewBuilder

Configure the latent infection process.

Parameters:

Name Type Description Default
latent_class type[BaseLatentInfectionProcess]

Class to use for latent infections (e.g., PopulationInfections, SubpopulationInfections, or a custom implementation)

required
**params dict[str, object]

Parameters for latent class constructor (model structure). DO NOT include n_initialization_points - it will be computed automatically from observation processes. DO NOT include population structure params (subpop_fractions) - these are passed at sample/fit time to allow fitting to multiple jurisdictions.

{}

Returns:

Type Description
PyrenewBuilder

Self, for method chaining

Raises:

Type Description
ValueError

If n_initialization_points or population structure params are included

RuntimeError

If latent has already been configured

Source code in pyrenew/model/pyrenew_builder.py
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def configure_latent(
    self,
    latent_class: type[BaseLatentInfectionProcess],
    **params: dict[str, object],
) -> PyrenewBuilder:
    """
    Configure the latent infection process.

    Parameters
    ----------
    latent_class
        Class to use for latent infections (e.g., PopulationInfections,
        SubpopulationInfections, or a custom implementation)
    **params
        Parameters for latent class constructor (model structure).
        DO NOT include n_initialization_points - it will be computed
        automatically from observation processes.
        DO NOT include population structure params (subpop_fractions) -
        these are passed at sample/fit time to allow fitting to multiple
        jurisdictions.

    Returns
    -------
    PyrenewBuilder
        Self, for method chaining

    Raises
    ------
    ValueError
        If n_initialization_points or population structure params are included
    RuntimeError
        If latent has already been configured
    """
    if self.latent_class is not None:
        raise RuntimeError(
            "Latent process already configured. Create a new builder for "
            "a different configuration."
        )

    if "n_initialization_points" in params:
        raise ValueError(
            "Do not specify n_initialization_points - it will be computed "
            "automatically from observation processes and generation interval. "
            "Use PyrenewBuilder.build() to create the model with the correct value."
        )

    # Check for population structure params that should be at sample time
    sample_time_found = _SAMPLE_TIME_PARAMS.intersection(params.keys())
    if sample_time_found:
        raise ValueError(
            f"Do not specify {sample_time_found} at configure time. "
            f"Population structure is passed at sample/fit time to allow "
            f"fitting the same model to multiple jurisdictions."
        )

    self.latent_class = latent_class
    self.latent_params = params
    return self