Skip to content

Observation

Observation processes for connecting infections to observed data.

BaseObservationProcess is the abstract base. Concrete subclasses:

  • Counts: Aggregate counts (admissions, deaths)
  • CountsBySubpop: Subpopulation-level counts
  • Measurements: Continuous subpopulation-level signals (e.g., wastewater)

All observation processes implement:

  • sample(): Sample observations given infections
  • infection_resolution(): returns "aggregate" or "subpop"
  • lookback_days(): returns required infection history length

Noise models (CountNoise, MeasurementNoise) are composable—pass them to observation constructors to control the output distribution.

BaseObservationProcess

BaseObservationProcess(name: str, temporal_pmf_rv: RandomVariable)

Bases: RandomVariable

Abstract base class for observation processes that use convolution with temporal distributions.

This class provides common functionality for connecting infections to observed data (e.g., hospital admissions, wastewater concentrations) through temporal convolution operations.

Key features provided:

  • PMF validation (sum to 1, non-negative)
  • Minimum observation day calculation
  • Convolution wrapper with timeline alignment
  • Deterministic quantity tracking

Subclasses must implement:

  • validate(): Validate parameters (call _validate_pmf() for PMFs)
  • lookback_days(): Return PMF length for initialization
  • infection_resolution(): Return "aggregate" or "subpop"
  • _predicted_obs(): Transform infections to predicted values
  • sample(): Apply noise model to predicted observations
Notes

Computing predicted observations on day t requires infection history from previous days (determined by the temporal PMF length). The first len(pmf) - 1 days have insufficient history and return NaN.

See Also

pyrenew.convolve.compute_delay_ascertained_incidence : Underlying convolution function pyrenew.metaclass.RandomVariable : Base class for all random variables

Initialize base observation process.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names, enabling multiple observations of the same type in a single model.

required
temporal_pmf_rv RandomVariable

The temporal distribution PMF (e.g., delay or shedding distribution). Must sample to a 1D array that sums to ~1.0 with non-negative values. Subclasses may have additional parameters.

required
Notes

Subclasses should call super().__init__(name, temporal_pmf_rv) in their constructors and may add additional parameters.

Source code in pyrenew/observation/base.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(self, name: str, temporal_pmf_rv: RandomVariable) -> None:
    """
    Initialize base observation process.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names, enabling multiple
        observations of the same type in a single model.
    temporal_pmf_rv
        The temporal distribution PMF (e.g., delay or shedding distribution).
        Must sample to a 1D array that sums to ~1.0 with non-negative values.
        Subclasses may have additional parameters.

    Notes
    -----
    Subclasses should call ``super().__init__(name, temporal_pmf_rv)``
    in their constructors and may add additional parameters.
    """
    super().__init__(name=name)
    self.temporal_pmf_rv = temporal_pmf_rv

infection_resolution abstractmethod

infection_resolution() -> str

Return whether this observation uses aggregate or subpop infections.

Returns one of:

  • "aggregate": Uses a single aggregated infection trajectory. Shape: (n_days,)
  • "subpop": Uses subpopulation-level infection trajectories. Shape: (n_days, n_subpops), indexed via subpop_indices.

Returns:

Type Description
str

Either "aggregate" or "subpop"

Notes

This is used by multi-signal models to route the correct infection output to each observation process.

Source code in pyrenew/observation/base.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@abstractmethod
def infection_resolution(self) -> str:
    """
    Return whether this observation uses aggregate or subpop infections.

    Returns one of:

    - ``"aggregate"``: Uses a single aggregated infection trajectory.
      Shape: ``(n_days,)``
    - ``"subpop"``: Uses subpopulation-level infection trajectories.
      Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``.

    Returns
    -------
    str
        Either ``"aggregate"`` or ``"subpop"``

    Notes
    -----
    This is used by multi-signal models to route the correct infection
    output to each observation process.
    """
    pass  # pragma: no cover

lookback_days abstractmethod

lookback_days() -> int

Return the number of days this observation process needs to look back.

This determines the minimum n_initialization_points required by the latent process when this observation is included in a multi-signal model.

Returns:

Type Description
int

Number of days of infection history required.

Notes

Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points. Implementations should return len(pmf) - 1.

This is used by model builders to automatically compute n_initialization_points as max(gen_int_length, max(all lookbacks)).

Source code in pyrenew/observation/base.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@abstractmethod
def lookback_days(self) -> int:
    """
    Return the number of days this observation process needs to look back.

    This determines the minimum n_initialization_points required by the
    latent process when this observation is included in a multi-signal model.

    Returns
    -------
    int
        Number of days of infection history required.

    Notes
    -----
    Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a
    PMF of length L covers lags 0 to L-1, requiring L-1 initialization
    points. Implementations should return ``len(pmf) - 1``.

    This is used by model builders to automatically compute
    n_initialization_points as ``max(gen_int_length, max(all lookbacks))``.
    """
    pass  # pragma: no cover

