Skip to content

Observation

Observation processes for connecting infections to observed data.

BaseObservationProcess is the abstract base. Concrete subclasses:

  • Counts: Aggregate counts (admissions, deaths)
  • CountsBySubpop: Subpopulation-level counts
  • Measurements: Continuous subpopulation-level signals (e.g., wastewater)

All observation processes implement:

  • sample(): Sample observations given infections
  • infection_resolution(): returns "aggregate" or "subpop"
  • lookback_days(): returns required infection history length

Noise models (CountNoise, MeasurementNoise) are composable—pass them to observation constructors to control the output distribution.

BaseObservationProcess

BaseObservationProcess(name: str, temporal_pmf_rv: RandomVariable)

Bases: RandomVariable

Abstract base class for observation processes that use convolution with temporal distributions.

This class provides common functionality for connecting infections to observed data (e.g., hospital admissions, wastewater concentrations) through temporal convolution operations.

Key features provided:

  • PMF validation (sum to 1, non-negative)
  • Minimum observation day calculation
  • Convolution wrapper with timeline alignment
  • Deterministic quantity tracking

Subclasses must implement:

  • validate(): Validate parameters (call _validate_pmf() for PMFs)
  • lookback_days(): Return PMF length for initialization
  • infection_resolution(): Return "aggregate" or "subpop"
  • _predicted_obs(): Transform infections to predicted values
  • sample(): Apply noise model to predicted observations
Notes

Computing predicted observations on day t requires infection history from previous days (determined by the temporal PMF length). The first len(pmf) - 1 days have insufficient history and return NaN.

See Also

pyrenew.convolve.compute_delay_ascertained_incidence : Underlying convolution function pyrenew.metaclass.RandomVariable : Base class for all random variables

Initialize base observation process.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names, enabling multiple observations of the same type in a single model.

required
temporal_pmf_rv RandomVariable

The temporal distribution PMF (e.g., delay or shedding distribution). Must sample to a 1D array that sums to ~1.0 with non-negative values. Subclasses may have additional parameters.

required
Notes

Subclasses should call super().__init__(name, temporal_pmf_rv) in their constructors and may add additional parameters.

Source code in pyrenew/observation/base.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def __init__(self, name: str, temporal_pmf_rv: RandomVariable) -> None:
    """
    Initialize base observation process.

    Parameters
    ----------
    name : str
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names, enabling multiple
        observations of the same type in a single model.
    temporal_pmf_rv : RandomVariable
        The temporal distribution PMF (e.g., delay or shedding distribution).
        Must sample to a 1D array that sums to ~1.0 with non-negative values.
        Subclasses may have additional parameters.

    Notes
    -----
    Subclasses should call ``super().__init__(name, temporal_pmf_rv)``
    in their constructors and may add additional parameters.
    """
    super().__init__(name=name)
    self.temporal_pmf_rv = temporal_pmf_rv

infection_resolution abstractmethod

infection_resolution() -> str

Return whether this observation uses aggregate or subpop infections.

Returns one of:

  • "aggregate": Uses a single aggregated infection trajectory. Shape: (n_days,)
  • "subpop": Uses subpopulation-level infection trajectories. Shape: (n_days, n_subpops), indexed via subpop_indices.

Returns:

Type Description
str

Either "aggregate" or "subpop"

Notes

This is used by multi-signal models to route the correct infection output to each observation process.

Source code in pyrenew/observation/base.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
@abstractmethod
def infection_resolution(self) -> str:
    """
    Return whether this observation uses aggregate or subpop infections.

    Returns one of:

    - ``"aggregate"``: Uses a single aggregated infection trajectory.
      Shape: ``(n_days,)``
    - ``"subpop"``: Uses subpopulation-level infection trajectories.
      Shape: ``(n_days, n_subpops)``, indexed via ``subpop_indices``.

    Returns
    -------
    str
        Either ``"aggregate"`` or ``"subpop"``

    Notes
    -----
    This is used by multi-signal models to route the correct infection
    output to each observation process.
    """
    pass  # pragma: no cover

lookback_days abstractmethod

lookback_days() -> int

Return the number of days this observation process needs to look back.

This determines the minimum n_initialization_points required by the latent process when this observation is included in a multi-signal model.

Returns:

Type Description
int

Number of days of infection history required.

Notes

Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points. Implementations should return len(pmf) - 1.

This is used by model builders to automatically compute n_initialization_points as max(gen_int_length, max(all lookbacks)).

Source code in pyrenew/observation/base.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
@abstractmethod
def lookback_days(self) -> int:
    """
    Return the number of days this observation process needs to look back.

    This determines the minimum n_initialization_points required by the
    latent process when this observation is included in a multi-signal model.

    Returns
    -------
    int
        Number of days of infection history required.

    Notes
    -----
    Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a
    PMF of length L covers lags 0 to L-1, requiring L-1 initialization
    points. Implementations should return ``len(pmf) - 1``.

    This is used by model builders to automatically compute
    n_initialization_points as ``max(gen_int_length, max(all lookbacks))``.
    """
    pass  # pragma: no cover

sample abstractmethod

sample(obs: ArrayLike | None = None, **kwargs) -> ArrayLike

Sample from the observation process.

