Observation
Observation processes for connecting infections to observed data.
BaseObservationProcess is the abstract base. Concrete subclasses:
PopulationCounts: Aggregate counts (admissions, deaths)SubpopulationCounts: Subpopulation-level countsMeasurementObservation: Continuous subpopulation-level signals (e.g., wastewater)
All observation processes implement:
sample(): Sample observations given infectionsinfection_resolution(): returns"aggregate"or"subpop"lookback_days(): returns required infection history length
Noise models (CountNoise, MeasurementNoise) are composable—pass them
to observation constructors to control the output distribution.
BaseObservationProcess
BaseObservationProcess(name: str, temporal_pmf_rv: RandomVariable)
Bases: RandomVariable
Abstract base class for observation processes that use convolution with temporal distributions.
This class provides common functionality for connecting infections to observed data (e.g., hospital admissions, wastewater concentrations) through temporal convolution operations.
Key features provided:
- PMF validation (sum to 1, non-negative)
- Minimum observation day calculation
- Convolution wrapper with timeline alignment
- Deterministic quantity tracking
Subclasses must implement:
lookback_days(): Return PMF length for initializationinfection_resolution(): Return"aggregate"or"subpop"_predicted_obs(): Transform infections to predicted valuessample(): Apply noise model to predicted observations
Notes
Computing predicted observations on day t requires infection history
from previous days (determined by the temporal PMF length).
The first len(pmf) - 1 days have insufficient history and return NaN.
See Also
pyrenew.convolve.compute_delay_ascertained_incidence : Underlying convolution function pyrenew.metaclass.RandomVariable : Base class for all random variables
Initialize base observation process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names, enabling multiple observations of the same type in a single model. |
required |
temporal_pmf_rv
|
RandomVariable
|
The temporal distribution PMF (e.g., delay or shedding distribution). Must sample to a 1D array that sums to ~1.0 with non-negative values. Subclasses may have additional parameters. |
required |
Notes
Subclasses should call super().__init__(name, temporal_pmf_rv)
in their constructors and may add additional parameters.
Source code in pyrenew/observation/base.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 | |
infection_resolution
abstractmethod
infection_resolution() -> str
Return whether this observation uses aggregate or subpop infections.
Returns one of:
"aggregate": Uses a single aggregated infection trajectory. Shape:(n_days,)"subpop": Uses subpopulation-level infection trajectories. Shape:(n_days, n_subpops), indexed viasubpop_indices.
Returns:
| Type | Description |
|---|---|
str
|
Either |
Notes
This is used by multi-signal models to route the correct infection output to each observation process.
Source code in pyrenew/observation/base.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | |
lookback_days
abstractmethod
lookback_days() -> int
Return the number of days this observation process needs to look back.
This determines the minimum n_initialization_points required by the latent process when this observation is included in a multi-signal model.
Returns:
| Type | Description |
|---|---|
int
|
Number of days of infection history required. |
Notes
Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a
PMF of length L covers lags 0 to L-1, requiring L-1 initialization
points. Implementations should return len(pmf) - 1.
This is used by model builders to automatically compute
n_initialization_points as max(gen_int_length, max(all lookbacks)).
Source code in pyrenew/observation/base.py
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 | |
sample
abstractmethod
Sample from the observation process.
Subclasses must implement this method to define the specific
observation model. Typically calls _predicted_obs first,
then applies the noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
ArrayLike | None
|
Observed data for conditioning, or None for prior predictive sampling. |
None
|
**kwargs
|
object
|
Subclass-specific parameters (e.g., infections from the infection process). |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Observed or sampled values from the observation process. |
Source code in pyrenew/observation/base.py
652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 | |
validate_data
abstractmethod
Validate observation data before running inference.
Each observation process validates its own data requirements.
Called by the model's validate_data() method with concrete
(non-traced) values before JAX tracing begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
**obs_data
|
dict[str, object]
|
Observation-specific data kwargs (same as passed to |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If any data fails validation. |
Source code in pyrenew/observation/base.py
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 | |
CountNoise
Bases: ABC
Abstract base for count observation noise models.
Defines how discrete count observations are distributed around predicted values.
sample
abstractmethod
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample count observations given predicted counts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values (non-negative). |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning, or None for prior sampling. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. If provided, observations where mask is False are excluded from the likelihood. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Sampled or conditioned counts, same shape as predicted. |
Notes
Implementations use numpyro.handlers.mask rather than the
obs_mask parameter of numpyro.sample. This avoids creating
latent variables for masked entries, which would fail with NUTS
for discrete distributions.
Source code in pyrenew/observation/noise.py
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | |
CountObservation
CountObservation(
name: str,
ascertainment_rate_rv: RandomVariable,
delay_distribution_rv: RandomVariable,
noise: CountNoise,
right_truncation_rv: RandomVariable | None = None,
day_of_week_rv: RandomVariable | None = None,
aggregation: Literal["daily", "weekly"] = "daily",
reporting_schedule: Literal["regular", "irregular"] = "regular",
start_dow: int | None = None,
)
Bases: BaseObservationProcess
Abstract Base class for count observation processes.
Subclasses map infections to counts through ascertainment x delay convolution with composable noise model. Count observations always receive predictions on the model's daily time axis and then, if requested, aggregate those daily predictions to the observation reporting grid before evaluating the likelihood.
Initialize count observation base.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
ascertainment_rate_rv
|
RandomVariable
|
Ascertainment rate in [0, 1] (e.g., IHR, IER). |
required |
delay_distribution_rv
|
RandomVariable
|
Delay distribution PMF (must sum to ~1.0). |
required |
noise
|
CountNoise
|
Noise model for count observations (Poisson, NegBin, etc.). |
required |
right_truncation_rv
|
RandomVariable | None
|
Optional reporting delay PMF for right-truncation adjustment.
When provided (along with |
None
|
day_of_week_rv
|
RandomVariable | None
|
Optional day-of-week multiplicative effect. Must sample to
shape (7,) with non-negative values, where entry j is the
multiplier for day-of-week j (0=Monday, 6=Sunday, ISO
convention). An effect of 1.0 means no adjustment for that
day. Values summing to 7.0 preserve weekly totals and keep
the ascertainment rate interpretable; other sums rescale
overall predicted counts. When provided (along with
|
None
|
aggregation
|
Literal['daily', 'weekly']
|
Observation reporting cadence; one of |
'daily'
|
reporting_schedule
|
Literal['regular', 'irregular']
|
Either |
'regular'
|
start_dow
|
int | None
|
Day-of-week on which the weekly aggregation cycle begins
(0=Monday, 6=Sunday, ISO convention). Use |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If |
Source code in pyrenew/observation/count_observations.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
aggregation_period
property
aggregation_period: int
Width of the observation reporting period in days.
Returns:
| Type | Description |
|---|---|
int
|
|
infection_resolution
infection_resolution() -> str
Return required infection resolution.
Returns:
| Type | Description |
|---|---|
str
|
"aggregate" or "subpop". |
Source code in pyrenew/observation/count_observations.py
182 183 184 185 186 187 188 189 190 191 | |
lookback_days
lookback_days() -> int
Return required lookback days for this observation.
Delay PMFs are 0-indexed (delay can be 0), so a PMF of length L covers delays 0 to L-1, requiring L-1 initialization points.
Returns:
| Type | Description |
|---|---|
int
|
Length of delay distribution PMF minus 1. |
Source code in pyrenew/observation/count_observations.py
168 169 170 171 172 173 174 175 176 177 178 179 180 | |
validate
validate() -> None
Validate observation parameters.
Raises:
| Type | Description |
|---|---|
ValueError
|
If delay PMF invalid, ascertainment rate outside [0,1], or noise params invalid. |
Source code in pyrenew/observation/count_observations.py
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | |
HierarchicalNormalNoise
HierarchicalNormalNoise(
sensor_mode_rv: RandomVariable, sensor_sd_rv: RandomVariable
)
Bases: MeasurementNoise
Normal noise with hierarchical sensor-level effects.
Observation model: obs ~ Normal(predicted + sensor_mode, sensor_sd)
where sensor_mode and sensor_sd are sampled per-sensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sensor_mode_rv
|
RandomVariable
|
Prior for sensor-level modes.
Must implement |
required |
sensor_sd_rv
|
RandomVariable
|
Prior for sensor-level SDs (must be > 0).
Must implement |
required |
Notes
Use VectorizedVariable
to wrap simple RVs that lack this interface.
Initialize hierarchical Normal noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sensor_mode_rv
|
RandomVariable
|
Prior for sensor-level modes.
Must implement |
required |
sensor_sd_rv
|
RandomVariable
|
Prior for sensor-level SDs (must be > 0).
Must implement |
required |
Source code in pyrenew/observation/noise.py
281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
301 302 303 304 305 306 307 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
*,
sensor_indices: ArrayLike,
n_sensors: int,
) -> ArrayLike
Sample from Normal distribution with sensor-level hierarchical effects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted measurement values. Shape: (n_obs,) |
required |
obs
|
ArrayLike | None
|
Observed measurements for conditioning. Shape: (n_obs,) |
None
|
sensor_indices
|
ArrayLike
|
Sensor index for each observation (0-indexed). Shape: (n_obs,) |
required |
n_sensors
|
int
|
Total number of sensors. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,) |
Source code in pyrenew/observation/noise.py
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 | |
MeasurementNoise
Bases: ABC
Abstract base for continuous measurement noise models.
Defines how continuous observations are distributed around predicted values.
sample
abstractmethod
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
**kwargs: object,
) -> ArrayLike
Sample continuous observations given predicted values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted measurement values. |
required |
obs
|
ArrayLike | None
|
Observed measurements for conditioning, or None for prior sampling. |
None
|
**kwargs
|
object
|
Additional context (e.g., sensor indices). |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Sampled or conditioned measurements, same shape as predicted. |
Source code in pyrenew/observation/noise.py
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 | |
MeasurementObservation
MeasurementObservation(
name: str, temporal_pmf_rv: RandomVariable, noise: MeasurementNoise
)
Bases: BaseObservationProcess
Abstract base for continuous measurement observations.
Subclasses implement signal-specific transformations from infections to predicted measurement values, then add measurement noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
temporal_pmf_rv
|
RandomVariable
|
Temporal distribution PMF (e.g., shedding kinetics for wastewater). |
required |
noise
|
MeasurementNoise
|
Noise model for continuous measurements (e.g., HierarchicalNormalNoise). |
required |
Notes
Subclasses must implement _predicted_obs() according to their
specific signal processing (e.g., wastewater shedding kinetics,
dilution factors, etc.).
See Also
pyrenew.observation.noise.HierarchicalNormalNoise : Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities
Initialize measurement observation base.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
temporal_pmf_rv
|
RandomVariable
|
Temporal distribution PMF (e.g., shedding kinetics). |
required |
noise
|
MeasurementNoise
|
Noise model (e.g., HierarchicalNormalNoise with sensor effects). |
required |
Source code in pyrenew/observation/measurement_observations.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/measurement_observations.py
71 72 73 74 75 76 77 | |
infection_resolution
infection_resolution() -> str
Return "subpop" for measurement observations.
Measurement observations require subpopulation-level infections because each measurement corresponds to a specific catchment area.
Returns:
| Type | Description |
|---|---|
str
|
|
Source code in pyrenew/observation/measurement_observations.py
93 94 95 96 97 98 99 100 101 102 103 104 105 | |
lookback_days
lookback_days() -> int
Return required lookback days for this observation.
Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points.
Returns:
| Type | Description |
|---|---|
int
|
Length of temporal PMF minus 1. |
Source code in pyrenew/observation/measurement_observations.py
79 80 81 82 83 84 85 86 87 88 89 90 91 | |
sample
sample(
infections: ArrayLike,
times: ArrayLike,
subpop_indices: ArrayLike,
sensor_indices: ArrayLike,
n_sensors: int,
obs: ArrayLike | None = None,
**kwargs: object,
) -> ObservationSample
Sample measurements from observed sensors.
Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.
Transforms infections to predicted values via signal-specific processing
(_predicted_obs), then applies noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Infections from the infection process. Shape: (n_total, n_subpops) |
required |
times
|
ArrayLike
|
Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,) |
required |
subpop_indices
|
ArrayLike
|
Subpopulation index for each observation (0-indexed). Shape: (n_obs,) |
required |
sensor_indices
|
ArrayLike
|
Sensor index for each observation (0-indexed). Shape: (n_obs,) |
required |
n_sensors
|
int
|
Total number of measurement sensors. |
required |
obs
|
ArrayLike | None
|
Observed measurements (n_obs,), or None for prior sampling. |
None
|
**kwargs
|
object
|
Additional keyword arguments forwarded by the model
dispatch (e.g., |
{}
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/measurement_observations.py
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | |
validate_data
validate_data(
n_total: int,
n_subpops: int,
times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
sensor_indices: ArrayLike | None = None,
n_sensors: int | None = None,
obs: ArrayLike | None = None,
**kwargs: object,
) -> None
Validate measurement observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
times
|
ArrayLike | None
|
Day index for each observation on the shared time axis. |
None
|
subpop_indices
|
ArrayLike | None
|
Subpopulation index for each observation (0-indexed). |
None
|
sensor_indices
|
ArrayLike | None
|
Sensor index for each observation (0-indexed). |
None
|
n_sensors
|
int | None
|
Total number of measurement sensors. |
None
|
obs
|
ArrayLike | None
|
Observed measurements (n_obs,). |
None
|
**kwargs
|
object
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If times, subpop_indices, or sensor_indices are out of bounds, or if obs and times have mismatched lengths. |
Source code in pyrenew/observation/measurement_observations.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
NegativeBinomialNoise
NegativeBinomialNoise(concentration_rv: RandomVariable)
Bases: CountNoise
Negative Binomial noise for overdispersed counts (variance > mean).
Uses NB2 parameterization. Higher concentration reduces overdispersion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
concentration_rv
|
RandomVariable
|
Concentration parameter (must be > 0). Higher values reduce overdispersion. |
required |
Notes
The NB2 parameterization has variance = mean + mean^2 / concentration. As concentration -> infinity, this approaches Poisson.
Initialize Negative Binomial noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
concentration_rv
|
RandomVariable
|
Concentration parameter (must be > 0). Higher values reduce overdispersion. |
required |
Source code in pyrenew/observation/noise.py
151 152 153 154 155 156 157 158 159 160 161 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
163 164 165 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample from Negative Binomial distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values. |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Negative Binomial-distributed counts. |
Source code in pyrenew/observation/noise.py
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 | |
validate_concentration_rv
validate_concentration_rv() -> None
Validate concentration is positive.
Raises:
| Type | Description |
|---|---|
ValueError
|
If concentration <= 0. |
Source code in pyrenew/observation/noise.py
167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
NegativeBinomialObservation
NegativeBinomialObservation(name: str, concentration_rv: RandomVariable)
Bases: RandomVariable
Negative Binomial observation
Default constructor
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Name for the numpyro variable. |
required |
concentration_rv
|
RandomVariable
|
Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion. |
required |
Returns:
| Type | Description |
|---|---|
None
|
|
Source code in pyrenew/observation/negativebinomial.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 | |
sample
Sample from the negative binomial distribution
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
ArrayLike
|
Mean parameter of the negative binomial distribution. |
required |
obs
|
ArrayLike | None
|
Observed data, by default None. |
None
|
**kwargs
|
object
|
Additional keyword arguments passed through to internal sample calls, should there be any. |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
|
Source code in pyrenew/observation/negativebinomial.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | |
ObservationSample
Bases: NamedTuple
Return type for observation process sample() methods.
Attributes:
| Name | Type | Description |
|---|---|---|
observed |
ArrayLike
|
Sampled or conditioned observations. Shape depends on the observation process and indexing. |
predicted |
ArrayLike
|
Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks. |
PoissonNoise
PoissonNoise()
Bases: CountNoise
Poisson noise for equidispersed counts (variance = mean).
Initialize Poisson noise (no parameters).
Source code in pyrenew/observation/noise.py
83 84 85 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
87 88 89 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample from Poisson distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values. |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Poisson-distributed counts. |
Source code in pyrenew/observation/noise.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | |
validate
staticmethod
validate() -> None
PoissonNoise always passes validation.
Source code in pyrenew/observation/noise.py
91 92 93 94 95 96 | |
PopulationCounts
PopulationCounts(
name: str,
ascertainment_rate_rv: RandomVariable,
delay_distribution_rv: RandomVariable,
noise: CountNoise,
right_truncation_rv: RandomVariable | None = None,
day_of_week_rv: RandomVariable | None = None,
aggregation: Literal["daily", "weekly"] = "daily",
reporting_schedule: Literal["regular", "irregular"] = "regular",
start_dow: int | None = None,
)
Bases: CountObservation
Aggregated count observation.
Maps aggregate infections to counts through ascertainment x delay
convolution with composable noise model. Predictions are constructed on
the daily model axis; aggregation_period controls whether those
predictions are scored as daily counts or summed to weekly counts before
the likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names (e.g., "hospital" produces sites "hospital_obs", "hospital_predicted"). |
required |
ascertainment_rate_rv
|
RandomVariable
|
Ascertainment rate in [0, 1] (e.g., IHR, IER). |
required |
delay_distribution_rv
|
RandomVariable
|
Delay distribution PMF (must sum to ~1.0). |
required |
noise
|
CountNoise
|
Noise model (PoissonNoise, NegativeBinomialNoise, etc.). |
required |
Source code in pyrenew/observation/count_observations.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/count_observations.py
563 564 565 566 567 568 569 570 571 572 | |
infection_resolution
infection_resolution() -> str
Return "aggregate" for aggregated observations.
Returns:
| Type | Description |
|---|---|
str
|
The string "aggregate". |
Source code in pyrenew/observation/count_observations.py
552 553 554 555 556 557 558 559 560 561 | |
sample
sample(
infections: ArrayLike,
obs: ArrayLike | None = None,
right_truncation_offset: int | None = None,
first_day_dow: int | None = None,
period_end_times: ArrayLike | None = None,
) -> ObservationSample
Sample aggregated counts.
Daily transforms (right-truncation, day-of-week) run on the
daily axis. When aggregation == "weekly" the daily
predictions are summed onto the reporting-period grid before
the noise model. Likelihood path depends on
reporting_schedule: "regular" uses a dense-with-NaN
array plus a mask; "irregular" fancy-indexes the
aggregated array at period indices derived from
period_end_times.
aggregation_period describes the observation scale only. The
latent infection process may use daily or coarser Rt parameter
cadence, but by the time this method is called it supplies infections
on the daily model axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Aggregate infections from the infection process.
Shape |
required |
obs
|
ArrayLike | None
|
Observed counts. Shape depends on |
None
|
right_truncation_offset
|
int | None
|
If provided (and |
None
|
first_day_dow
|
int | None
|
Day-of-week index of the first timepoint on the shared
time axis (0=Monday, 6=Sunday, ISO convention). Required
when |
None
|
period_end_times
|
ArrayLike | None
|
Daily-axis indices of each observed period's final day.
Required when |
None
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/count_observations.py
651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 | |
validate_data
validate_data(
n_total: int,
n_subpops: int,
obs: ArrayLike | None = None,
period_end_times: ArrayLike | None = None,
first_day_dow: int | None = None,
**kwargs: object,
) -> None
Validate aggregated count observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of daily time steps ( |
required |
n_subpops
|
int
|
Number of subpopulations (unused for aggregate observations). |
required |
obs
|
ArrayLike | None
|
Observed counts. Shape depends on |
None
|
period_end_times
|
ArrayLike | None
|
Daily-axis indices of each observed period's final day. Required
for |
None
|
first_day_dow
|
int | None
|
Day-of-week index of element 0 of the shared time axis
(0=Monday, 6=Sunday, ISO convention). Required when
|
None
|
**kwargs
|
object
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If obs length or period_end_times fail their respective
checks, or if |
Source code in pyrenew/observation/count_observations.py
574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 | |
SubpopulationCounts
SubpopulationCounts(
name: str,
ascertainment_rate_rv: RandomVariable,
delay_distribution_rv: RandomVariable,
noise: CountNoise,
right_truncation_rv: RandomVariable | None = None,
day_of_week_rv: RandomVariable | None = None,
aggregation: Literal["daily", "weekly"] = "daily",
reporting_schedule: Literal["regular", "irregular"] = "regular",
start_dow: int | None = None,
)
Bases: CountObservation
Subpopulation-level count observation.
Maps subpopulation-level infections to counts through
ascertainment x delay convolution with composable noise model.
Predictions are constructed on the daily model axis for each
subpopulation; aggregation_period controls whether those predictions
are scored as daily counts or summed to weekly counts before the
likelihood.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
ascertainment_rate_rv
|
RandomVariable
|
Ascertainment rate in [0, 1]. |
required |
delay_distribution_rv
|
RandomVariable
|
Delay distribution PMF (must sum to ~1.0). |
required |
noise
|
CountNoise
|
Noise model (PoissonNoise, NegativeBinomialNoise, etc.). |
required |
Source code in pyrenew/observation/count_observations.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/count_observations.py
756 757 758 759 760 761 762 763 764 765 | |
infection_resolution
infection_resolution() -> str
Return "subpop" for subpopulation-level observations.
Returns:
| Type | Description |
|---|---|
str
|
The string "subpop". |
Source code in pyrenew/observation/count_observations.py
767 768 769 770 771 772 773 774 775 776 | |
sample
sample(
infections: ArrayLike,
obs: ArrayLike | None = None,
right_truncation_offset: int | None = None,
first_day_dow: int | None = None,
period_end_times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
) -> ObservationSample
Sample subpopulation-level counts.
Daily transforms (right-truncation, day-of-week) run on the
daily axis. When aggregation == "weekly" the daily
predictions are summed onto the reporting-period grid before
the noise model. Likelihood path depends on
reporting_schedule: "regular" selects the observed
subpopulation columns and uses a dense-with-NaN array plus
a mask; "irregular" fancy-indexes the aggregated array
at period indices derived from period_end_times.
aggregation_period describes the observation scale only. The
latent infection process may use daily or coarser Rt parameter
cadence, but by the time this method is called it supplies infections
on the daily model axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Subpopulation-level infections from the infection process.
Shape |
required |
obs
|
ArrayLike | None
|
Observed counts. For |
None
|
right_truncation_offset
|
int | None
|
If provided (and |
None
|
first_day_dow
|
int | None
|
Day-of-week index of the first timepoint on the shared
time axis (0=Monday, 6=Sunday, ISO convention). Required
when |
None
|
period_end_times
|
ArrayLike | None
|
Daily-axis indices of each observed period's final day.
Required when |
None
|
subpop_indices
|
ArrayLike | None
|
Subpopulation indices (0-indexed). Required. For
|
None
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/count_observations.py
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 | |
validate_data
validate_data(
n_total: int,
n_subpops: int,
obs: ArrayLike | None = None,
period_end_times: ArrayLike | None = None,
first_day_dow: int | None = None,
subpop_indices: ArrayLike | None = None,
**kwargs: object,
) -> None
Validate subpopulation-level count observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of daily time steps ( |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
obs
|
ArrayLike | None
|
Observed counts. For |
None
|
period_end_times
|
ArrayLike | None
|
Daily-axis indices of each observed period's final day.
Required for |
None
|
first_day_dow
|
int | None
|
Day-of-week index of element 0 of the shared time axis
(0=Monday, 6=Sunday, ISO convention). Required when
|
None
|
subpop_indices
|
ArrayLike | None
|
Subpopulation indices (0-indexed). For
|
None
|
**kwargs
|
object
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If any index array is out of bounds, any shape check
fails, |
Source code in pyrenew/observation/count_observations.py
778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 | |