sample abstractmethod

sample(obs: ArrayLike | None = None, **kwargs: object) -> ArrayLike

Sample from the observation process.

Subclasses must implement this method to define the specific observation model. Typically calls _predicted_obs first, then applies the noise model.

Parameters:

Name Type Description Default
obs ArrayLike | None

Observed data for conditioning, or None for prior predictive sampling.

None
**kwargs object

Subclass-specific parameters (e.g., infections from the infection process).

{}

Returns:

Type Description
ArrayLike

Observed or sampled values from the observation process.

Source code in pyrenew/observation/base.py
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
@abstractmethod
def sample(
    self,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample from the observation process.

    Subclasses must implement this method to define the specific
    observation model. Typically calls ``_predicted_obs`` first,
    then applies the noise model.

    Parameters
    ----------
    obs
        Observed data for conditioning, or None for prior predictive sampling.
    **kwargs
        Subclass-specific parameters (e.g., infections from the infection process).

    Returns
    -------
    ArrayLike
        Observed or sampled values from the observation process.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate observation process parameters.

Subclasses must implement this method to validate all parameters. Typically this involves calling _validate_pmf() for the PMF and adding any additional parameter-specific validation.

Raises:

Type Description
ValueError

If any parameters fail validation.

Source code in pyrenew/observation/base.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@abstractmethod
def validate(self) -> None:
    """
    Validate observation process parameters.

    Subclasses must implement this method to validate all parameters.
    Typically this involves calling ``_validate_pmf()`` for the PMF
    and adding any additional parameter-specific validation.

    Raises
    ------
    ValueError
        If any parameters fail validation.
    """
    pass  # pragma: no cover

validate_data abstractmethod

validate_data(
    n_total: int, n_subpops: int, **obs_data: dict[str, object]
) -> None

Validate observation data before running inference.

Each observation process validates its own data requirements. Called by the model's validate_data() method with concrete (non-traced) values before JAX tracing begins.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
**obs_data dict[str, object]

Observation-specific data kwargs (same as passed to sample(), minus infections which comes from the latent process).

{}

Raises:

Type Description
ValueError

If any data fails validation.

Source code in pyrenew/observation/base.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
@abstractmethod
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    **obs_data: dict[str, object],
) -> None:
    """
    Validate observation data before running inference.

    Each observation process validates its own data requirements.
    Called by the model's ``validate_data()`` method with concrete
    (non-traced) values before JAX tracing begins.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations.
    **obs_data
        Observation-specific data kwargs (same as passed to ``sample()``,
        minus ``infections`` which comes from the latent process).

    Raises
    ------
    ValueError
        If any data fails validation.
    """
    pass  # pragma: no cover

CountNoise

Bases: ABC

Abstract base for count observation noise models.

Defines how discrete count observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample count observations given predicted counts.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values (non-negative).

required
obs ArrayLike | None

Observed counts for conditioning, or None for prior sampling.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. If provided, observations where mask is False are excluded from the likelihood.

None

Returns:

Type Description
ArrayLike

Sampled or conditioned counts, same shape as predicted.

Notes

Implementations use numpyro.handlers.mask rather than the obs_mask parameter of numpyro.sample. This avoids creating latent variables for masked entries, which would fail with NUTS for discrete distributions.

Source code in pyrenew/observation/noise.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample count observations given predicted counts.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values (non-negative).
    obs
        Observed counts for conditioning, or None for prior sampling.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included. If provided,
        observations where mask is False are excluded from the likelihood.

    Returns
    -------
    ArrayLike
        Sampled or conditioned counts, same shape as predicted.

    Notes
    -----
    Implementations use ``numpyro.handlers.mask`` rather than the
    ``obs_mask`` parameter of ``numpyro.sample``. This avoids creating
    latent variables for masked entries, which would fail with NUTS
    for discrete distributions.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate noise model parameters.

Raises:

Type Description
ValueError

If parameters are invalid.

Source code in pyrenew/observation/noise.py
138
139
140
141
142
143
144
145
146
147
148
@abstractmethod
def validate(self) -> None:
    """
    Validate noise model parameters.

    Raises
    ------
    ValueError
        If parameters are invalid.
    """
    pass  # pragma: no cover

Counts

Counts(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
)

Bases: _CountBase

Aggregated count observation.