Subclasses must implement this method to define the specific observation model. Typically calls _predicted_obs first, then applies the noise model.

Parameters:

Name Type Description Default
obs ArrayLike | None

Observed data for conditioning, or None for prior predictive sampling.

None
**kwargs

Subclass-specific parameters (e.g., infections from the infection process).

{}

Returns:

Type Description
ArrayLike

Observed or sampled values from the observation process.

Source code in pyrenew/observation/base.py
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
@abstractmethod
def sample(
    self,
    obs: ArrayLike | None = None,
    **kwargs,
) -> ArrayLike:
    """
    Sample from the observation process.

    Subclasses must implement this method to define the specific
    observation model. Typically calls ``_predicted_obs`` first,
    then applies the noise model.

    Parameters
    ----------
    obs : ArrayLike | None
        Observed data for conditioning, or None for prior predictive sampling.
    **kwargs
        Subclass-specific parameters (e.g., infections from the infection process).

    Returns
    -------
    ArrayLike
        Observed or sampled values from the observation process.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate observation process parameters.

Subclasses must implement this method to validate all parameters. Typically this involves calling _validate_pmf() for the PMF and adding any additional parameter-specific validation.

Raises:

Type Description
ValueError

If any parameters fail validation.

Source code in pyrenew/observation/base.py
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@abstractmethod
def validate(self) -> None:
    """
    Validate observation process parameters.

    Subclasses must implement this method to validate all parameters.
    Typically this involves calling ``_validate_pmf()`` for the PMF
    and adding any additional parameter-specific validation.

    Raises
    ------
    ValueError
        If any parameters fail validation.
    """
    pass  # pragma: no cover

validate_data abstractmethod

validate_data(n_total: int, n_subpops: int, **obs_data) -> None

Validate observation data before running inference.

Each observation process validates its own data requirements. Called by the model's validate_data() method with concrete (non-traced) values before JAX tracing begins.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
**obs_data

Observation-specific data kwargs (same as passed to sample(), minus infections which comes from the latent process).

{}

Raises:

Type Description
ValueError

If any data fails validation.

Source code in pyrenew/observation/base.py
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
@abstractmethod
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    **obs_data,
) -> None:
    """
    Validate observation data before running inference.

    Each observation process validates its own data requirements.
    Called by the model's ``validate_data()`` method with concrete
    (non-traced) values before JAX tracing begins.

    Parameters
    ----------
    n_total : int
        Total number of time steps (n_init + n_days_post_init).
    n_subpops : int
        Number of subpopulations.
    **obs_data
        Observation-specific data kwargs (same as passed to ``sample()``,
        minus ``infections`` which comes from the latent process).

    Raises
    ------
    ValueError
        If any data fails validation.
    """
    pass  # pragma: no cover

CountNoise

Bases: ABC

Abstract base for count observation noise models.

Defines how discrete count observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample count observations given predicted counts.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values (non-negative).

required
obs ArrayLike | None

Observed counts for conditioning, or None for prior sampling.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. If provided, observations where mask is False are excluded from the likelihood.

None

Returns:

Type Description
ArrayLike

Sampled or conditioned counts, same shape as predicted.

Notes

Implementations use numpyro.handlers.mask rather than the obs_mask parameter of numpyro.sample. This avoids creating latent variables for masked entries, which would fail with NUTS for discrete distributions.

Source code in pyrenew/observation/noise.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample count observations given predicted counts.

    Parameters
    ----------
    name : str
        Numpyro sample site name.
    predicted : ArrayLike
        Predicted count values (non-negative).
    obs : ArrayLike | None
        Observed counts for conditioning, or None for prior sampling.
    mask : ArrayLike | None
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included. If provided,
        observations where mask is False are excluded from the likelihood.

    Returns
    -------
    ArrayLike
        Sampled or conditioned counts, same shape as predicted.

    Notes
    -----
    Implementations use ``numpyro.handlers.mask`` rather than the
    ``obs_mask`` parameter of ``numpyro.sample``. This avoids creating
    latent variables for masked entries, which would fail with NUTS
    for discrete distributions.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate noise model parameters.

Raises:

Type Description
ValueError

If parameters are invalid.

Source code in pyrenew/observation/noise.py
138
139
140
141
142
143
144
145
146
147
148
@abstractmethod
def validate(self) -> None:
    """
    Validate noise model parameters.

    Raises
    ------
    ValueError
        If parameters are invalid.
    """
    pass  # pragma: no cover

Counts

Counts(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
)

Bases: _CountBase

Aggregated count observation.

Maps aggregate infections to counts through ascertainment x delay convolution with composable noise model.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names (e.g., "hospital" produces sites "hospital_obs", "hospital_predicted").

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1] (e.g., IHR, IER).

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name : str
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv : RandomVariable
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv : RandomVariable
        Delay distribution PMF (must sum to ~1.0).
    noise : CountNoise
        Noise model for count observations (Poisson, NegBin, etc.).
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
169
170
171
172
173
174
175
176
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"Counts(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r})"
    )

infection_resolution

infection_resolution() -> str

Return "aggregate" for aggregated observations.

Returns:

Type Description
str

The string "aggregate".

Source code in pyrenew/observation/count_observations.py
158
159
160
161
162
163
164
165
166
167
def infection_resolution(self) -> str:
    """
    Return "aggregate" for aggregated observations.

    Returns
    -------
    str
        The string "aggregate".
    """
    return "aggregate"

sample

sample(
    infections: ArrayLike, obs: ArrayLike | None = None
) -> ObservationSample

Sample aggregated counts.

Both infections and obs use a shared time axis [0, n_total) where n_total = n_init + n_days. NaN in obs marks unobserved timepoints (initialization period or missing data).

Parameters:

Name Type Description Default
infections ArrayLike

Aggregate infections from the infection process. Shape: (n_total,) where n_total = n_init + n_days.

required
obs ArrayLike | None

Observed counts on shared time axis. Shape: (n_total,). Use NaN for initialization period and any missing observations. None for prior predictive sampling.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predicted counts before noise, shape: n_total).

Source code in pyrenew/observation/count_observations.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
def sample(
    self,
    infections: ArrayLike,
    obs: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample aggregated counts.

    Both infections and obs use a shared time axis [0, n_total) where
    n_total = n_init + n_days. NaN in obs marks unobserved timepoints
    (initialization period or missing data).

    Parameters
    ----------
    infections : ArrayLike
        Aggregate infections from the infection process.
        Shape: (n_total,) where n_total = n_init + n_days.
    obs : ArrayLike | None
        Observed counts on shared time axis. Shape: (n_total,).
        Use NaN for initialization period and any missing observations.
        None for prior predictive sampling.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned counts) and
        `predicted` (predicted counts before noise, shape: n_total).
    """
    predicted_counts = self._predicted_obs(infections)
    self._deterministic("predicted", predicted_counts)

    # Compute mask: True where observation contributes to likelihood.
    # NaN in predictions (initialization period) or obs (missing data)
    # are excluded via mask.
    valid_pred = ~jnp.isnan(predicted_counts)
    if obs is not None:
        valid_obs = ~jnp.isnan(obs)
        mask = valid_pred & valid_obs
    else:
        mask = valid_pred

    # JAX evaluates log_prob for all array elements even when mask
    # excludes them from the likelihood sum. Replace NaN with safe values
    # to avoid NaN propagation in JAX's computation graph. These values
    # do not affect inference since mask=False excludes them.
    safe_predicted = jnp.where(jnp.isnan(predicted_counts), 1.0, predicted_counts)
    safe_obs = None
    if obs is not None:
        safe_obs = jnp.where(jnp.isnan(obs), safe_predicted, obs)

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=safe_predicted,
        obs=safe_obs,
        mask=mask,
    )

    return ObservationSample(observed=observed, predicted=predicted_counts)

