Skip to content

Observation

Observation processes for connecting infections to observed data.

BaseObservationProcess is the abstract base. Concrete subclasses:

  • PopulationCounts: Aggregate counts (admissions, deaths)
  • SubpopulationCounts: Subpopulation-level counts
  • MeasurementObservation: Continuous subpopulation-level signals (e.g., wastewater)

All observation processes implement:

  • sample(): Sample observations given infections
  • infection_resolution(): returns "aggregate" or "subpop"
  • lookback_days(): returns required infection history length

Noise models (CountNoise, MeasurementNoise) are composable—pass them to observation constructors to control the output distribution.

BaseObservationProcess

BaseObservationProcess(name: str, temporal_pmf_rv: RandomVariable)

Bases: RandomVariable

Abstract base class for observation processes that use convolution with temporal distributions.

This class provides common functionality for connecting infections to observed data (e.g., hospital admissions, wastewater concentrations) through temporal convolution operations.

Key features provided:

  • PMF validation (sum to 1, non-negative)
  • Minimum observation day calculation
  • Convolution wrapper with timeline alignment
  • Deterministic quantity tracking

Subclasses must implement:

  • lookback_days(): Return PMF length for initialization
  • infection_resolution(): Return "aggregate" or "subpop"
  • _predicted_obs(): Transform infections to predicted values
  • sample(): Apply noise model to predicted observations
Notes

Computing predicted observations on day t requires infection history from previous days (determined by the temporal PMF length). The first len(pmf) - 1 days have insufficient history and return NaN.

See Also

pyrenew.convolve.compute_delay_ascertained_incidence : Underlying convolution function pyrenew.metaclass.RandomVariable : Base class for all random variables

Initialize base observation process.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names, enabling multiple observations of the same type in a single model.

required
temporal_pmf_rv RandomVariable

The temporal distribution PMF (e.g., delay or shedding distribution). Must sample to a 1D array that sums to ~1.0 with non-negative values. Subclasses may have additional parameters.

required
Notes

Subclasses should call super().__init__(name, temporal_pmf_rv) in their constructors and may add additional parameters.

Source code in pyrenew/observation/base.py
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def __init__(self, name: str, temporal_pmf_rv: RandomVariable) -> None:
    """
    Initialize base observation process.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names, enabling multiple
        observations of the same type in a single model.
    temporal_pmf_rv
        The temporal distribution PMF (e.g., delay or shedding distribution).
        Must sample to a 1D array that sums to ~1.0 with non-negative values.
        Subclasses may have additional parameters.

    Notes
    -----
    Subclasses should call ``super().__init__(name, temporal_pmf_rv)``
    in their constructors and may add additional parameters.
    """
    super().__init__(name=name)
    self.temporal_pmf_rv = temporal_pmf_rv

infection_resolution abstractmethod

infection_resolution() -> str

Return whether this observation uses aggregate or subpop infections.

Returns one of:

  • "aggregate": Uses a single aggregated infection trajectory. Shape: (n_days,)
  • "subpop": Uses subpopulation-level infection trajectories. Shape: (n_days, n_subpops), indexed via subpop_indices.

Returns:

Type Description
str

Either "aggregate" or "subpop"

Notes

This is used by multi-signal models to route the correct infection output to each observation process.

Source code in pyrenew/observation/base.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
@abstractmethod
def infection_resolution(self) -> str:
    """
    Return whether this observation uses aggregate or subpop infections.

    Returns one of:

    - ``"aggregate"``: Uses a single aggregated infection trajectory.
      Shape: ``(n_days,)``
    - ``"subpop"``: Uses subpopulation-level infection trajectories.
      Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``.

    Returns
    -------
    str
        Either ``"aggregate"`` or ``"subpop"``

    Notes
    -----
    This is used by multi-signal models to route the correct infection
    output to each observation process.
    """
    pass  # pragma: no cover

lookback_days abstractmethod

lookback_days() -> int

Return the number of days this observation process needs to look back.

This determines the minimum n_initialization_points required by the latent process when this observation is included in a multi-signal model.

Returns:

Type Description
int

Number of days of infection history required.

Notes

Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points. Implementations should return len(pmf) - 1.

This is used by model builders to automatically compute n_initialization_points as max(gen_int_length, max(all lookbacks)).

Source code in pyrenew/observation/base.py
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
@abstractmethod
def lookback_days(self) -> int:
    """
    Return the number of days this observation process needs to look back.

    This determines the minimum n_initialization_points required by the
    latent process when this observation is included in a multi-signal model.

    Returns
    -------
    int
        Number of days of infection history required.

    Notes
    -----
    Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a
    PMF of length L covers lags 0 to L-1, requiring L-1 initialization
    points. Implementations should return ``len(pmf) - 1``.

    This is used by model builders to automatically compute
    n_initialization_points as ``max(gen_int_length, max(all lookbacks))``.
    """
    pass  # pragma: no cover

sample abstractmethod

sample(obs: ArrayLike | None = None, **kwargs: object) -> ArrayLike

Sample from the observation process.

Subclasses must implement this method to define the specific observation model. Typically calls _predicted_obs first, then applies the noise model.

Parameters:

Name Type Description Default
obs ArrayLike | None

Observed data for conditioning, or None for prior predictive sampling.

None
**kwargs object

Subclass-specific parameters (e.g., infections from the infection process).

{}

Returns:

Type Description
ArrayLike

Observed or sampled values from the observation process.