Maps aggregate infections to counts through ascertainment x delay convolution with composable noise model.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names (e.g., "hospital" produces sites "hospital_obs", "hospital_predicted").

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1] (e.g., IHR, IER).

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv
        Delay distribution PMF (must sum to ~1.0).
    noise
        Noise model for count observations (Poisson, NegBin, etc.).
    right_truncation_rv
        Optional reporting delay PMF for right-truncation adjustment.
        When provided (along with ``right_truncation_offset`` at sample
        time), predicted counts are scaled down for recent timepoints
        to account for incomplete reporting.
    day_of_week_rv : RandomVariable | None
        Optional day-of-week multiplicative effect. Must sample to
        shape (7,) with non-negative values, where entry j is the
        multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
        convention). An effect of 1.0 means no adjustment for that
        day. Values summing to 7.0 preserve weekly totals and keep
        the ascertainment rate interpretable; other sums rescale
        overall predicted counts. When provided (along with
        ``first_day_dow`` at sample time), predicted counts are
        scaled by a periodic weekly pattern.
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise
    self.right_truncation_rv = right_truncation_rv
    self.day_of_week_rv = day_of_week_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
281
282
283
284
285
286
287
288
289
290
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"Counts(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r}, "
        f"right_truncation_rv={self.right_truncation_rv!r}, "
        f"day_of_week_rv={self.day_of_week_rv!r})"
    )

infection_resolution

infection_resolution() -> str

Return "aggregate" for aggregated observations.

Returns:

Type Description
str

The string "aggregate".

Source code in pyrenew/observation/count_observations.py
270
271
272
273
274
275
276
277
278
279
def infection_resolution(self) -> str:
    """
    Return "aggregate" for aggregated observations.

    Returns
    -------
    str
        The string "aggregate".
    """
    return "aggregate"

sample

sample(
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
) -> ObservationSample

Sample aggregated counts.

Both infections and obs use a shared time axis [0, n_total) where n_total = n_init + n_days. NaN in obs marks unobserved timepoints (initialization period or missing data).

Parameters:

Name Type Description Default
infections ArrayLike

Aggregate infections from the infection process. Shape: (n_total,) where n_total = n_init + n_days.

required
obs ArrayLike | None

Observed counts on shared time axis. Shape: (n_total,). Use NaN for initialization period and any missing observations. None for prior predictive sampling.

None
right_truncation_offset int | None

If provided (and right_truncation_rv was set at construction), apply right-truncation adjustment to predicted counts.

None
first_day_dow int | None

Day of the week for the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when day_of_week_rv was set at construction. Use model.compute_first_day_dow(obs_start_dow) to convert from the day of the week of the first observation.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predicted counts before noise, shape: n_total).

Source code in pyrenew/observation/count_observations.py
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
def sample(
    self,
    infections: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
) -> ObservationSample:
    """
    Sample aggregated counts.

    Both infections and obs use a shared time axis [0, n_total) where
    n_total = n_init + n_days. NaN in obs marks unobserved timepoints
    (initialization period or missing data).

    Parameters
    ----------
    infections
        Aggregate infections from the infection process.
        Shape: (n_total,) where n_total = n_init + n_days.
    obs
        Observed counts on shared time axis. Shape: (n_total,).
        Use NaN for initialization period and any missing observations.
        None for prior predictive sampling.
    right_truncation_offset
        If provided (and ``right_truncation_rv`` was set at construction),
        apply right-truncation adjustment to predicted counts.
    first_day_dow : int | None
        Day of the week for the first timepoint on the shared time
        axis (0=Monday, 6=Sunday, ISO convention). Required when
        ``day_of_week_rv`` was set at construction. Use
        ``model.compute_first_day_dow(obs_start_dow)`` to convert
        from the day of the week of the first observation.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned counts) and
        `predicted` (predicted counts before noise, shape: n_total).
    """
    predicted_counts = self._predicted_obs(infections)
    if self.day_of_week_rv is not None:
        if first_day_dow is None:
            raise ValueError(
                "first_day_dow is required when day_of_week_rv is set."
            )
        predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow)
    if self.right_truncation_rv is not None and right_truncation_offset is not None:
        predicted_counts = self._apply_right_truncation(
            predicted_counts, right_truncation_offset
        )
    self._deterministic("predicted", predicted_counts)

    # Compute mask: True where observation contributes to likelihood.
    # NaN in predictions (initialization period) or obs (missing data)
    # are excluded via mask.
    valid_pred = ~jnp.isnan(predicted_counts)
    if obs is not None:
        valid_obs = ~jnp.isnan(obs)
        mask = valid_pred & valid_obs
    else:
        mask = valid_pred

    # JAX evaluates log_prob for all array elements even when mask
    # excludes them from the likelihood sum. Replace NaN with safe values
    # to avoid NaN propagation in JAX's computation graph. These values
    # do not affect inference since mask=False excludes them.
    safe_predicted = jnp.where(jnp.isnan(predicted_counts), 1.0, predicted_counts)
    safe_obs = None
    if obs is not None:
        safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs)

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=safe_predicted,
        obs=safe_obs,
        mask=mask,
    )

    return ObservationSample(observed=observed, predicted=predicted_counts)

validate_data