validate_data

validate_data(
    n_total: int, n_subpops: int, obs: ArrayLike | None = None, **kwargs
) -> None

Validate aggregated count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations (unused for aggregate observations).

required
obs ArrayLike | None

Observed counts on shared time axis. Shape: (n_total,).

None
**kwargs

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If obs length doesn't match n_total.

Source code in pyrenew/observation/count_observations.py
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    obs: ArrayLike | None = None,
    **kwargs,
) -> None:
    """
    Validate aggregated count observation data.

    Parameters
    ----------
    n_total : int
        Total number of time steps (n_init + n_days_post_init).
    n_subpops : int
        Number of subpopulations (unused for aggregate observations).
    obs : ArrayLike | None
        Observed counts on shared time axis. Shape: (n_total,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If obs length doesn't match n_total.
    """
    if obs is not None:
        self._validate_obs_dense(obs, n_total)

CountsBySubpop

CountsBySubpop(
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
)

Bases: _CountBase

Subpopulation-level count observation.

Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
ascertainment_rate_rv RandomVariable

Ascertainment rate in [0, 1].

required
delay_distribution_rv RandomVariable

Delay distribution PMF (must sum to ~1.0).

required
noise CountNoise

Noise model (PoissonNoise, NegativeBinomialNoise, etc.).

required
Source code in pyrenew/observation/count_observations.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(
    self,
    name: str,
    ascertainment_rate_rv: RandomVariable,
    delay_distribution_rv: RandomVariable,
    noise: CountNoise,
) -> None:
    """
    Initialize count observation base.

    Parameters
    ----------
    name : str
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    ascertainment_rate_rv : RandomVariable
        Ascertainment rate in [0, 1] (e.g., IHR, IER).
    delay_distribution_rv : RandomVariable
        Delay distribution PMF (must sum to ~1.0).
    noise : CountNoise
        Noise model for count observations (Poisson, NegBin, etc.).
    """
    super().__init__(name=name, temporal_pmf_rv=delay_distribution_rv)
    self.ascertainment_rate_rv = ascertainment_rate_rv
    self.noise = noise

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/count_observations.py
287
288
289
290
291
292
293
294
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"CountsBySubpop(name={self.name!r}, "
        f"ascertainment_rate_rv={self.ascertainment_rate_rv!r}, "
        f"delay_distribution_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for subpopulation-level observations.

Returns:

Type Description
str

The string "subpop".

Source code in pyrenew/observation/count_observations.py
296
297
298
299
300
301
302
303
304
305
def infection_resolution(self) -> str:
    """
    Return "subpop" for subpopulation-level observations.

    Returns
    -------
    str
        The string "subpop".
    """
    return "subpop"