Source code in pyrenew/observation/base.py
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
@abstractmethod
def sample(
    self,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample from the observation process.

    Subclasses must implement this method to define the specific
    observation model. Typically calls ``_predicted_obs`` first,
    then applies the noise model.

    Parameters
    ----------
    obs
        Observed data for conditioning, or None for prior predictive sampling.
    **kwargs
        Subclass-specific parameters (e.g., infections from the infection process).

    Returns
    -------
    ArrayLike
        Observed or sampled values from the observation process.
    """
    pass  # pragma: no cover

validate_data abstractmethod

validate_data(
    n_total: int, n_subpops: int, **obs_data: dict[str, object]
) -> None

Validate observation data before running inference.

Each observation process validates its own data requirements. Called by the model's validate_data() method with concrete (non-traced) values before JAX tracing begins.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
**obs_data dict[str, object]

Observation-specific data kwargs (same as passed to sample(), minus infections which comes from the latent process).

{}

Raises:

Type Description
ValueError

If any data fails validation.

Source code in pyrenew/observation/base.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@abstractmethod
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    **obs_data: dict[str, object],
) -> None:
    """
    Validate observation data before running inference.

    Each observation process validates its own data requirements.
    Called by the model's ``validate_data()`` method with concrete
    (non-traced) values before JAX tracing begins.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations.
    **obs_data
        Observation-specific data kwargs (same as passed to ``sample()``,
        minus ``infections`` which comes from the latent process).

    Raises
    ------
    ValueError
        If any data fails validation.
    """
    pass  # pragma: no cover

CountNoise

Bases: ABC

Abstract base for count observation noise models.

Defines how discrete count observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample count observations given predicted counts.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values (non-negative).

required
obs ArrayLike | None

Observed counts for conditioning, or None for prior sampling.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. If provided, observations where mask is False are excluded from the likelihood.

None

Returns:

Type Description
ArrayLike

Sampled or conditioned counts, same shape as predicted.

Notes

Implementations use numpyro.handlers.mask rather than the obs_mask parameter of numpyro.sample. This avoids creating latent variables for masked entries, which would fail with NUTS for discrete distributions.

Source code in pyrenew/observation/noise.py
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample count observations given predicted counts.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values (non-negative).
    obs
        Observed counts for conditioning, or None for prior sampling.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included. If provided,
        observations where mask is False are excluded from the likelihood.

    Returns
    -------
    ArrayLike
        Sampled or conditioned counts, same shape as predicted.

    Notes
    -----
    Implementations use ``numpyro.handlers.mask`` rather than the
    ``obs_mask`` parameter of ``numpyro.sample``. This avoids creating
    latent variables for masked entries, which would fail with NUTS
    for discrete distributions.
    """
    pass  # pragma: no cover

CountObservation

CountObservation(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
)

Bases: BaseObservationProcess

Abstract Base class for count observation processes.

Subclasses map infections to counts through ascertainment x delay convolution with composable noise model. Count observations always receive predictions on the model's daily time axis and then, if requested, aggregate those daily predictions to the observation reporting grid before evaluating the likelihood.

Initialize count observation base.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1] (e.g., IHR, IER).

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model for count observations (Poisson, NegBin, etc.).

required
right_truncation_rv RandomVariable | None

Optional reporting delay PMF for right-truncation adjustment. When provided (along with right_truncation_offset at sample time), predicted counts are scaled down for recent timepoints to account for incomplete reporting.

None
day_of_week_rv RandomVariable | None

Optional day-of-week multiplicative effect. Must sample to shape (7,) with non-negative values, where entry j is the multiplier for day-of-week j (0=Monday, 6=Sunday, ISO convention). An effect of 1.0 means no adjustment for that day. Values summing to 7.0 preserve weekly totals and keep the ascertainment rate interpretable; other sums rescale overall predicted counts. When provided (along with first_day_dow at sample time), predicted counts are scaled by a periodic weekly pattern.

None
aggregation Literal['daily', 'weekly']

Observation reporting cadence; one of "daily" or "weekly". Controls only the scale on which the count likelihood is evaluated; it does not control how often the latent Rt temporal process samples new parameters.

'daily'
reporting_schedule Literal['regular', 'irregular']

Either "regular" (dense observation array, one entry per period, NaN for unobserved periods) or "irregular" (sparse observation array with user-supplied period-end time indices).

'regular'
start_dow int | None

Day-of-week on which the weekly aggregation cycle begins (0=Monday, 6=Sunday, ISO convention). Use 6 for MMWR Sunday-Saturday epiweeks or 0 for ISO Monday-Sunday weeks. Required when aggregation == "weekly"; ignored otherwise. Daily predictions are bucketed into weeks starting on start_dow and summed before scoring.

None

Raises:

Type Description
ValueError

If aggregation, reporting_schedule, or start_dow are invalid, or if a day-of-week effect is combined with aggregation == "weekly" (within-period structure is destroyed by aggregation).

Source code in pyrenew/observation/count_observations.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv
        Delay distribution PMF (must sum to ~1.0).
    noise
        Noise model for count observations (Poisson, NegBin, etc.).
    right_truncation_rv
        Optional reporting delay PMF for right-truncation adjustment.
        When provided (along with ``right_truncation_offset`` at sample
        time), predicted counts are scaled down for recent timepoints
        to account for incomplete reporting.
    day_of_week_rv : RandomVariable | None
        Optional day-of-week multiplicative effect. Must sample to
        shape (7,) with non-negative values, where entry j is the
        multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
        convention). An effect of 1.0 means no adjustment for that
        day. Values summing to 7.0 preserve weekly totals and keep
        the ascertainment rate interpretable; other sums rescale
        overall predicted counts. When provided (along with
        ``first_day_dow`` at sample time), predicted counts are
        scaled by a periodic weekly pattern.
    aggregation
        Observation reporting cadence; one of ``"daily"`` or
        ``"weekly"``. Controls only the scale on which the count
        likelihood is evaluated; it does not control how often the
        latent Rt temporal process samples new parameters.
    reporting_schedule
        Either ``"regular"`` (dense observation array, one entry
        per period, NaN for unobserved periods) or ``"irregular"``
        (sparse observation array with user-supplied period-end
        time indices).
    start_dow
        Day-of-week on which the weekly aggregation cycle begins
        (0=Monday, 6=Sunday, ISO convention). Use ``6`` for
        MMWR Sunday-Saturday epiweeks or ``0`` for ISO
        Monday-Sunday weeks. Required when
        ``aggregation == "weekly"``; ignored otherwise. Daily
        predictions are bucketed into weeks starting on
        ``start_dow`` and summed before scoring.

    Raises
    ------
    ValueError
        If ``aggregation``, ``reporting_schedule``, or ``start_dow``
        are invalid, or if a day-of-week effect is combined with
        ``aggregation == "weekly"`` (within-period structure is
        destroyed by aggregation).
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise
    self.right_truncation_rv = right_truncation_rv
    self.day_of_week_rv = day_of_week_rv
    self._validate_aggregation_start_dow(aggregation, start_dow)
    if reporting_schedule not in self._SUPPORTED_SCHEDULES:
        raise ValueError(
            f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, "
            f"got {reporting_schedule!r}"
        )
    if aggregation == "weekly" and day_of_week_rv is not None:
        raise ValueError(
            "day_of_week_rv cannot be combined with aggregation == 'weekly'; "
            "aggregation destroys within-period structure."
        )
    self.aggregation = aggregation
    self.reporting_schedule = reporting_schedule
    self.start_dow = start_dow

aggregation_period property

aggregation_period: int

Width of the observation reporting period in days.

Returns:

Type Description
int

1 for daily aggregation, 7 for weekly.

infection_resolution

infection_resolution() -> str

Return required infection resolution.

Returns:

Type Description
str

"aggregate" or "subpop".

Source code in pyrenew/observation/count_observations.py
182
183
184
185
186
187
188
189
190
191
def infection_resolution(self) -> str:
    """
    Return required infection resolution.

    Returns
    -------
    str
        "aggregate" or "subpop".
    """
    raise NotImplementedError("Subclasses must implement infection_resolution()")