validate_data(
    n_total: int, n_subpops: int, obs: ArrayLike | None = None, **kwargs: object
) -> None

Validate aggregated count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations (unused for aggregate observations).

required
obs ArrayLike | None

Observed counts on shared time axis. Shape: (n_total,).

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If obs length doesn't match n_total.

Source code in pyrenew/observation/count_observations.py
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None:
    """
    Validate aggregated count observation data.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations (unused for aggregate observations).
    obs
        Observed counts on shared time axis. Shape: (n_total,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If obs length doesn't match n_total.
    """
    if obs is not None:
        self._validate_obs_dense(obs, n_total)

CountsBySubpop

CountsBySubpop(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
)

Bases: _CountBase

Subpopulation-level count observation.

Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1].

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
    right_truncation_rv: RandomVariable | None = None,
    day_of_week_rv: RandomVariable | None = None,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv
        Delay distribution PMF (must sum to ~1.0).
    noise
        Noise model for count observations (Poisson, NegBin, etc.).
    right_truncation_rv
        Optional reporting delay PMF for right-truncation adjustment.
        When provided (along with ``right_truncation_offset`` at sample
        time), predicted counts are scaled down for recent timepoints
        to account for incomplete reporting.
    day_of_week_rv : RandomVariable | None
        Optional day-of-week multiplicative effect. Must sample to
        shape (7,) with non-negative values, where entry j is the
        multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
        convention). An effect of 1.0 means no adjustment for that
        day. Values summing to 7.0 preserve weekly totals and keep
        the ascertainment rate interpretable; other sums rescale
        overall predicted counts. When provided (along with
        ``first_day_dow`` at sample time), predicted counts are
        scaled by a periodic weekly pattern.
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise
    self.right_truncation_rv = right_truncation_rv
    self.day_of_week_rv = day_of_week_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
422
423
424
425
426
427
428
429
430
431
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"CountsBySubpop(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r}, "
        f"right_truncation_rv={self.right_truncation_rv!r}, "
        f"day_of_week_rv={self.day_of_week_rv!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for subpopulation-level observations.

Returns:

Type Description
str

The string "subpop".

Source code in pyrenew/observation/count_observations.py
433
434
435
436
437
438
439
440
441
442
def infection_resolution(self) -> str:
    """
    Return "subpop" for subpopulation-level observations.

    Returns
    -------
    str
        The string "subpop".
    """
    return "subpop"

sample

sample(
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
) -> ObservationSample

Sample subpopulation-level counts.

Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.

Parameters:

Name Type Description Default
infections ArrayLike

Subpopulation-level infections from the infection process. Shape: (n_total, n_subpops)

required
times ArrayLike

Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,)

required
subpop_indices ArrayLike

Subpopulation index for each observation (0-indexed). Shape: (n_obs,)

required
obs ArrayLike | None

Observed counts (n_obs,), or None for prior sampling.

None
right_truncation_offset int | None

If provided (and right_truncation_rv was set at construction), apply right-truncation adjustment to predicted counts.

None
first_day_dow int | None

Day of the week for the first timepoint on the shared time axis (0=Monday, 6=Sunday, ISO convention). Required when day_of_week_rv was set at construction. Use model.compute_first_day_dow(obs_start_dow) to convert from the day of the week of the first observation.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predicted counts before noise, shape: n_total x n_subpops).

Source code in pyrenew/observation/count_observations.py
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
def sample(
    self,
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    obs: ArrayLike | None = None,
    right_truncation_offset: int | None = None,
    first_day_dow: int | None = None,
) -> ObservationSample:
    """
    Sample subpopulation-level counts.

    Times are on the shared time axis [0, n_total) where
    n_total = n_init + n_days. This method performs direct indexing
    without any offset adjustment.

    Parameters
    ----------
    infections
        Subpopulation-level infections from the infection process.
        Shape: (n_total, n_subpops)
    times
        Day index for each observation on the shared time axis.
        Must be in range [0, n_total). Shape: (n_obs,)
    subpop_indices
        Subpopulation index for each observation (0-indexed).
        Shape: (n_obs,)
    obs
        Observed counts (n_obs,), or None for prior sampling.
    right_truncation_offset
        If provided (and ``right_truncation_rv`` was set at construction),
        apply right-truncation adjustment to predicted counts.
    first_day_dow : int | None
        Day of the week for the first timepoint on the shared time
        axis (0=Monday, 6=Sunday, ISO convention). Required when
        ``day_of_week_rv`` was set at construction. Use
        ``model.compute_first_day_dow(obs_start_dow)`` to convert
        from the day of the week of the first observation.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned counts) and
        `predicted` (predicted counts before noise, shape: n_total x n_subpops).
    """
    predicted_counts = self._predicted_obs(infections)
    if self.day_of_week_rv is not None:
        if first_day_dow is None:
            raise ValueError(
                "first_day_dow is required when day_of_week_rv is set."
            )
        predicted_counts = self._apply_day_of_week(predicted_counts, first_day_dow)
    if self.right_truncation_rv is not None and right_truncation_offset is not None:
        predicted_counts = self._apply_right_truncation(
            predicted_counts, right_truncation_offset
        )
    self._deterministic("predicted", predicted_counts)

    # Direct indexing on shared time axis - no offset needed
    predicted_obs = predicted_counts[times, subpop_indices]

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=predicted_obs,
        obs=obs,
    )

    return ObservationSample(observed=observed, predicted=predicted_counts)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None