sample

sample(
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    obs: ArrayLike | None = None,
) -> ObservationSample

Sample subpopulation-level counts.

Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.

Parameters:

Name Type Description Default
infections ArrayLike

Subpopulation-level infections from the infection process. Shape: (n_total, n_subpops)

required
times ArrayLike

Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,)

required
subpop_indices ArrayLike

Subpopulation index for each observation (0-indexed). Shape: (n_obs,)

required
obs ArrayLike | None

Observed counts (n_obs,), or None for prior sampling.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned counts) and predicted (predicted counts before noise, shape: n_total x n_subpops).

Source code in pyrenew/observation/count_observations.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
def sample(
    self,
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    obs: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample subpopulation-level counts.

    Times are on the shared time axis [0, n_total) where
    n_total = n_init + n_days. This method performs direct indexing
    without any offset adjustment.

    Parameters
    ----------
    infections : ArrayLike
        Subpopulation-level infections from the infection process.
        Shape: (n_total, n_subpops)
    times : ArrayLike
        Day index for each observation on the shared time axis.
        Must be in range [0, n_total). Shape: (n_obs,)
    subpop_indices : ArrayLike
        Subpopulation index for each observation (0-indexed).
        Shape: (n_obs,)
    obs : ArrayLike | None
        Observed counts (n_obs,), or None for prior sampling.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned counts) and
        `predicted` (predicted counts before noise, shape: n_total x n_subpops).
    """
    predicted_counts = self._predicted_obs(infections)
    self._deterministic("predicted", predicted_counts)

    # Direct indexing on shared time axis - no offset needed
    predicted_obs = predicted_counts[times, subpop_indices]

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=predicted_obs,
        obs=obs,
    )

    return ObservationSample(observed=observed, predicted=predicted_counts)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    obs: ArrayLike | None = None,
    **kwargs,
) -> None

Validate subpopulation-level count observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
times ArrayLike | None

Day index for each observation on the shared time axis.

None
subpop_indices ArrayLike | None

Subpopulation index for each observation (0-indexed).

None
obs ArrayLike | None

Observed counts (n_obs,).

None
**kwargs

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If times or subpop_indices are out of bounds, or if obs and times have mismatched lengths.

Source code in pyrenew/observation/count_observations.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    obs: ArrayLike | None = None,
    **kwargs,
) -> None:
    """
    Validate subpopulation-level count observation data.

    Parameters
    ----------
    n_total : int
        Total number of time steps (n_init + n_days_post_init).
    n_subpops : int
        Number of subpopulations.
    times : ArrayLike | None
        Day index for each observation on the shared time axis.
    subpop_indices : ArrayLike | None
        Subpopulation index for each observation (0-indexed).
    obs : ArrayLike | None
        Observed counts (n_obs,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If times or subpop_indices are out of bounds, or if
        obs and times have mismatched lengths.
    """
    if times is not None:
        self._validate_times(times, n_total)
        if obs is not None:
            self._validate_obs_times_shape(obs, times)
    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)

HierarchicalNormalNoise

HierarchicalNormalNoise(
    sensor_mode_rv: RandomVariable, sensor_sd_rv: RandomVariable
)

Bases: MeasurementNoise

Normal noise with hierarchical sensor-level effects.

Observation model: obs ~ Normal(predicted + sensor_mode, sensor_sd) where sensor_mode and sensor_sd are sampled per-sensor.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Notes

Use VectorizedRV to wrap simple RVs that lack this interface.

Initialize hierarchical Normal noise.

Parameters:

Name Type Description Default
sensor_mode_rv RandomVariable

Prior for sensor-level modes. Must implement sample(n_groups=...) -> ArrayLike.

required
sensor_sd_rv RandomVariable

Prior for sensor-level SDs (must be > 0). Must implement sample(n_groups=...) -> ArrayLike.

required
Source code in pyrenew/observation/noise.py
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
def __init__(
    self,
    sensor_mode_rv: RandomVariable,
    sensor_sd_rv: RandomVariable,
) -> None:
    """
    Initialize hierarchical Normal noise.

    Parameters
    ----------
    sensor_mode_rv : RandomVariable
        Prior for sensor-level modes.
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    sensor_sd_rv : RandomVariable
        Prior for sensor-level SDs (must be > 0).
        Must implement ``sample(n_groups=...) -> ArrayLike``.
    """
    self.sensor_mode_rv = sensor_mode_rv
    self.sensor_sd_rv = sensor_sd_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
382
383
384
385
386
387
388
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"HierarchicalNormalNoise("
        f"sensor_mode_rv={self.sensor_mode_rv!r}, "
        f"sensor_sd_rv={self.sensor_sd_rv!r})"
    )

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike

Sample from Normal distribution with sensor-level hierarchical effects.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values. Shape: (n_obs,)

required
obs ArrayLike | None

Observed measurements for conditioning. Shape: (n_obs,)

None
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of sensors.

required

Returns:

Type Description
ArrayLike

Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,)

Source code in pyrenew/observation/noise.py
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    *,
    sensor_indices: ArrayLike,
    n_sensors: int,
) -> ArrayLike:
    """
    Sample from Normal distribution with sensor-level hierarchical effects.

    Parameters
    ----------
    name : str
        Numpyro sample site name.
    predicted : ArrayLike
        Predicted measurement values.
        Shape: (n_obs,)
    obs : ArrayLike | None
        Observed measurements for conditioning.
        Shape: (n_obs,)
    sensor_indices : ArrayLike
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors : int
        Total number of sensors.

    Returns
    -------
    ArrayLike
        Normal distributed measurements with hierarchical sensor effects.
        Shape: (n_obs,)
    """
    sensor_mode = self.sensor_mode_rv(n_groups=n_sensors)
    sensor_sd = self.sensor_sd_rv(n_groups=n_sensors)

    loc = predicted + sensor_mode[sensor_indices]
    scale = sensor_sd[sensor_indices]

    return numpyro.sample(name, dist.Normal(loc=loc, scale=scale), obs=obs)

validate

validate() -> None

Validate noise parameters.

Notes

Full validation requires n_groups, which is only available during sample().

Source code in pyrenew/observation/noise.py
390
391
392
393
394
395
396
397
398
def validate(self) -> None:
    """
    Validate noise parameters.

    Notes
    -----
    Full validation requires n_groups, which is only available during sample().
    """
    pass

MeasurementNoise

Bases: ABC

Abstract base for continuous measurement noise models.

Defines how continuous observations are distributed around predicted values.

sample abstractmethod

sample(
    name: str, predicted: ArrayLike, obs: ArrayLike | None = None, **kwargs
) -> ArrayLike

Sample continuous observations given predicted values.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted measurement values.

required
obs ArrayLike | None

Observed measurements for conditioning, or None for prior sampling.

None
**kwargs

Additional context (e.g., sensor indices).

{}

Returns:

Type Description
ArrayLike

Sampled or conditioned measurements, same shape as predicted.

Source code in pyrenew/observation/noise.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
@abstractmethod
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs,
) -> ArrayLike:
    """
    Sample continuous observations given predicted values.

    Parameters
    ----------
    name : str
        Numpyro sample site name.
    predicted : ArrayLike
        Predicted measurement values.
    obs : ArrayLike | None
        Observed measurements for conditioning, or None for prior sampling.
    **kwargs
        Additional context (e.g., sensor indices).

    Returns
    -------
    ArrayLike
        Sampled or conditioned measurements, same shape as predicted.
    """
    pass  # pragma: no cover

validate abstractmethod

validate() -> None

Validate noise model parameters.

Raises:

Type Description
ValueError

If parameters are invalid.

Source code in pyrenew/observation/noise.py
328
329
330
331
332
333
334
335
336
337
338
@abstractmethod
def validate(self) -> None:
    """
    Validate noise model parameters.

    Raises
    ------
    ValueError
        If parameters are invalid.
    """
    pass  # pragma: no cover

Measurements

Measurements(
    name: str, temporal_pmf_rv: RandomVariable, noise: MeasurementNoise
)

Bases: BaseObservationProcess

Abstract base for continuous measurement observations.

Subclasses implement signal-specific transformations from infections to predicted measurement values, then add measurement noise.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics for wastewater).

required
noise MeasurementNoise

Noise model for continuous measurements (e.g., HierarchicalNormalNoise).

required
Notes

Subclasses must implement _predicted_obs() according to their specific signal processing (e.g., wastewater shedding kinetics, dilution factors, etc.).

See Also

pyrenew.observation.noise.HierarchicalNormalNoise : Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities

Initialize measurement observation base.

Parameters:

Name Type Description Default
name str

Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names.

required
temporal_pmf_rv RandomVariable

Temporal distribution PMF (e.g., shedding kinetics).

required
noise MeasurementNoise

Noise model (e.g., HierarchicalNormalNoise with sensor effects).

required
Source code in pyrenew/observation/measurements.py
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(
    self,
    name: str,
    temporal_pmf_rv: RandomVariable,
    noise: MeasurementNoise,
) -> None:
    """
    Initialize measurement observation base.

    Parameters
    ----------
    name : str
        Unique name for this observation process. Used to prefix all
        numpyro sample and deterministic site names.
    temporal_pmf_rv : RandomVariable
        Temporal distribution PMF (e.g., shedding kinetics).
    noise : MeasurementNoise
        Noise model (e.g., HierarchicalNormalNoise with sensor effects).
    """
    super().__init__(name=name, temporal_pmf_rv=temporal_pmf_rv)
    self.noise = noise

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/measurements.py
71
72
73
74
75
76
77
def __repr__(self) -> str:
    """Return string representation."""
    return (
        f"{self.__class__.__name__}(name={self.name!r}, "
        f"temporal_pmf_rv={self.temporal_pmf_rv!r}, "
        f"noise={self.noise!r})"
    )

infection_resolution

infection_resolution() -> str

Return "subpop" for measurement observations.

Measurement observations require subpopulation-level infections because each measurement corresponds to a specific catchment area.

Returns:

Type Description
str

"subpop"

Source code in pyrenew/observation/measurements.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def infection_resolution(self) -> str:
    """
    Return "subpop" for measurement observations.

    Measurement observations require subpopulation-level infections
    because each measurement corresponds to a specific catchment area.

    Returns
    -------
    str
        ``"subpop"``
    """
    return "subpop"

lookback_days

lookback_days() -> int

Return required lookback days for this observation.

Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points.

Returns:

Type Description
int

Length of temporal PMF minus 1.

Source code in pyrenew/observation/measurements.py
79
80
81
82
83
84
85
86
87
88
89
90
91
def lookback_days(self) -> int:
    """
    Return required lookback days for this observation.

    Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF
    of length L covers lags 0 to L-1, requiring L-1 initialization points.

    Returns
    -------
    int
        Length of temporal PMF minus 1.
    """
    return len(self.temporal_pmf_rv()) - 1

sample

sample(
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
) -> ObservationSample

Sample measurements from observed sensors.

Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.

Transforms infections to predicted values via signal-specific processing (_predicted_obs), then applies noise model.

Parameters:

Name Type Description Default
infections ArrayLike

Infections from the infection process. Shape: (n_total, n_subpops)

required
times ArrayLike

Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,)

required
subpop_indices ArrayLike

Subpopulation index for each observation (0-indexed). Shape: (n_obs,)

required
sensor_indices ArrayLike

Sensor index for each observation (0-indexed). Shape: (n_obs,)

required
n_sensors int

Total number of measurement sensors.

required
obs ArrayLike | None

Observed measurements (n_obs,), or None for prior sampling.

None

Returns:

Type Description
ObservationSample

Named tuple with observed (sampled/conditioned measurements) and predicted (predicted values before noise, shape: n_total x n_subpops).

Source code in pyrenew/observation/measurements.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def sample(
    self,
    infections: ArrayLike,
    times: ArrayLike,
    subpop_indices: ArrayLike,
    sensor_indices: ArrayLike,
    n_sensors: int,
    obs: ArrayLike | None = None,
) -> ObservationSample:
    """
    Sample measurements from observed sensors.

    Times are on the shared time axis [0, n_total) where
    n_total = n_init + n_days. This method performs direct indexing
    without any offset adjustment.

    Transforms infections to predicted values via signal-specific processing
    (``_predicted_obs``), then applies noise model.

    Parameters
    ----------
    infections : ArrayLike
        Infections from the infection process.
        Shape: (n_total, n_subpops)
    times : ArrayLike
        Day index for each observation on the shared time axis.
        Must be in range [0, n_total). Shape: (n_obs,)
    subpop_indices : ArrayLike
        Subpopulation index for each observation (0-indexed).
        Shape: (n_obs,)
    sensor_indices : ArrayLike
        Sensor index for each observation (0-indexed).
        Shape: (n_obs,)
    n_sensors : int
        Total number of measurement sensors.
    obs : ArrayLike | None
        Observed measurements (n_obs,), or None for prior sampling.

    Returns
    -------
    ObservationSample
        Named tuple with `observed` (sampled/conditioned measurements) and
        `predicted` (predicted values before noise, shape: n_total x n_subpops).
    """
    predicted_values = self._predicted_obs(infections)
    self._deterministic("predicted", predicted_values)

    # Direct indexing on shared time axis - no offset needed
    predicted_obs = predicted_values[times, subpop_indices]

    observed = self.noise.sample(
        name=self._sample_site_name("obs"),
        predicted=predicted_obs,
        obs=obs,
        sensor_indices=sensor_indices,
        n_sensors=n_sensors,
    )

    return ObservationSample(observed=observed, predicted=predicted_values)

validate_data

validate_data(
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs,
) -> None

Validate measurement observation data.

Parameters:

Name Type Description Default
n_total int

Total number of time steps (n_init + n_days_post_init).

required
n_subpops int

Number of subpopulations.

required
times ArrayLike | None

Day index for each observation on the shared time axis.

None
subpop_indices ArrayLike | None

Subpopulation index for each observation (0-indexed).

None
sensor_indices ArrayLike | None

Sensor index for each observation (0-indexed).

None
n_sensors int | None

Total number of measurement sensors.

None
obs ArrayLike | None

Observed measurements (n_obs,).

None
**kwargs

Additional keyword arguments (ignored).

{}

Raises:

Type Description
ValueError

If times, subpop_indices, or sensor_indices are out of bounds, or if obs and times have mismatched lengths.

Source code in pyrenew/observation/measurements.py
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def validate_data(
    self,
    n_total: int,
    n_subpops: int,
    times: ArrayLike | None = None,
    subpop_indices: ArrayLike | None = None,
    sensor_indices: ArrayLike | None = None,
    n_sensors: int | None = None,
    obs: ArrayLike | None = None,
    **kwargs,
) -> None:
    """
    Validate measurement observation data.

    Parameters
    ----------
    n_total : int
        Total number of time steps (n_init + n_days_post_init).
    n_subpops : int
        Number of subpopulations.
    times : ArrayLike | None
        Day index for each observation on the shared time axis.
    subpop_indices : ArrayLike | None
        Subpopulation index for each observation (0-indexed).
    sensor_indices : ArrayLike | None
        Sensor index for each observation (0-indexed).
    n_sensors : int | None
        Total number of measurement sensors.
    obs : ArrayLike | None
        Observed measurements (n_obs,).
    **kwargs
        Additional keyword arguments (ignored).

    Raises
    ------
    ValueError
        If times, subpop_indices, or sensor_indices are out of bounds,
        or if obs and times have mismatched lengths.
    """
    if times is not None:
        self._validate_times(times, n_total)
        if obs is not None:
            self._validate_obs_times_shape(obs, times)
    if subpop_indices is not None:
        self._validate_subpop_indices(subpop_indices, n_subpops)
    if sensor_indices is not None and n_sensors is not None:
        self._validate_index_array(sensor_indices, n_sensors, "sensor_indices")

NegativeBinomialNoise

NegativeBinomialNoise(concentration_rv: RandomVariable)

Bases: CountNoise

Negative Binomial noise for overdispersed counts (variance > mean).

Uses NB2 parameterization. Higher concentration reduces overdispersion.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Notes

The NB2 parameterization has variance = mean + mean^2 / concentration. As concentration -> infinity, this approaches Poisson.

Initialize Negative Binomial noise.

Parameters:

Name Type Description Default
concentration_rv RandomVariable

Concentration parameter (must be > 0). Higher values reduce overdispersion.

required
Source code in pyrenew/observation/noise.py
221
222
223
224
225
226
227
228
229
230
231
def __init__(self, concentration_rv: RandomVariable) -> None:
    """
    Initialize Negative Binomial noise.

    Parameters
    ----------
    concentration_rv : RandomVariable
        Concentration parameter (must be > 0).
        Higher values reduce overdispersion.
    """
    self.concentration_rv = concentration_rv

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
233
234
235
def __repr__(self) -> str:
    """Return string representation."""
    return f"NegativeBinomialNoise(concentration_rv={self.concentration_rv!r})"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Negative Binomial distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Negative Binomial-distributed counts.

Source code in pyrenew/observation/noise.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Negative Binomial distribution.

    Parameters
    ----------
    name : str
        Numpyro sample site name.
    predicted : ArrayLike
        Predicted count values.
    obs : ArrayLike | None
        Observed counts for conditioning.
    mask : ArrayLike | None
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Negative Binomial-distributed counts.
    """
    concentration = self.concentration_rv()
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.NegativeBinomial2(
                mean=predicted + _EPSILON,
                concentration=concentration,
            ),
            obs=obs,
        )