lookback_days

lookback_days() -> int

Return required lookback days for this observation.

Delay PMFs are 0-indexed (delay can be 0), so a PMF of length L covers delays 0 to L-1, requiring L-1 initialization points.

Returns:

Type Description
int

Length of delay distribution PMF minus 1.

Source code in pyrenew/observation/count_observations.py
168
169
170
171
172
173
174
175
176
177
178
179
180
def lookback_days(self) -> int:
    """
    Return required lookback days for this observation.

    Delay PMFs are 0-indexed (delay can be 0), so a PMF of length L
    covers delays 0 to L-1, requiring L-1 initialization points.

    Returns
    -------
    int
        Length of delay distribution PMF minus 1.
    """
    return len(self.temporal_pmf_rv()) - 1

validate

validate() -> None

Validate observation parameters.

Raises:

Type Description
ValueError

If delay PMF invalid, ascertainment rate outside [0,1], or noise params invalid.

Source code in pyrenew/observation/count_observations.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
def validate(self) -> None:
    """
    Validate observation parameters.

    Raises
    ------
    ValueError
        If delay PMF invalid, ascertainment rate outside [0,1],
        or noise params invalid.
    """
    delay_pmf = self.temporal_pmf_rv()
    self._validate_pmf(delay_pmf, "delay_distribution_rv")

    ascertainment_rate = self.ascertainment_rate_rv()
    if jnp.any(ascertainment_rate < 0) or jnp.any(ascertainment_rate > 1):
        raise ValueError(
            "ascertainment_rate_rv must be in [0, 1], "
            "got value(s) outside this range"
        )

    self.noise.validate()

    if self.right_truncation_rv is not None:
        rt_pmf = self.right_truncation_rv()
        self._validate_pmf(rt_pmf, "right_truncation_rv")

    if self.day_of_week_rv is not None:
        dow_effect = self.day_of_week_rv()
        self._validate_dow_effect(dow_effect, "day_of_week_rv")

HierarchicalNormalNoise

HierarchicalNormalNoise(
    sensor_mode_rv: RandomVariable, sensor_sd_rv: RandomVariable
)

Bases: MeasurementNoise

Normal noise with hierarchical sensor-level effects.

Observation model: obs ~ Normal(predicted + sensor_mode, sensor_sd) where sensor_mode and sensor_sd are sampled per-sensor.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Notes

Use VectorizedVariable to wrap simple RVs that lack this interface.

Initialize hierarchical Normal noise.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Source code in pyrenew/observation/noise.py
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
def __init__(
    self,
    sensor_mode_rv: RandomVariable,
    sensor_sd_rv: RandomVariable,
) -> None:
    """
    Initialize hierarchical Normal noise.

    Parameters
    ----------
    sensor_mode_rv
        Prior for sensor-level modes.
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    sensor_sd_rv
        Prior for sensor-level SDs (must be > 0).
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    """
    self.sensor_mode_rv = sensor_mode_rv
    self.sensor_sd_rv = sensor_sd_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
301
302
303
304
305
306
307
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"HierarchicalNormalNoise("
        f"sensor_mode_rv={self.sensor_mode_rv!r}, "
        f"sensor_sd_rv={self.sensor_sd_rv!r})"
    )

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike

Sample from Normal distribution with sensor-level hierarchical effects.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values. Shape: (n_obs,)

required
obs ArrayLike | None

Observed measurements for conditioning. Shape: (n_obs,)

None
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of sensors.

required

Returns:

Type Description
ArrayLike

Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,)

Source code in pyrenew/observation/noise.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike:
    """
    Sample from Normal distribution with sensor-level hierarchical effects.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted measurement values.
        Shape: (n_obs,)
    obs
        Observed measurements for conditioning.
        Shape: (n_obs,)
    sensor_indices
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors
        Total number of sensors.

    Returns
    -------
    ArrayLike
        Normal distributed measurements with hierarchical sensor effects.
        Shape: (n_obs,)
    """
    sensor_mode = self.sensor_mode_rv(n_groups=n_sensors)
    sensor_sd = self.sensor_sd_rv(n_groups=n_sensors)

    loc = predicted + sensor_mode[sensor_indices]
    scale = sensor_sd[sensor_indices]

    return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs)

MeasurementNoise

Bases: ABC

Abstract base for continuous measurement noise models.

Defines how continuous observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike

Sample continuous observations given predicted values.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values.

required
obs ArrayLike | None

Observed measurements for conditioning, or None for prior sampling.

None
**kwargs object

Additional context (e.g., sensor indices).

{}

Returns:

Type Description
ArrayLike

Sampled or conditioned measurements, same shape as predicted.