Validate subpopulation-level count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
times ArrayLike | None

Day index for each observation on the shared time axis.

None
subpop_indices ArrayLike | None

Subpopulation index for each observation (0-indexed).

None
obs ArrayLike | None

Observed counts (n_obs,).

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If times or subpop_indices are out of bounds, or if obs and times have mismatched lengths.

Source code in pyrenew/observation/count_observations.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None:
    """
    Validate subpopulation-level count observation data.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations.
    times
        Day index for each observation on the shared time axis.
    subpop_indices
        Subpopulation index for each observation (0-indexed).
    obs
        Observed counts (n_obs,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If times or subpop_indices are out of bounds, or if
        obs and times have mismatched lengths.
    """
    if times is not None:
        self._validate_times(times, n_total)
        if obs is not None:
            self._validate_obs_times_shape(obs, times)
    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)

HierarchicalNormalNoise

HierarchicalNormalNoise(
    sensor_mode_rv: RandomVariable, sensor_sd_rv: RandomVariable
)

Bases: MeasurementNoise

Normal noise with hierarchical sensor-level effects.

Observation model: obs ~ Normal(predicted + sensor_mode, sensor_sd) where sensor_mode and sensor_sd are sampled per-sensor.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Notes

Use VectorizedRV to wrap simple RVs that lack this interface.

Initialize hierarchical Normal noise.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Source code in pyrenew/observation/noise.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def __init__(
    self,
    sensor_mode_rv: RandomVariable,
    sensor_sd_rv: RandomVariable,
) -> None:
    """
    Initialize hierarchical Normal noise.

    Parameters
    ----------
    sensor_mode_rv
        Prior for sensor-level modes.
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    sensor_sd_rv
        Prior for sensor-level SDs (must be > 0).
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    """
    self.sensor_mode_rv = sensor_mode_rv
    self.sensor_sd_rv = sensor_sd_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
382
383
384
385
386
387
388
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"HierarchicalNormalNoise("
        f"sensor_mode_rv={self.sensor_mode_rv!r}, "
        f"sensor_sd_rv={self.sensor_sd_rv!r})"
    )

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike

Sample from Normal distribution with sensor-level hierarchical effects.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values. Shape: (n_obs,)

required
obs ArrayLike | None

Observed measurements for conditioning. Shape: (n_obs,)

None
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of sensors.

required

Returns:

Type Description
ArrayLike

Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,)

Source code in pyrenew/observation/noise.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike:
    """
    Sample from Normal distribution with sensor-level hierarchical effects.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted measurement values.
        Shape: (n_obs,)
    obs
        Observed measurements for conditioning.
        Shape: (n_obs,)
    sensor_indices
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors
        Total number of sensors.

    Returns
    -------
    ArrayLike
        Normal distributed measurements with hierarchical sensor effects.
        Shape: (n_obs,)
    """
    sensor_mode = self.sensor_mode_rv(n_groups=n_sensors)
    sensor_sd = self.sensor_sd_rv(n_groups=n_sensors)

    loc = predicted + sensor_mode[sensor_indices]
    scale = sensor_sd[sensor_indices]

    return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs)

validate

validate() -> None

Validate noise parameters.

Notes

Full validation requires n_groups, which is only available during sample().

Source code in pyrenew/observation/noise.py
390
391
392
393
394
395
396
397
398
def validate(self) -> None:
    """
    Validate noise parameters.

    Notes
    -----
    Full validation requires n_groups, which is only available during sample().
    """
    pass

MeasurementNoise

Bases: ABC

Abstract base for continuous measurement noise models.

Defines how continuous observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike

Sample continuous observations given predicted values.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values.

required
obs ArrayLike | None

Observed measurements for conditioning, or None for prior sampling.

None
**kwargs object

Additional context (e.g., sensor indices).

{}

Returns:

Type Description
ArrayLike

Sampled or conditioned measurements, same shape as predicted.

Source code in pyrenew/observation/noise.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample continuous observations given predicted values.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted measurement values.
    obs
        Observed measurements for conditioning, or None for prior sampling.
    **kwargs
        Additional context (e.g., sensor indices).

    Returns
    -------
    ArrayLike
        Sampled or conditioned measurements, same shape as predicted.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate noise model parameters.

Raises:

Type Description
ValueError

If parameters are invalid.

Source code in pyrenew/observation/noise.py
328
329
330
331
332
333
334
335
336
337
338
@abstractmethod
def validate(self) -> None:
    """
    Validate noise model parameters.

    Raises
    ------
    ValueError
        If parameters are invalid.
    """
    pass  # pragma: no cover

Measurements

Measurements(
    name: str, temporal_pmf_rv: RandomVariable, noise: MeasurementNoise
)

Bases: BaseObservationProcess

Abstract base for continuous measurement observations.

Subclasses implement signal-specific transformations from infections to predicted measurement values, then add measurement noise.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics for wastewater).

required
noise MeasurementNoise

Noise model for continuous measurements (e.g., HierarchicalNormalNoise).

required
Notes

Subclasses must implement _predicted_obs() according to their specific signal processing (e.g., wastewater shedding kinetics, dilution factors, etc.).

See Also

pyrenew.observation.noise.HierarchicalNormalNoise : Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities

Initialize measurement observation base.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics).

required
noise MeasurementNoise

Noise model (e.g., HierarchicalNormalNoise with sensor effects).

required
Source code in pyrenew/observation/measurements.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    name: str,
    temporal_pmf_rv: RandomVariable,
    noise: MeasurementNoise,
) -> None:
    """
    Initialize measurement observation base.

    Parameters
    ----------
    name
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    temporal_pmf_rv
        Temporal distribution PMF (e.g., shedding kinetics).
    noise
        Noise model (e.g., HierarchicalNormalNoise with sensor effects).
    """
    super().__init__(name=name, temporal_pmf_rv=temporal_pmf_rv)
    self.noise = noise

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/measurements.py
71
72
73
74
75
76
77
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"{self.__class__.__name__}(name={self.name!r}, "
        f"temporal_pmf_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for measurement observations.

Measurement observations require subpopulation-level infections because each measurement corresponds to a specific catchment area.

Returns:

Type Description
str

"subpop"

Source code in pyrenew/observation/measurements.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def infection_resolution(self) -> str:
    """
    Return "subpop" for measurement observations.

    Measurement observations require subpopulation-level infections
    because each measurement corresponds to a specific catchment area.

    Returns
    -------
    str
        ``"subpop"``
    """
    return "subpop"

lookback_days

lookback_days() -> int

Return required lookback days for this observation.

Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points.

Returns:

Type Description
int

Length of temporal PMF minus 1.

Source code in pyrenew/observation/measurements.py
79
80
81
82
83
84
85
86
87
88
89
90
91
def lookback_days(self) -> int:
    """
    Return required lookback days for this observation.

    Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF
    of length L covers lags 0 to L-1, requiring L-1 initialization points.

    Returns
    -------
    int
        Length of temporal PMF minus 1.
    """
    return len(self.temporal_pmf_rv()) - 1

sample

sample(
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
) -> ObservationSample

Sample measurements from observed sensors.

Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.

Transforms infections to predicted values via signal-specific processing (_predicted_obs), then applies noise model.

Parameters:

Name Type Description Default
infections ArrayLike

Infections from the infection process. Shape: (n_total, n_subpops)

required
times ArrayLike

Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,)

required
subpop_indices ArrayLike

Subpopulation index for each observation (0-indexed). Shape: (n_obs,)

required
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of measurement sensors.

required
obs ArrayLike | None

Observed measurements (n_obs,), or None for prior sampling.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned measurements) and predicted (predicted values before noise, shape: n_total x n_subpops).