validate

validate() -> None

Validate concentration is positive.

Raises:

Type Description
ValueError

If concentration <= 0.

Source code in pyrenew/observation/noise.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def validate(self) -> None:
    """
    Validate concentration is positive.

    Raises
    ------
    ValueError
        If concentration <= 0.
    """
    concentration = self.concentration_rv()
    if jnp.any(concentration <= 0):
        raise ValueError(
            f"NegativeBinomialNoise: concentration must be positive, "
            f"got {float(concentration)}"
        )

NegativeBinomialObservation

NegativeBinomialObservation(
    name: str, concentration_rv: RandomVariable, eps: float = 1e-10
)

Bases: RandomVariable

Negative Binomial observation

Default constructor

Parameters:

Name Type Description Default
name str

Name for the numpyro variable.

required
concentration_rv RandomVariable

Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion.

required
eps float

Small value to add to the predicted mean to prevent numerical instability. Defaults to 1e-10.

1e-10

Returns:

Type Description
None
Source code in pyrenew/observation/negativebinomial.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    name: str,
    concentration_rv: RandomVariable,
    eps: float = 1e-10,
) -> None:
    """
    Default constructor

    Parameters
    ----------
    name
        Name for the numpyro variable.
    concentration_rv
        Random variable from which to sample the positive concentration
        parameter of the negative binomial. This parameter is sometimes
        called k, phi, or the "dispersion" or "overdispersion" parameter,
        despite the fact that larger values imply that the distribution
        becomes more Poissonian, while smaller ones imply a greater degree
        of dispersion.
    eps
        Small value to add to the predicted mean to prevent numerical
        instability. Defaults to 1e-10.

    Returns
    -------
    None
    """

    NegativeBinomialObservation.validate(concentration_rv)

    super().__init__(name=name)
    self.concentration_rv = concentration_rv
    self.eps = eps