Source code in pyrenew/observation/noise.py
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample continuous observations given predicted values.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted measurement values.
    obs
        Observed measurements for conditioning, or None for prior sampling.
    **kwargs
        Additional context (e.g., sensor indices).

    Returns
    -------
    ArrayLike
        Sampled or conditioned measurements, same shape as predicted.
    """
    pass  # pragma: no cover

MeasurementObservation

MeasurementObservation(
    name: str, temporal_pmf_rv: RandomVariable, noise: MeasurementNoise
)

Bases: BaseObservationProcess

Abstract base for continuous measurement observations.

Subclasses implement signal-specific transformations from infections to predicted measurement values, then add measurement noise.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics for wastewater).

required
noise MeasurementNoise

Noise model for continuous measurements (e.g., HierarchicalNormalNoise).

required
Notes

Subclasses must implement _predicted_obs() according to their specific signal processing (e.g., wastewater shedding kinetics, dilution factors, etc.).

See Also

pyrenew.observation.noise.HierarchicalNormalNoise : Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities

Initialize measurement observation base.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics).

required
noise MeasurementNoise

Noise model (e.g., HierarchicalNormalNoise with sensor effects).

required
Source code in pyrenew/observation/measurement_observations.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    name: str,
    temporal_pmf_rv: RandomVariable,
    noise: MeasurementNoise,
) -> None:
    """
    Initialize measurement observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    temporal_pmf_rv
        Temporal distribution PMF (e.g., shedding kinetics).
    noise
        Noise model (e.g., HierarchicalNormalNoise with sensor effects).
    """
    super().__init__(name=name, temporal_pmf_rv=temporal_pmf_rv)
    self.noise = noise

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/measurement_observations.py
71
72
73
74
75
76
77
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"{self.__class__.__name__}(name={self.name!r}, "
        f"temporal_pmf_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for measurement observations.

Measurement observations require subpopulation-level infections because each measurement corresponds to a specific catchment area.

Returns:

Type Description
str

"subpop"

Source code in pyrenew/observation/measurement_observations.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def infection_resolution(self) -> str:
    """
    Return "subpop" for measurement observations.

    Measurement observations require subpopulation-level infections
    because each measurement corresponds to a specific catchment area.

    Returns
    -------
    str
        ``"subpop"``
    """
    return "subpop"

lookback_days

lookback_days() -> int

Return required lookback days for this observation.

Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points.

Returns:

Type Description
int

Length of temporal PMF minus 1.

Source code in pyrenew/observation/measurement_observations.py
79
80
81
82
83
84
85
86
87
88
89
90
91
def lookback_days(self) -> int:
    """
    Return required lookback days for this observation.

    Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF
    of length L covers lags 0 to L-1, requiring L-1 initialization points.

    Returns
    -------
    int
        Length of temporal PMF minus 1.
    """
    return len(self.temporal_pmf_rv()) - 1

sample

sample(
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ObservationSample

Sample measurements from observed sensors.

Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.

Transforms infections to predicted values via signal-specific processing (_predicted_obs), then applies noise model.

Parameters:

Name Type Description Default
infections ArrayLike

Infections from the infection process. Shape: (n_total, n_subpops)

required
times ArrayLike

Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,)

required
subpop_indices ArrayLike

Subpopulation index for each observation (0-indexed). Shape: (n_obs,)

required
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of measurement sensors.

required
obs ArrayLike | None

Observed measurements (n_obs,), or None for prior sampling.

None
**kwargs object

Additional keyword arguments forwarded by the model dispatch (e.g., first_day_dow); ignored here because measurement observations index the shared axis directly.

{}

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned measurements) and predicted (predicted values before noise, shape: n_total x n_subpops).

Source code in pyrenew/observation/measurement_observations.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def sample(
    self,
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ObservationSample:
    """
    Sample measurements from observed sensors.

    Times are on the shared time axis [0, n_total) where
    n_total = n_init + n_days. This method performs direct indexing
    without any offset adjustment.

    Transforms infections to predicted values via signal-specific processing
    (``_predicted_obs``), then applies noise model.

    Parameters
    ----------
    infections
        Infections from the infection process.
        Shape: (n_total, n_subpops)
    times
        Day index for each observation on the shared time axis.
        Must be in range [0, n_total). Shape: (n_obs,)
    subpop_indices
        Subpopulation index for each observation (0-indexed).
        Shape: (n_obs,)
    sensor_indices
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors
        Total number of measurement sensors.
    obs
        Observed measurements (n_obs,), or None for prior sampling.
    **kwargs
        Additional keyword arguments forwarded by the model
        dispatch (e.g., ``first_day_dow``); ignored here because
        measurement observations index the shared axis directly.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned measurements) and
        `predicted` (predicted values before noise, shape: n_total x n_subpops).
    """
    predicted_values = self._predicted_obs(infections)
    self._deterministic("predicted", predicted_values)

    # Direct indexing on shared time axis - no offset needed
    predicted_obs = predicted_values[times, subpop_indices]

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=predicted_obs,
        obs=obs,
        sensor_indices=sensor_indices,
        n_sensors=n_sensors,
    )

    return ObservationSample(observed=observed, predicted=predicted_values)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None

Validate measurement observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
times ArrayLike | None

Day index for each observation on the shared time axis.

None
subpop_indices ArrayLike | None

Subpopulation index for each observation (0-indexed).

None
sensor_indices ArrayLike | None

Sensor index for each observation (0-indexed).

None
n_sensors int | None

Total number of measurement sensors.

None
obs ArrayLike | None

Observed measurements (n_obs,).

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If times, subpop_indices, or sensor_indices are out of bounds, or if obs and times have mismatched lengths.

Source code in pyrenew/observation/measurement_observations.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None:
    """
    Validate measurement observation data.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations.
    times
        Day index for each observation on the shared time axis.
    subpop_indices
        Subpopulation index for each observation (0-indexed).
    sensor_indices
        Sensor index for each observation (0-indexed).
    n_sensors
        Total number of measurement sensors.
    obs
        Observed measurements (n_obs,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If times, subpop_indices, or sensor_indices are out of bounds,
        or if obs and times have mismatched lengths.
    """
    if times is not None:
        self._validate_times(times, n_total)
        if obs is not None:
            self._validate_shapes_match(obs, times, "obs", "times")
    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)
    if sensor_indices is not None and n_sensors is not None:
        self._validate_index_array(sensor_indices, n_sensors, "sensor_indices")

NegativeBinomialNoise

NegativeBinomialNoise(concentration_rv: RandomVariable)

Bases: CountNoise

Negative Binomial noise for overdispersed counts (variance > mean).

Uses NB2 parameterization. Higher concentration reduces overdispersion.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Notes

The NB2 parameterization has variance = mean + mean^2 / concentration. As concentration -> infinity, this approaches Poisson.

Initialize Negative Binomial noise.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Source code in pyrenew/observation/noise.py
151
152
153
154
155
156
157
158
159
160
161
def __init__(self, concentration_rv: RandomVariable) -> None:
    """
    Initialize Negative Binomial noise.

    Parameters
    ----------
    concentration_rv
        Concentration parameter (must be > 0).
        Higher values reduce overdispersion.
    """
    self.concentration_rv = concentration_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
163
164
165
def __repr__(self) -> str:
    """Return string representation."""
    return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Negative Binomial distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Negative Binomial-distributed counts.

Source code in pyrenew/observation/noise.py
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Negative Binomial distribution.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values.
    obs
        Observed counts for conditioning.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Negative Binomial-distributed counts.
    """
    concentration = self.concentration_rv()
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.NegativeBinomial2(
                mean=predicted,
                concentration=concentration,
            ),
            obs=obs,
        )

validate_concentration_rv

validate_concentration_rv() -> None

Validate concentration is positive.

Raises:

Type Description
ValueError

If concentration <= 0.

Source code in pyrenew/observation/noise.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
def validate_concentration_rv(self) -> None:
    """
    Validate concentration is positive.

    Raises
    ------
    ValueError
        If concentration <= 0.
    """
    concentration = self.concentration_rv()
    if jnp.any(concentration <= 0):
        raise ValueError(
            f"NegativeBinomialNoise: concentration must be positive, "
            f"got {float(concentration)}"
        )

NegativeBinomialObservation

NegativeBinomialObservation(name: str, concentration_rv: RandomVariable)

Bases: RandomVariable

Negative Binomial observation

Default constructor

Parameters:

Name Type Description Default
name str

Name for the numpyro variable.

required
concentration_rv RandomVariable

Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion.

required

Returns:

Type Description
None
Source code in pyrenew/observation/negativebinomial.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    name: str,
    concentration_rv: RandomVariable,
) -> None:
    """
    Default constructor

    Parameters
    ----------
    name
        Name for the numpyro variable.
    concentration_rv
        Random variable from which to sample the positive concentration
        parameter of the negative binomial. This parameter is sometimes
        called k, phi, or the "dispersion" or "overdispersion" parameter,
        despite the fact that larger values imply that the distribution
        becomes more Poissonian, while smaller ones imply a greater degree
        of dispersion.

    Returns
    -------
    None
    """
    super().__init__(name=name)
    self.concentration_rv = concentration_rv

sample

sample(
    mu: ArrayLike, obs: ArrayLike | None = None, **kwargs: object
) -> ArrayLike

Sample from the negative binomial distribution

Parameters:

Name Type Description Default
mu ArrayLike

Mean parameter of the negative binomial distribution.

required
obs ArrayLike | None

Observed data, by default None.

None
**kwargs object

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
ArrayLike
Source code in pyrenew/observation/negativebinomial.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def sample(
    self,
    mu: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample from the negative binomial distribution

    Parameters
    ----------
    mu
        Mean parameter of the negative binomial distribution.
    obs
        Observed data, by default None.
    **kwargs
        Additional keyword arguments passed through to internal sample calls, should there be any.

    Returns
    -------
    ArrayLike
    """
    concentration = self.concentration_rv.sample()

    negative_binomial_sample = numpyro.sample(
        name=self.name,
        fn=dist.NegativeBinomial2(
            mean=mu,
            concentration=concentration,
        ),
        obs=obs,
    )

    return negative_binomial_sample

ObservationSample

Bases: NamedTuple

Return type for observation process sample() methods.

Attributes:

Name Type Description
observed ArrayLike

Sampled or conditioned observations. Shape depends on the observation process and indexing.

predicted ArrayLike

Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks.

PoissonNoise

PoissonNoise()

Bases: CountNoise

Poisson noise for equidispersed counts (variance = mean).

Initialize Poisson noise (no parameters).

Source code in pyrenew/observation/noise.py
83
84
85
def __init__(self) -> None:
    """Initialize Poisson noise (no parameters)."""
    pass

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
87
88
89
def __repr__(self) -> str:
    """Return string representation."""
    return "PoissonNoise()"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Poisson distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Poisson-distributed counts.

Source code in pyrenew/observation/noise.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Poisson distribution.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values.
    obs
        Observed counts for conditioning.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Poisson-distributed counts.
    """
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.Poisson(rate=predicted),
            obs=obs,
        )

validate staticmethod

validate() -> None

PoissonNoise always passes validation.

Source code in pyrenew/observation/noise.py
91
92
93
94
95
96
@staticmethod
def validate() -> None:
    """
    PoissonNoise always passes validation.
    """
    pass

PopulationCounts

PopulationCounts(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
)

Bases: CountObservation

Aggregated count observation.

Maps aggregate infections to counts through ascertainment x delay convolution with composable noise model. Predictions are constructed on the daily model axis; aggregation_period controls whether those predictions are scored as daily counts or summed to weekly counts before the likelihood.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names (e.g., "hospital" produces sites "hospital_obs", "hospital_predicted").

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1] (e.g., IHR, IER).

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv
        Delay distribution PMF (must sum to ~1.0).
    noise
        Noise model for count observations (Poisson, NegBin, etc.).
    right_truncation_rv
        Optional reporting delay PMF for right-truncation adjustment.
        When provided (along with ``right_truncation_offset`` at sample
        time), predicted counts are scaled down for recent timepoints
        to account for incomplete reporting.
    day_of_week_rv : RandomVariable | None
        Optional day-of-week multiplicative effect. Must sample to
        shape (7,) with non-negative values, where entry j is the
        multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
        convention). An effect of 1.0 means no adjustment for that
        day. Values summing to 7.0 preserve weekly totals and keep
        the ascertainment rate interpretable; other sums rescale
        overall predicted counts. When provided (along with
        ``first_day_dow`` at sample time), predicted counts are
        scaled by a periodic weekly pattern.
    aggregation
        Observation reporting cadence; one of ``"daily"`` or
        ``"weekly"``. Controls only the scale on which the count
        likelihood is evaluated; it does not control how often the
        latent Rt temporal process samples new parameters.
    reporting_schedule
        Either ``"regular"`` (dense observation array, one entry
        per period, NaN for unobserved periods) or ``"irregular"``
        (sparse observation array with user-supplied period-end
        time indices).
    start_dow
        Day-of-week on which the weekly aggregation cycle begins
        (0=Monday, 6=Sunday, ISO convention). Use ``6`` for
        MMWR Sunday-Saturday epiweeks or ``0`` for ISO
        Monday-Sunday weeks. Required when
        ``aggregation == "weekly"``; ignored otherwise. Daily
        predictions are bucketed into weeks starting on
        ``start_dow`` and summed before scoring.

    Raises
    ------
    ValueError
        If ``aggregation``, ``reporting_schedule``, or ``start_dow``
        are invalid, or if a day-of-week effect is combined with
        ``aggregation == "weekly"`` (within-period structure is
        destroyed by aggregation).
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise
    self.right_truncation_rv = right_truncation_rv
    self.day_of_week_rv = day_of_week_rv
    self._validate_aggregation_start_dow(aggregation, start_dow)
    if reporting_schedule not in self._SUPPORTED_SCHEDULES:
        raise ValueError(
            f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, "
            f"got {reporting_schedule!r}"
        )
    if aggregation == "weekly" and day_of_week_rv is not None:
        raise ValueError(
            "day_of_week_rv cannot be combined with aggregation == 'weekly'; "
            "aggregation destroys within-period structure."
        )
    self.aggregation = aggregation
    self.reporting_schedule = reporting_schedule
    self.start_dow = start_dow

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
563
564
565
566
567
568
569
570
571
572
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"PopulationCounts(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r}, "
        f"right_truncation_rv={self.right_truncation_rv!r}, "
        f"day_of_week_rv={self.day_of_week_rv!r})"
    )

infection_resolution

infection_resolution() -> str

Return "aggregate" for aggregated observations.

Returns:

Type Description
str

The string "aggregate".

Source code in pyrenew/observation/count_observations.py
552
553
554
555
556
557
558
559
560
561
def infection_resolution(self) -> str:
    """
    Return "aggregate" for aggregated observations.

    Returns
    -------
    str
        The string "aggregate".
    """
    return "aggregate"

sample

sample(
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
    period_end_times: ArrayLike | None = None,
) -> ObservationSample

Sample aggregated counts.

Daily transforms (right-truncation, day-of-week) run on the daily axis. When aggregation == "weekly" the daily predictions are summed onto the reporting-period grid before the noise model. Likelihood path depends on reporting_schedule: "regular" uses a dense-with-NaN array plus a mask; "irregular" fancy-indexes the aggregated array at period indices derived from period_end_times.