Source code in pyrenew/observation/measurements.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def sample(
    self,
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample measurements from observed sensors.

    Times are on the shared time axis [0, n_total) where
    n_total = n_init + n_days. This method performs direct indexing
    without any offset adjustment.

    Transforms infections to predicted values via signal-specific processing
    (``_predicted_obs``), then applies noise model.

    Parameters
    ----------
    infections
        Infections from the infection process.
        Shape: (n_total, n_subpops)
    times
        Day index for each observation on the shared time axis.
        Must be in range [0, n_total). Shape: (n_obs,)
    subpop_indices
        Subpopulation index for each observation (0-indexed).
        Shape: (n_obs,)
    sensor_indices
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors
        Total number of measurement sensors.
    obs
        Observed measurements (n_obs,), or None for prior sampling.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned measurements) and
        `predicted` (predicted values before noise, shape: n_total x n_subpops).
    """
    predicted_values = self._predicted_obs(infections)
    self._deterministic("predicted", predicted_values)

    # Direct indexing on shared time axis - no offset needed
    predicted_obs = predicted_values[times, subpop_indices]

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=predicted_obs,
        obs=obs,
        sensor_indices=sensor_indices,
        n_sensors=n_sensors,
    )

    return ObservationSample(observed=observed, predicted=predicted_values)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None

Validate measurement observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
times ArrayLike | None

Day index for each observation on the shared time axis.

None
subpop_indices ArrayLike | None

Subpopulation index for each observation (0-indexed).

None
sensor_indices ArrayLike | None

Sensor index for each observation (0-indexed).

None
n_sensors int | None

Total number of measurement sensors.

None
obs ArrayLike | None

Observed measurements (n_obs,).

None
**kwargs object

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If times, subpop_indices, or sensor_indices are out of bounds, or if obs and times have mismatched lengths.

Source code in pyrenew/observation/measurements.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> None:
    """
    Validate measurement observation data.

    Parameters
    ----------
    n_total
        Total number of time steps (n_init + n_days_post_init).
    n_subpops
        Number of subpopulations.
    times
        Day index for each observation on the shared time axis.
    subpop_indices
        Subpopulation index for each observation (0-indexed).
    sensor_indices
        Sensor index for each observation (0-indexed).
    n_sensors
        Total number of measurement sensors.
    obs
        Observed measurements (n_obs,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If times, subpop_indices, or sensor_indices are out of bounds,
        or if obs and times have mismatched lengths.
    """
    if times is not None:
        self._validate_times(times, n_total)
        if obs is not None:
            self._validate_obs_times_shape(obs, times)
    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)
    if sensor_indices is not None and n_sensors is not None:
        self._validate_index_array(sensor_indices, n_sensors, "sensor_indices")

NegativeBinomialNoise

NegativeBinomialNoise(concentration_rv: RandomVariable)

Bases: CountNoise

Negative Binomial noise for overdispersed counts (variance > mean).

Uses NB2 parameterization. Higher concentration reduces overdispersion.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Notes

The NB2 parameterization has variance = mean + mean^2 / concentration. As concentration -> infinity, this approaches Poisson.

Initialize Negative Binomial noise.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Source code in pyrenew/observation/noise.py
221
222
223
224
225
226
227
228
229
230
231
def __init__(self, concentration_rv: RandomVariable) -> None:
    """
    Initialize Negative Binomial noise.

    Parameters
    ----------
    concentration_rv
        Concentration parameter (must be > 0).
        Higher values reduce overdispersion.
    """
    self.concentration_rv = concentration_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
233
234
235
def __repr__(self) -> str:
    """Return string representation."""
    return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Negative Binomial distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Negative Binomial-distributed counts.

Source code in pyrenew/observation/noise.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Negative Binomial distribution.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values.
    obs
        Observed counts for conditioning.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Negative Binomial-distributed counts.
    """
    concentration = self.concentration_rv()
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.NegativeBinomial2(
                mean=predicted + _EPSILON,
                concentration=concentration,
            ),
            obs=obs,
        )

validate

validate() -> None

Validate concentration is positive.

Raises:

Type Description
ValueError

If concentration <= 0.

Source code in pyrenew/observation/noise.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def validate(self) -> None:
    """
    Validate concentration is positive.

    Raises
    ------
    ValueError
        If concentration <= 0.
    """
    concentration = self.concentration_rv()
    if jnp.any(concentration <= 0):
        raise ValueError(
            f"NegativeBinomialNoise: concentration must be positive, "
            f"got {float(concentration)}"
        )

NegativeBinomialObservation

NegativeBinomialObservation(
    name: str, concentration_rv: RandomVariable, eps: float = 1e-10
)

Bases: RandomVariable

Negative Binomial observation

Default constructor

Parameters:

Name Type Description Default
name str

Name for the numpyro variable.

required
concentration_rv RandomVariable

Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion.

required
eps float

Small value to add to the predicted mean to prevent numerical instability. Defaults to 1e-10.

1e-10

Returns:

Type Description
None
Source code in pyrenew/observation/negativebinomial.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    name: str,
    concentration_rv: RandomVariable,
    eps: float = 1e-10,
) -> None:
    """
    Default constructor

    Parameters
    ----------
    name
        Name for the numpyro variable.
    concentration_rv
        Random variable from which to sample the positive concentration
        parameter of the negative binomial. This parameter is sometimes
        called k, phi, or the "dispersion" or "overdispersion" parameter,
        despite the fact that larger values imply that the distribution
        becomes more Poissonian, while smaller ones imply a greater degree
        of dispersion.
    eps
        Small value to add to the predicted mean to prevent numerical
        instability. Defaults to 1e-10.

    Returns
    -------
    None
    """

    NegativeBinomialObservation.validate(concentration_rv)

    super().__init__(name=name)
    self.concentration_rv = concentration_rv
    self.eps = eps

sample

sample(
    mu: ArrayLike, obs: ArrayLike | None = None, **kwargs: object
) -> ArrayLike

Sample from the negative binomial distribution

Parameters:

Name Type Description Default
mu ArrayLike

Mean parameter of the negative binomial distribution.

required
obs ArrayLike | None