sample

sample(mu: ArrayLike, obs: ArrayLike | None = None, **kwargs) -> ArrayLike

Sample from the negative binomial distribution

Parameters:

Name Type Description Default
mu ArrayLike

Mean parameter of the negative binomial distribution.

required
obs ArrayLike | None

Observed data, by default None.

None
**kwargs

Additional keyword arguments passed through to internal sample calls, should there be any.

{}

Returns:

Type Description
ArrayLike
Source code in pyrenew/observation/negativebinomial.py
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def sample(
    self,
    mu: ArrayLike,
    obs: ArrayLike | None = None,
    **kwargs,
) -> ArrayLike:
    """
    Sample from the negative binomial distribution

    Parameters
    ----------
    mu
        Mean parameter of the negative binomial distribution.
    obs
        Observed data, by default None.
    **kwargs
        Additional keyword arguments passed through to internal sample calls, should there be any.

    Returns
    -------
    ArrayLike
    """
    concentration = self.concentration_rv.sample()

    negative_binomial_sample = numpyro.sample(
        name=self.name,
        fn=dist.NegativeBinomial2(
            mean=mu + self.eps,
            concentration=concentration,
        ),
        obs=obs,
    )

    return negative_binomial_sample

validate staticmethod

validate(concentration_rv: RandomVariable) -> None