aggregation_period describes the observation scale only. The latent infection process may use daily or coarser Rt parameter cadence, but by the time this method is called it supplies infections on the daily model axis.

Parameters:

Name Type Description Default
infections ArrayLike

Aggregate infections from the infection process. Shape (n_total,).

required
obs ArrayLike | None

Observed counts. Shape depends on reporting_schedule: "regular" expects a dense array on the period grid with NaN for unobserved periods; "irregular" expects an array of the same length as period_end_times. None for prior predictive sampling.

None
right_truncation_offset int | None

If provided (and right_truncation_rv was set at construction), apply right-truncation adjustment to the daily predictions.

None
first_day_dow int | None

Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when day_of_week_rv was set at construction or when aggregation == "weekly". This aligns observation-level day-of-week effects or weekly aggregation to the shared daily model axis.

None
period_end_times ArrayLike | None

Daily-axis indices of each observed period's final day. Required when reporting_schedule == "irregular".

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predictions on the reporting-period grid; equal to daily predictions when aggregation == "daily").

Source code in pyrenew/observation/count_observations.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
def sample(
    self,
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
    period_end_times: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample aggregated counts.

    Daily transforms (right-truncation, day-of-week) run on the
    daily axis. When ``aggregation == "weekly"`` the daily
    predictions are summed onto the reporting-period grid before
    the noise model. Likelihood path depends on
    ``reporting_schedule``: ``"regular"`` uses a dense-with-NaN
    array plus a mask; ``"irregular"`` fancy-indexes the
    aggregated array at period indices derived from
    ``period_end_times``.

    ``aggregation_period`` describes the observation scale only. The
    latent infection process may use daily or coarser Rt parameter
    cadence, but by the time this method is called it supplies infections
    on the daily model axis.

    Parameters
    ----------
    infections
        Aggregate infections from the infection process.
        Shape ``(n_total,)``.
    obs
        Observed counts. Shape depends on ``reporting_schedule``:
        ``"regular"`` expects a dense array on the period grid
        with NaN for unobserved periods; ``"irregular"`` expects
        an array of the same length as ``period_end_times``.
        ``None`` for prior predictive sampling.
    right_truncation_offset
        If provided (and ``right_truncation_rv`` was set at
        construction), apply right-truncation adjustment to the
        daily predictions.
    first_day_dow
        Day-of-week index of the first timepoint on the shared
        time axis (0=Monday, 6=Sunday, ISO convention). Required
        when ``day_of_week_rv`` was set at construction or when
        ``aggregation == "weekly"``. This aligns observation-level
        day-of-week effects or weekly aggregation to the shared daily
        model axis.
    period_end_times
        Daily-axis indices of each observed period's final day.
        Required when ``reporting_schedule == "irregular"``.

    Returns
    -------
    ObservationSample
        Named tuple with ``observed`` (sampled/conditioned counts)
        and ``predicted`` (predictions on the reporting-period
        grid; equal to daily predictions when
        ``aggregation == "daily"``).
    """
    predicted = self._compute_predicted(
        infections, first_day_dow, right_truncation_offset
    )

    if self.reporting_schedule == "regular":
        observed = self._score_masked(predicted, obs)
    else:
        if period_end_times is None:
            raise ValueError(
                f"Observation '{self.name}': period_end_times is "
                f"required when reporting_schedule == 'irregular'"
            )
        period_idx = self._period_indices(period_end_times, first_day_dow)
        observed = self.noise.sample(
            name=self._sample_site_name("obs"),
            predicted=predicted[period_idx],
            obs=obs,
        )

    return ObservationSample(observed=observed, predicted=predicted)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    period_end_times: ArrayLike | None = None,
    first_day_dow: int | None = None,
    **kwargs: object,
) -> None

Validate aggregated count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of daily time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations (unused for aggregate observations).

required
obs ArrayLike | None

Observed counts. Shape depends on reporting_schedule: "regular" expects a dense array of length n_total // P after front-trim, with NaN for unobserved periods; "irregular" expects an array matching period_end_times.

None
period_end_times ArrayLike | None

Daily-axis indices of each observed period's final day. Required for reporting_schedule="irregular".

None
first_day_dow int | None

Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when aggregation == "weekly" so weekly observation periods can be aligned to the shared daily model axis.

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If obs length or period_end_times fail their respective checks, or if first_day_dow is missing when aggregation == "weekly".

Source code in pyrenew/observation/count_observations.py
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    period_end_times: ArrayLike | None = None,
    first_day_dow: int | None = None,
    **kwargs: object,
) -> None:
    """
    Validate aggregated count observation data.

    Parameters
    ----------
    n_total
        Total number of daily time steps (``n_init + n_days_post_init``).
    n_subpops
        Number of subpopulations (unused for aggregate observations).
    obs
        Observed counts. Shape depends on ``reporting_schedule``:
        ``"regular"`` expects a dense array of length ``n_total // P``
        after front-trim, with NaN for unobserved periods;
        ``"irregular"`` expects an array matching ``period_end_times``.
    period_end_times
        Daily-axis indices of each observed period's final day. Required
        for ``reporting_schedule="irregular"``.
    first_day_dow
        Day-of-week index of element 0 of the shared time axis
        (0=Monday, 6=Sunday, ISO convention). Required when
        ``aggregation == "weekly"`` so weekly observation periods can be
        aligned to the shared daily model axis.
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If obs length or period_end_times fail their respective
        checks, or if ``first_day_dow`` is missing when
        ``aggregation == "weekly"``.
    """
    if self.reporting_schedule == "regular":
        if obs is None:
            return
        if self.aggregation == "daily":
            self._validate_obs_dense(obs, n_total)
            return
        n_periods = self._n_periods(n_total, first_day_dow)
        obs = jnp.asarray(obs)
        if obs.ndim != 1:
            raise ValueError(
                f"Observation '{self.name}': obs must be 1D, got shape {obs.shape}"
            )
        if obs.shape[0] != n_periods:
            raise ValueError(
                f"Observation '{self.name}': obs length {obs.shape[0]} "
                f"must equal n_periods ({n_periods}). "
                f"Pad with NaN for unobserved periods."
            )
        return

    if period_end_times is None:
        if obs is None:
            return
        raise ValueError(
            f"Observation '{self.name}': period_end_times is required "
            f"when reporting_schedule='irregular' and obs is provided."
        )
    offset = self._compute_period_offset(first_day_dow, self.start_dow)
    self._validate_period_end_times(
        period_end_times, n_total, offset, self.aggregation_period
    )
    if obs is not None:
        self._validate_shapes_match(
            obs, period_end_times, "obs", "period_end_times"
        )

SubpopulationCounts

SubpopulationCounts(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
)

Bases: CountObservation

Subpopulation-level count observation.

Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model. Predictions are constructed on the daily model axis for each subpopulation; aggregation_period controls whether those predictions are scored as daily counts or summed to weekly counts before the likelihood.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1].

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
    aggregation: Literal["daily", "weekly"] = "daily",
    reporting_schedule: Literal["regular", "irregular"] = "regular",
    start_dow: int | None = None,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv
        Delay distribution PMF (must sum to ~1.0).
    noise
        Noise model for count observations (Poisson, NegBin, etc.).
    right_truncation_rv
        Optional reporting delay PMF for right-truncation adjustment.
        When provided (along with ``right_truncation_offset`` at sample
        time), predicted counts are scaled down for recent timepoints
        to account for incomplete reporting.
    day_of_week_rv : RandomVariable | None
        Optional day-of-week multiplicative effect. Must sample to
        shape (7,) with non-negative values, where entry j is the
        multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
        convention). An effect of 1.0 means no adjustment for that
        day. Values summing to 7.0 preserve weekly totals and keep
        the ascertainment rate interpretable; other sums rescale
        overall predicted counts. When provided (along with
        ``first_day_dow`` at sample time), predicted counts are
        scaled by a periodic weekly pattern.
    aggregation
        Observation reporting cadence; one of ``"daily"`` or
        ``"weekly"``. Controls only the scale on which the count
        likelihood is evaluated; it does not control how often the
        latent Rt temporal process samples new parameters.
    reporting_schedule
        Either ``"regular"`` (dense observation array, one entry
        per period, NaN for unobserved periods) or ``"irregular"``
        (sparse observation array with user-supplied period-end
        time indices).
    start_dow
        Day-of-week on which the weekly aggregation cycle begins
        (0=Monday, 6=Sunday, ISO convention). Use ``6`` for
        MMWR Sunday-Saturday epiweeks or ``0`` for ISO
        Monday-Sunday weeks. Required when
        ``aggregation == "weekly"``; ignored otherwise. Daily
        predictions are bucketed into weeks starting on
        ``start_dow`` and summed before scoring.

    Raises
    ------
    ValueError
        If ``aggregation``, ``reporting_schedule``, or ``start_dow``
        are invalid, or if a day-of-week effect is combined with
        ``aggregation == "weekly"`` (within-period structure is
        destroyed by aggregation).
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise
    self.right_truncation_rv = right_truncation_rv
    self.day_of_week_rv = day_of_week_rv
    self._validate_aggregation_start_dow(aggregation, start_dow)
    if reporting_schedule not in self._SUPPORTED_SCHEDULES:
        raise ValueError(
            f"reporting_schedule must be one of {self._SUPPORTED_SCHEDULES}, "
            f"got {reporting_schedule!r}"
        )
    if aggregation == "weekly" and day_of_week_rv is not None:
        raise ValueError(
            "day_of_week_rv cannot be combined with aggregation == 'weekly'; "
            "aggregation destroys within-period structure."
        )
    self.aggregation = aggregation
    self.reporting_schedule = reporting_schedule
    self.start_dow = start_dow

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
756
757
758
759
760
761
762
763
764
765
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"SubpopulationCounts(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r}, "
        f"right_truncation_rv={self.right_truncation_rv!r}, "
        f"day_of_week_rv={self.day_of_week_rv!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for subpopulation-level observations.

Returns:

Type Description
str

The string "subpop".

Source code in pyrenew/observation/count_observations.py
767
768
769
770
771
772
773
774
775
776
def infection_resolution(self) -> str:
    """
    Return "subpop" for subpopulation-level observations.

    Returns
    -------
    str
        The string "subpop".
    """
    return "subpop"

sample

sample(
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
    period_end_times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
) -> ObservationSample

Sample subpopulation-level counts.

Daily transforms (right-truncation, day-of-week) run on the daily axis. When aggregation == "weekly" the daily predictions are summed onto the reporting-period grid before the noise model. Likelihood path depends on reporting_schedule: "regular" selects the observed subpopulation columns and uses a dense-with-NaN array plus a mask; "irregular" fancy-indexes the aggregated array at period indices derived from period_end_times.

aggregation_period describes the observation scale only. The latent infection process may use daily or coarser Rt parameter cadence, but by the time this method is called it supplies infections on the daily model axis.

Parameters:

Name Type Description Default
infections ArrayLike

Subpopulation-level infections from the infection process. Shape (n_total, n_subpops).

required
obs ArrayLike | None

Observed counts. For reporting_schedule="regular": shape (n_periods, n_observed_subpops) with NaN for unobserved periods. For reporting_schedule="irregular": shape (n_obs,) matching period_end_times and subpop_indices. None for prior predictive sampling.

None
right_truncation_offset int | None

If provided (and right_truncation_rv was set at construction), apply right-truncation adjustment to the daily predictions.

None
first_day_dow int | None

Day-of-week index of the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when day_of_week_rv was set at construction or when aggregation == "weekly". This aligns observation-level day-of-week effects or weekly aggregation to the shared daily model axis.

None
period_end_times ArrayLike | None

Daily-axis indices of each observed period's final day. Required when reporting_schedule == "irregular".

None
subpop_indices ArrayLike | None

Subpopulation indices (0-indexed). Required. For reporting_schedule="regular": shape (n_observed_subpops,) selecting which subpopulation columns of the aggregated array enter the likelihood. For reporting_schedule="irregular": shape (n_obs,) with one subpopulation per observation.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predictions on the reporting-period grid, shape (n_periods, n_subpops); equal to daily predictions when aggregation == "daily").

Source code in pyrenew/observation/count_observations.py
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
def sample(
    self,
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
    period_end_times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample subpopulation-level counts.

    Daily transforms (right-truncation, day-of-week) run on the
    daily axis. When ``aggregation == "weekly"`` the daily
    predictions are summed onto the reporting-period grid before
    the noise model. Likelihood path depends on
    ``reporting_schedule``: ``"regular"`` selects the observed
    subpopulation columns and uses a dense-with-NaN array plus
    a mask; ``"irregular"`` fancy-indexes the aggregated array
    at period indices derived from ``period_end_times``.

    ``aggregation_period`` describes the observation scale only. The
    latent infection process may use daily or coarser Rt parameter
    cadence, but by the time this method is called it supplies infections
    on the daily model axis.

    Parameters
    ----------
    infections
        Subpopulation-level infections from the infection process.
        Shape ``(n_total, n_subpops)``.
    obs
        Observed counts. For ``reporting_schedule="regular"``:
        shape ``(n_periods, n_observed_subpops)`` with NaN for
        unobserved periods. For
        ``reporting_schedule="irregular"``: shape ``(n_obs,)``
        matching ``period_end_times`` and ``subpop_indices``.
        ``None`` for prior predictive sampling.
    right_truncation_offset
        If provided (and ``right_truncation_rv`` was set at
        construction), apply right-truncation adjustment to the
        daily predictions.
    first_day_dow
        Day-of-week index of the first timepoint on the shared
        time axis (0=Monday, 6=Sunday, ISO convention). Required
        when ``day_of_week_rv`` was set at construction or when
        ``aggregation == "weekly"``. This aligns observation-level
        day-of-week effects or weekly aggregation to the shared daily
        model axis.
    period_end_times
        Daily-axis indices of each observed period's final day.
        Required when ``reporting_schedule == "irregular"``.
    subpop_indices
        Subpopulation indices (0-indexed). Required. For
        ``reporting_schedule="regular"``: shape
        ``(n_observed_subpops,)`` selecting which subpopulation
        columns of the aggregated array enter the likelihood.
        For ``reporting_schedule="irregular"``: shape ``(n_obs,)``
        with one subpopulation per observation.

    Returns
    -------
    ObservationSample
        Named tuple with ``observed`` (sampled/conditioned counts)
        and ``predicted`` (predictions on the reporting-period
        grid, shape ``(n_periods, n_subpops)``; equal to daily
        predictions when ``aggregation == "daily"``).
    """
    if subpop_indices is None:
        raise ValueError(f"Observation '{self.name}': subpop_indices is required.")

    predicted = self._compute_predicted(
        infections, first_day_dow, right_truncation_offset
    )

    if self.reporting_schedule == "regular":
        observed = self._score_masked(predicted[:, subpop_indices], obs)
    else:
        if period_end_times is None:
            raise ValueError(
                f"Observation '{self.name}': period_end_times is "
                f"required when reporting_schedule == 'irregular'"
            )
        period_idx = self._period_indices(period_end_times, first_day_dow)
        observed = self.noise.sample(
            name=self._sample_site_name("obs"),
            predicted=predicted[period_idx, subpop_indices],
            obs=obs,
        )

    return ObservationSample(observed=observed, predicted=predicted)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    period_end_times: ArrayLike | None = None,
    first_day_dow: int | None = None,
    subpop_indices: ArrayLike | None = None,
    **kwargs: object,
) -> None