Observed data, by default None.

None
**kwargs object

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
ArrayLike
Source code in pyrenew/observation/negativebinomial.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def sample(
    self,
    mu: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs: object,
) -> ArrayLike:
    """
    Sample from the negative binomial distribution

    Parameters
    ----------
    mu
        Mean parameter of the negative binomial distribution.
    obs
        Observed data, by default None.
    **kwargs
        Additional keyword arguments passed through to internal sample calls, should there be any.

    Returns
    -------
    ArrayLike
    """
    concentration = self.concentration_rv.sample()

    negative_binomial_sample = numpyro.sample(
        name=self.name,
        fn=dist.NegativeBinomial2(
            mean=mu + self.eps,
            concentration=concentration,
        ),
        obs=obs,
    )

    return negative_binomial_sample

validate staticmethod

validate(concentration_rv: RandomVariable) -> None

Check that the concentration_rv is actually a RandomVariable

Parameters:

Name Type Description Default
concentration_rv RandomVariable

RandomVariable from which to sample the positive concentration parameter of the negative binomial.

required

Returns:

Type Description
None
Source code in pyrenew/observation/negativebinomial.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@staticmethod
def validate(concentration_rv: RandomVariable) -> None:
    """
    Check that the concentration_rv is actually a RandomVariable

    Parameters
    ----------
    concentration_rv
        RandomVariable from which to sample the positive concentration
        parameter of the negative binomial.

    Returns
    -------
    None
    """
    assert isinstance(concentration_rv, RandomVariable)
    return None

ObservationSample

Bases: NamedTuple

Return type for observation process sample() methods.

Attributes:

Name Type Description
observed ArrayLike

Sampled or conditioned observations. Shape depends on the observation process and indexing.

predicted ArrayLike

Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks.

PoissonNoise

PoissonNoise()

Bases: CountNoise

Poisson noise for equidispersed counts (variance = mean).

Initialize Poisson noise (no parameters).

Source code in pyrenew/observation/noise.py
156
157
158
def __init__(self) -> None:
    """Initialize Poisson noise (no parameters)."""
    pass

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
160
161
162
def __repr__(self) -> str:
    """Return string representation."""
    return "PoissonNoise()"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Poisson distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Poisson-distributed counts.

Source code in pyrenew/observation/noise.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Poisson distribution.

    Parameters
    ----------
    name
        Numpyro sample site name.
    predicted
        Predicted count values.
    obs
        Observed counts for conditioning.
    mask
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Poisson-distributed counts.
    """
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.Poisson(rate=predicted + _EPSILON),
            obs=obs,
        )

validate

validate() -> None

Validate Poisson noise (always valid).

Source code in pyrenew/observation/noise.py
164
165
166
def validate(self) -> None:
    """Validate Poisson noise (always valid)."""
    pass

VectorizedRV

VectorizedRV(name: str, rv: RandomVariable)

Bases: RandomVariable

Wrapper that adds n_groups support to simple RandomVariables.

Uses numpyro.plate to vectorize sampling, enabling simple RVs to work with noise models expecting the group-level interface.

Parameters:

Name Type Description Default
name str

A name for this random variable. The numpyro plate is named f"{name}_plate".

required
rv RandomVariable

The underlying RandomVariable to wrap.

required

Initialize VectorizedRV wrapper.

Parameters:

Name Type Description Default
name str

A name for this random variable. The numpyro plate is named f"{name}_plate".

required
rv RandomVariable

The underlying RandomVariable to wrap.

required
Source code in pyrenew/observation/noise.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, name: str, rv: RandomVariable) -> None:
    """
    Initialize VectorizedRV wrapper.

    Parameters
    ----------
    name
        A name for this random variable.
        The numpyro plate is named ``f"{name}_plate"``.
    rv
        The underlying RandomVariable to wrap.
    """
    super().__init__(name=name)
    self.rv = rv
    self.plate_name = f"{name}_plate"

sample

sample(n_groups: int, **kwargs: object) -> ArrayLike

Sample n_groups values using numpyro.plate.

Parameters:

Name Type Description Default
n_groups int

Number of group-level values to sample.

required

Returns:

Type Description
ArrayLike

Array of shape (n_groups,).

Source code in pyrenew/observation/noise.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def sample(self, n_groups: int, **kwargs: object) -> ArrayLike:
    """
    Sample n_groups values using numpyro.plate.

    Parameters
    ----------
    n_groups
        Number of group-level values to sample.

    Returns
    -------
    ArrayLike
        Array of shape (n_groups,).
    """
    with numpyro.plate(self.plate_name, n_groups):
        return self.rv(**kwargs)

validate

validate() -> None

Validate the underlying RV.

Source code in pyrenew/observation/noise.py
71
72
73
def validate(self) -> None:  # pragma: no cover
    """Validate the underlying RV."""
    self.rv.validate()