Check that the concentration_rv is actually a RandomVariable

Parameters:

Name Type Description Default
concentration_rv RandomVariable

RandomVariable from which to sample the positive concentration parameter of the negative binomial.

required

Returns:

Type Description
None
Source code in pyrenew/observation/negativebinomial.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@staticmethod
def validate(concentration_rv: RandomVariable) -> None:
    """
    Check that the concentration_rv is actually a RandomVariable

    Parameters
    ----------
    concentration_rv
        RandomVariable from which to sample the positive concentration
        parameter of the negative binomial.

    Returns
    -------
    None
    """
    assert isinstance(concentration_rv, RandomVariable)
    return None

ObservationSample

Bases: NamedTuple

Return type for observation process sample() methods.

Attributes:

Name Type Description
observed ArrayLike

Sampled or conditioned observations. Shape depends on the observation process and indexing.

predicted ArrayLike

Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks.

PoissonNoise

PoissonNoise()

Bases: CountNoise

Poisson noise for equidispersed counts (variance = mean).

Initialize Poisson noise (no parameters).

Source code in pyrenew/observation/noise.py
156
157
158
def __init__(self) -> None:
    """Initialize Poisson noise (no parameters)."""
    pass

__repr__

__repr__() -> str

Return string representation.

Source code in pyrenew/observation/noise.py
160
161
162
def __repr__(self) -> str:
    """Return string representation."""
    return "PoissonNoise()"