Validate subpopulation-level count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of daily time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
obs ArrayLike | None

Observed counts. For reporting_schedule="regular" has shape (n_periods, n_observed_subpops) with NaN for unobserved periods. For reporting_schedule="irregular" has shape (n_obs,) matching period_end_times and subpop_indices.

None
period_end_times ArrayLike | None

Daily-axis indices of each observed period's final day. Required for reporting_schedule="irregular".

None
first_day_dow int | None

Day-of-week index of element 0 of the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when aggregation == "weekly" so weekly observation periods can be aligned to the shared daily model axis.

None
subpop_indices ArrayLike | None

Subpopulation indices (0-indexed). For reporting_schedule="regular": shape (n_observed_subpops,) selecting which subpopulation columns appear in obs. For reporting_schedule="irregular": shape (n_obs,) with one subpopulation per observation.

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If any index array is out of bounds, any shape check fails, subpop_indices is missing when obs is provided, or first_day_dow is missing when aggregation == "weekly".

Source code in pyrenew/observation/count_observations.py
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    period_end_times: ArrayLike | None = None,
    first_day_dow: int | None = None,
    subpop_indices: ArrayLike | None = None,
    **kwargs: object,
) -> None:
    """
    Validate subpopulation-level count observation data.

    Parameters
    ----------
    n_total
        Total number of daily time steps (``n_init + n_days_post_init``).
    n_subpops
        Number of subpopulations.
    obs
        Observed counts. For ``reporting_schedule="regular"``
        has shape ``(n_periods, n_observed_subpops)`` with NaN
        for unobserved periods. For
        ``reporting_schedule="irregular"`` has shape ``(n_obs,)``
        matching ``period_end_times`` and ``subpop_indices``.
    period_end_times
        Daily-axis indices of each observed period's final day.
        Required for ``reporting_schedule="irregular"``.
    first_day_dow
        Day-of-week index of element 0 of the shared time axis
        (0=Monday, 6=Sunday, ISO convention). Required when
        ``aggregation == "weekly"`` so weekly observation periods can be
        aligned to the shared daily model axis.
    subpop_indices
        Subpopulation indices (0-indexed). For
        ``reporting_schedule="regular"``: shape
        ``(n_observed_subpops,)`` selecting which subpopulation
        columns appear in ``obs``. For
        ``reporting_schedule="irregular"``: shape ``(n_obs,)``
        with one subpopulation per observation.
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If any index array is out of bounds, any shape check
        fails, ``subpop_indices`` is missing when ``obs`` is
        provided, or ``first_day_dow`` is missing when
        ``aggregation == "weekly"``.
    """
    if obs is not None and subpop_indices is None:
        raise ValueError(
            f"Observation '{self.name}': subpop_indices is required "
            f"when obs is provided."
        )

    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)

    if self.reporting_schedule == "regular":
        if obs is None:
            return
        n_periods = self._n_periods(n_total, first_day_dow)
        obs = jnp.asarray(obs)
        if obs.ndim != 2:
            raise ValueError(
                f"Observation '{self.name}': regular-schedule obs must "
                f"be 2D (n_periods, n_observed_subpops); got shape {obs.shape}"
            )
        if obs.shape[0] != n_periods:
            raise ValueError(
                f"Observation '{self.name}': obs dimension 0 length "
                f"{obs.shape[0]} must equal n_periods ({n_periods}). "
                f"Pad with NaN for unobserved periods."
            )
        if subpop_indices is not None:
            n_observed = jnp.asarray(subpop_indices).shape[0]
            if obs.shape[1] != n_observed:
                raise ValueError(
                    f"Observation '{self.name}': obs dimension 1 length "
                    f"{obs.shape[1]} must equal len(subpop_indices) "
                    f"({n_observed})"
                )
        return

    if period_end_times is None:
        if obs is None:
            return
        raise ValueError(
            f"Observation '{self.name}': period_end_times is required "
            f"when reporting_schedule='irregular' and obs is provided."
        )
    offset = self._compute_period_offset(first_day_dow, self.start_dow)
    self._validate_period_end_times(
        period_end_times, n_total, offset, self.aggregation_period
    )
    if obs is not None:
        self._validate_shapes_match(
            obs, period_end_times, "obs", "period_end_times"
        )
    if subpop_indices is not None:
        self._validate_shapes_match(
            subpop_indices,
            period_end_times,
            "subpop_indices",
            "period_end_times",
        )