sample

sample(
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike

Sample from Poisson distribution.

Parameters:

Name Type Description Default
name str

Numpyro sample site name.

required
predicted ArrayLike

Predicted count values.

required
obs ArrayLike | None

Observed counts for conditioning.

None
mask ArrayLike | None

Boolean mask indicating which observations to include in the likelihood. If None, all observations are included.

None

Returns:

Type Description
ArrayLike

Poisson-distributed counts.

Source code in pyrenew/observation/noise.py
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
def sample(
    self,
    name: str,
    predicted: ArrayLike,
    obs: ArrayLike | None = None,
    mask: ArrayLike | None = None,
) -> ArrayLike:
    """
    Sample from Poisson distribution.

    Parameters
    ----------
    name : str
        Numpyro sample site name.
    predicted : ArrayLike
        Predicted count values.
    obs : ArrayLike | None
        Observed counts for conditioning.
    mask : ArrayLike | None
        Boolean mask indicating which observations to include in the
        likelihood. If None, all observations are included.

    Returns
    -------
    ArrayLike
        Poisson-distributed counts.
    """
    with numpyro.handlers.mask(mask=True if mask is None else mask):
        return numpyro.sample(
            name,
            dist.Poisson(rate=predicted + _EPSILON),
            obs=obs,
        )

validate

validate() -> None

Validate Poisson noise (always valid).

Source code in pyrenew/observation/noise.py
164
165
166
def validate(self) -> None:
    """Validate Poisson noise (always valid)."""
    pass

VectorizedRV

VectorizedRV(name: str, rv: RandomVariable)

Bases: RandomVariable

Wrapper that adds n_groups support to simple RandomVariables.

Uses numpyro.plate to vectorize sampling, enabling simple RVs to work with noise models expecting the group-level interface.

Parameters:

Name Type Description Default
name str

A name for this random variable. The numpyro plate is named f"{name}_plate".

required
rv RandomVariable

The underlying RandomVariable to wrap.

required

Initialize VectorizedRV wrapper.

Parameters:

Name Type Description Default
name str

A name for this random variable. The numpyro plate is named f"{name}_plate".

required
rv RandomVariable

The underlying RandomVariable to wrap.

required
Source code in pyrenew/observation/noise.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
def __init__(self, name: str, rv: RandomVariable) -> None:
    """
    Initialize VectorizedRV wrapper.

    Parameters
    ----------
    name : str
        A name for this random variable.
        The numpyro plate is named ``f"{name}_plate"``.
    rv : RandomVariable
        The underlying RandomVariable to wrap.
    """
    super().__init__(name=name)
    self.rv = rv
    self.plate_name = f"{name}_plate"

sample

sample(n_groups: int, **kwargs)

Sample n_groups values using numpyro.plate.

Parameters:

Name Type Description Default
n_groups int

Number of group-level values to sample.

required

Returns:

Type Description
ArrayLike

Array of shape (n_groups,).

Source code in pyrenew/observation/noise.py
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def sample(self, n_groups: int, **kwargs):
    """
    Sample n_groups values using numpyro.plate.

    Parameters
    ----------
    n_groups : int
        Number of group-level values to sample.

    Returns
    -------
    ArrayLike
        Array of shape (n_groups,).
    """
    with numpyro.plate(self.plate_name, n_groups):
        return self.rv(**kwargs)

validate

validate()

Validate the underlying RV.

Source code in pyrenew/observation/noise.py
71
72
73
def validate(self):  # pragma: no cover
    """Validate the underlying RV."""
    self.rv.validate()