Observation
Observation processes for connecting infections to observed data.
BaseObservationProcess is the abstract base. Concrete subclasses:
Counts: Aggregate counts (admissions, deaths)CountsBySubpop: Subpopulation-level countsMeasurements: Continuous subpopulation-level signals (e.g., wastewater)
All observation processes implement:
sample(): Sample observations given infectionsinfection_resolution(): returns"aggregate"or"subpop"lookback_days(): returns required infection history length
Noise models (CountNoise, MeasurementNoise) are composable—pass them
to observation constructors to control the output distribution.
BaseObservationProcess
BaseObservationProcess(name: str, temporal_pmf_rv: RandomVariable)
Bases: RandomVariable
Abstract base class for observation processes that use convolution with temporal distributions.
This class provides common functionality for connecting infections to observed data (e.g., hospital admissions, wastewater concentrations) through temporal convolution operations.
Key features provided:
- PMF validation (sum to 1, non-negative)
- Minimum observation day calculation
- Convolution wrapper with timeline alignment
- Deterministic quantity tracking
Subclasses must implement:
validate(): Validate parameters (call_validate_pmf()for PMFs)lookback_days(): Return PMF length for initializationinfection_resolution(): Return"aggregate"or"subpop"_predicted_obs(): Transform infections to predicted valuessample(): Apply noise model to predicted observations
Notes
Computing predicted observations on day t requires infection history
from previous days (determined by the temporal PMF length).
The first len(pmf) - 1 days have insufficient history and return NaN.
See Also
pyrenew.convolve.compute_delay_ascertained_incidence : Underlying convolution function pyrenew.metaclass.RandomVariable : Base class for all random variables
Initialize base observation process.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names, enabling multiple observations of the same type in a single model. |
required |
temporal_pmf_rv
|
RandomVariable
|
The temporal distribution PMF (e.g., delay or shedding distribution). Must sample to a 1D array that sums to ~1.0 with non-negative values. Subclasses may have additional parameters. |
required |
Notes
Subclasses should call super().__init__(name, temporal_pmf_rv)
in their constructors and may add additional parameters.
Source code in pyrenew/observation/base.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | |
infection_resolution
abstractmethod
infection_resolution() -> str
Return whether this observation uses aggregate or subpop infections.
Returns one of:
"aggregate": Uses a single aggregated infection trajectory. Shape:(n_days,)"subpop": Uses subpopulation-level infection trajectories. Shape:(n_days, n_subpops), indexed viasubpop_indices.
Returns:
| Type | Description |
|---|---|
str
|
Either |
Notes
This is used by multi-signal models to route the correct infection output to each observation process.
Source code in pyrenew/observation/base.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | |
lookback_days
abstractmethod
lookback_days() -> int
Return the number of days this observation process needs to look back.
This determines the minimum n_initialization_points required by the latent process when this observation is included in a multi-signal model.
Returns:
| Type | Description |
|---|---|
int
|
Number of days of infection history required. |
Notes
Delay/shedding PMFs are 0-indexed (effect can occur on day 0), so a
PMF of length L covers lags 0 to L-1, requiring L-1 initialization
points. Implementations should return len(pmf) - 1.
This is used by model builders to automatically compute
n_initialization_points as max(gen_int_length, max(all lookbacks)).
Source code in pyrenew/observation/base.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | |
sample
abstractmethod
Sample from the observation process.
Subclasses must implement this method to define the specific
observation model. Typically calls _predicted_obs first,
then applies the noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs
|
ArrayLike | None
|
Observed data for conditioning, or None for prior predictive sampling. |
None
|
**kwargs
|
Subclass-specific parameters (e.g., infections from the infection process). |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Observed or sampled values from the observation process. |
Source code in pyrenew/observation/base.py
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 | |
validate
abstractmethod
validate() -> None
Validate observation process parameters.
Subclasses must implement this method to validate all parameters.
Typically this involves calling _validate_pmf() for the PMF
and adding any additional parameter-specific validation.
Raises:
| Type | Description |
|---|---|
ValueError
|
If any parameters fail validation. |
Source code in pyrenew/observation/base.py
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
validate_data
abstractmethod
Validate observation data before running inference.
Each observation process validates its own data requirements.
Called by the model's validate_data() method with concrete
(non-traced) values before JAX tracing begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
**obs_data
|
Observation-specific data kwargs (same as passed to |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If any data fails validation. |
Source code in pyrenew/observation/base.py
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 | |
CountNoise
Bases: ABC
Abstract base for count observation noise models.
Defines how discrete count observations are distributed around predicted values.
sample
abstractmethod
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample count observations given predicted counts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values (non-negative). |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning, or None for prior sampling. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. If provided, observations where mask is False are excluded from the likelihood. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Sampled or conditioned counts, same shape as predicted. |
Notes
Implementations use numpyro.handlers.mask rather than the
obs_mask parameter of numpyro.sample. This avoids creating
latent variables for masked entries, which would fail with NUTS
for discrete distributions.
Source code in pyrenew/observation/noise.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | |
validate
abstractmethod
validate() -> None
Validate noise model parameters.
Raises:
| Type | Description |
|---|---|
ValueError
|
If parameters are invalid. |
Source code in pyrenew/observation/noise.py
138 139 140 141 142 143 144 145 146 147 148 | |
Counts
Counts(
name: str,
ascertainment_rate_rv: RandomVariable,
delay_distribution_rv: RandomVariable,
noise: CountNoise,
)
Bases: _CountBase
Aggregated count observation.
Maps aggregate infections to counts through ascertainment x delay convolution with composable noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names (e.g., "hospital" produces sites "hospital_obs", "hospital_predicted"). |
required |
ascertainment_rate_rv
|
RandomVariable
|
Ascertainment rate in [0, 1] (e.g., IHR, IER). |
required |
delay_distribution_rv
|
RandomVariable
|
Delay distribution PMF (must sum to ~1.0). |
required |
noise
|
CountNoise
|
Noise model (PoissonNoise, NegativeBinomialNoise, etc.). |
required |
Source code in pyrenew/observation/count_observations.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/count_observations.py
169 170 171 172 173 174 175 176 | |
infection_resolution
infection_resolution() -> str
Return "aggregate" for aggregated observations.
Returns:
| Type | Description |
|---|---|
str
|
The string "aggregate". |
Source code in pyrenew/observation/count_observations.py
158 159 160 161 162 163 164 165 166 167 | |
sample
sample(
infections: ArrayLike, obs: ArrayLike | None = None
) -> ObservationSample
Sample aggregated counts.
Both infections and obs use a shared time axis [0, n_total) where n_total = n_init + n_days. NaN in obs marks unobserved timepoints (initialization period or missing data).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Aggregate infections from the infection process. Shape: (n_total,) where n_total = n_init + n_days. |
required |
obs
|
ArrayLike | None
|
Observed counts on shared time axis. Shape: (n_total,). Use NaN for initialization period and any missing observations. None for prior predictive sampling. |
None
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/count_observations.py
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | |
validate_data
Validate aggregated count observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations (unused for aggregate observations). |
required |
obs
|
ArrayLike | None
|
Observed counts on shared time axis. Shape: (n_total,). |
None
|
**kwargs
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If obs length doesn't match n_total. |
Source code in pyrenew/observation/count_observations.py
178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | |
CountsBySubpop
CountsBySubpop(
name: str,
ascertainment_rate_rv: RandomVariable,
delay_distribution_rv: RandomVariable,
noise: CountNoise,
)
Bases: _CountBase
Subpopulation-level count observation.
Maps subpopulation-level infections to counts through ascertainment x delay convolution with composable noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
ascertainment_rate_rv
|
RandomVariable
|
Ascertainment rate in [0, 1]. |
required |
delay_distribution_rv
|
RandomVariable
|
Delay distribution PMF (must sum to ~1.0). |
required |
noise
|
CountNoise
|
Noise model (PoissonNoise, NegativeBinomialNoise, etc.). |
required |
Source code in pyrenew/observation/count_observations.py
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/count_observations.py
287 288 289 290 291 292 293 294 | |
infection_resolution
infection_resolution() -> str
Return "subpop" for subpopulation-level observations.
Returns:
| Type | Description |
|---|---|
str
|
The string "subpop". |
Source code in pyrenew/observation/count_observations.py
296 297 298 299 300 301 302 303 304 305 | |
sample
sample(
infections: ArrayLike,
times: ArrayLike,
subpop_indices: ArrayLike,
obs: ArrayLike | None = None,
) -> ObservationSample
Sample subpopulation-level counts.
Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Subpopulation-level infections from the infection process. Shape: (n_total, n_subpops) |
required |
times
|
ArrayLike
|
Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,) |
required |
subpop_indices
|
ArrayLike
|
Subpopulation index for each observation (0-indexed). Shape: (n_obs,) |
required |
obs
|
ArrayLike | None
|
Observed counts (n_obs,), or None for prior sampling. |
None
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/count_observations.py
347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 | |
validate_data
validate_data(
n_total: int,
n_subpops: int,
times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
obs: ArrayLike | None = None,
**kwargs,
) -> None
Validate subpopulation-level count observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
times
|
ArrayLike | None
|
Day index for each observation on the shared time axis. |
None
|
subpop_indices
|
ArrayLike | None
|
Subpopulation index for each observation (0-indexed). |
None
|
obs
|
ArrayLike | None
|
Observed counts (n_obs,). |
None
|
**kwargs
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If times or subpop_indices are out of bounds, or if obs and times have mismatched lengths. |
Source code in pyrenew/observation/count_observations.py
307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 | |
HierarchicalNormalNoise
HierarchicalNormalNoise(
sensor_mode_rv: RandomVariable, sensor_sd_rv: RandomVariable
)
Bases: MeasurementNoise
Normal noise with hierarchical sensor-level effects.
Observation model: obs ~ Normal(predicted + sensor_mode, sensor_sd)
where sensor_mode and sensor_sd are sampled per-sensor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sensor_mode_rv
|
RandomVariable
|
Prior for sensor-level modes.
Must implement |
required |
sensor_sd_rv
|
RandomVariable
|
Prior for sensor-level SDs (must be > 0).
Must implement |
required |
Notes
Use VectorizedRV to wrap simple RVs that lack this interface.
Initialize hierarchical Normal noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
sensor_mode_rv
|
RandomVariable
|
Prior for sensor-level modes.
Must implement |
required |
sensor_sd_rv
|
RandomVariable
|
Prior for sensor-level SDs (must be > 0).
Must implement |
required |
Source code in pyrenew/observation/noise.py
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
382 383 384 385 386 387 388 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
*,
sensor_indices: ArrayLike,
n_sensors: int,
) -> ArrayLike
Sample from Normal distribution with sensor-level hierarchical effects.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted measurement values. Shape: (n_obs,) |
required |
obs
|
ArrayLike | None
|
Observed measurements for conditioning. Shape: (n_obs,) |
None
|
sensor_indices
|
ArrayLike
|
Sensor index for each observation (0-indexed). Shape: (n_obs,) |
required |
n_sensors
|
int
|
Total number of sensors. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Normal distributed measurements with hierarchical sensor effects. Shape: (n_obs,) |
Source code in pyrenew/observation/noise.py
400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 | |
validate
validate() -> None
Validate noise parameters.
Notes
Full validation requires n_groups, which is only available during sample().
Source code in pyrenew/observation/noise.py
390 391 392 393 394 395 396 397 398 | |
MeasurementNoise
Bases: ABC
Abstract base for continuous measurement noise models.
Defines how continuous observations are distributed around predicted values.
sample
abstractmethod
Sample continuous observations given predicted values.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted measurement values. |
required |
obs
|
ArrayLike | None
|
Observed measurements for conditioning, or None for prior sampling. |
None
|
**kwargs
|
Additional context (e.g., sensor indices). |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Sampled or conditioned measurements, same shape as predicted. |
Source code in pyrenew/observation/noise.py
299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 | |
validate
abstractmethod
validate() -> None
Validate noise model parameters.
Raises:
| Type | Description |
|---|---|
ValueError
|
If parameters are invalid. |
Source code in pyrenew/observation/noise.py
328 329 330 331 332 333 334 335 336 337 338 | |
Measurements
Measurements(
name: str, temporal_pmf_rv: RandomVariable, noise: MeasurementNoise
)
Bases: BaseObservationProcess
Abstract base for continuous measurement observations.
Subclasses implement signal-specific transformations from infections to predicted measurement values, then add measurement noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
temporal_pmf_rv
|
RandomVariable
|
Temporal distribution PMF (e.g., shedding kinetics for wastewater). |
required |
noise
|
MeasurementNoise
|
Noise model for continuous measurements (e.g., HierarchicalNormalNoise). |
required |
Notes
Subclasses must implement _predicted_obs() according to their
specific signal processing (e.g., wastewater shedding kinetics,
dilution factors, etc.).
See Also
pyrenew.observation.noise.HierarchicalNormalNoise : Suitable noise model for sensor-level measurements pyrenew.observation.base.BaseObservationProcess : Parent class with common observation utilities
Initialize measurement observation base.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Unique name for this observation process. Used to prefix all numpyro sample and deterministic site names. |
required |
temporal_pmf_rv
|
RandomVariable
|
Temporal distribution PMF (e.g., shedding kinetics). |
required |
noise
|
MeasurementNoise
|
Noise model (e.g., HierarchicalNormalNoise with sensor effects). |
required |
Source code in pyrenew/observation/measurements.py
49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/measurements.py
71 72 73 74 75 76 77 | |
infection_resolution
infection_resolution() -> str
Return "subpop" for measurement observations.
Measurement observations require subpopulation-level infections because each measurement corresponds to a specific catchment area.
Returns:
| Type | Description |
|---|---|
str
|
|
Source code in pyrenew/observation/measurements.py
93 94 95 96 97 98 99 100 101 102 103 104 105 | |
lookback_days
lookback_days() -> int
Return required lookback days for this observation.
Temporal PMFs are 0-indexed (effect can occur on day 0), so a PMF of length L covers lags 0 to L-1, requiring L-1 initialization points.
Returns:
| Type | Description |
|---|---|
int
|
Length of temporal PMF minus 1. |
Source code in pyrenew/observation/measurements.py
79 80 81 82 83 84 85 86 87 88 89 90 91 | |
sample
sample(
infections: ArrayLike,
times: ArrayLike,
subpop_indices: ArrayLike,
sensor_indices: ArrayLike,
n_sensors: int,
obs: ArrayLike | None = None,
) -> ObservationSample
Sample measurements from observed sensors.
Times are on the shared time axis [0, n_total) where n_total = n_init + n_days. This method performs direct indexing without any offset adjustment.
Transforms infections to predicted values via signal-specific processing
(_predicted_obs), then applies noise model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
infections
|
ArrayLike
|
Infections from the infection process. Shape: (n_total, n_subpops) |
required |
times
|
ArrayLike
|
Day index for each observation on the shared time axis. Must be in range [0, n_total). Shape: (n_obs,) |
required |
subpop_indices
|
ArrayLike
|
Subpopulation index for each observation (0-indexed). Shape: (n_obs,) |
required |
sensor_indices
|
ArrayLike
|
Sensor index for each observation (0-indexed). Shape: (n_obs,) |
required |
n_sensors
|
int
|
Total number of measurement sensors. |
required |
obs
|
ArrayLike | None
|
Observed measurements (n_obs,), or None for prior sampling. |
None
|
Returns:
| Type | Description |
|---|---|
ObservationSample
|
Named tuple with |
Source code in pyrenew/observation/measurements.py
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | |
validate_data
validate_data(
n_total: int,
n_subpops: int,
times: ArrayLike | None = None,
subpop_indices: ArrayLike | None = None,
sensor_indices: ArrayLike | None = None,
n_sensors: int | None = None,
obs: ArrayLike | None = None,
**kwargs,
) -> None
Validate measurement observation data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_total
|
int
|
Total number of time steps (n_init + n_days_post_init). |
required |
n_subpops
|
int
|
Number of subpopulations. |
required |
times
|
ArrayLike | None
|
Day index for each observation on the shared time axis. |
None
|
subpop_indices
|
ArrayLike | None
|
Subpopulation index for each observation (0-indexed). |
None
|
sensor_indices
|
ArrayLike | None
|
Sensor index for each observation (0-indexed). |
None
|
n_sensors
|
int | None
|
Total number of measurement sensors. |
None
|
obs
|
ArrayLike | None
|
Observed measurements (n_obs,). |
None
|
**kwargs
|
Additional keyword arguments (ignored). |
{}
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If times, subpop_indices, or sensor_indices are out of bounds, or if obs and times have mismatched lengths. |
Source code in pyrenew/observation/measurements.py
107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | |
NegativeBinomialNoise
NegativeBinomialNoise(concentration_rv: RandomVariable)
Bases: CountNoise
Negative Binomial noise for overdispersed counts (variance > mean).
Uses NB2 parameterization. Higher concentration reduces overdispersion.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
concentration_rv
|
RandomVariable
|
Concentration parameter (must be > 0). Higher values reduce overdispersion. |
required |
Notes
The NB2 parameterization has variance = mean + mean^2 / concentration. As concentration -> infinity, this approaches Poisson.
Initialize Negative Binomial noise.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
concentration_rv
|
RandomVariable
|
Concentration parameter (must be > 0). Higher values reduce overdispersion. |
required |
Source code in pyrenew/observation/noise.py
221 222 223 224 225 226 227 228 229 230 231 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
233 234 235 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample from Negative Binomial distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values. |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Negative Binomial-distributed counts. |
Source code in pyrenew/observation/noise.py
253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 | |
validate
validate() -> None
Validate concentration is positive.
Raises:
| Type | Description |
|---|---|
ValueError
|
If concentration <= 0. |
Source code in pyrenew/observation/noise.py
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | |
NegativeBinomialObservation
NegativeBinomialObservation(
name: str, concentration_rv: RandomVariable, eps: float = 1e-10
)
Bases: RandomVariable
Negative Binomial observation
Default constructor
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Name for the numpyro variable. |
required |
concentration_rv
|
RandomVariable
|
Random variable from which to sample the positive concentration parameter of the negative binomial. This parameter is sometimes called k, phi, or the "dispersion" or "overdispersion" parameter, despite the fact that larger values imply that the distribution becomes more Poissonian, while smaller ones imply a greater degree of dispersion. |
required |
eps
|
float
|
Small value to add to the predicted mean to prevent numerical instability. Defaults to 1e-10. |
1e-10
|
Returns:
| Type | Description |
|---|---|
None
|
|
Source code in pyrenew/observation/negativebinomial.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | |
sample
Sample from the negative binomial distribution
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
ArrayLike
|
Mean parameter of the negative binomial distribution. |
required |
obs
|
ArrayLike | None
|
Observed data, by default None. |
None
|
**kwargs
|
Additional keyword arguments passed through to internal sample calls, should there be any. |
{}
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
|
Source code in pyrenew/observation/negativebinomial.py
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 | |
validate
staticmethod
validate(concentration_rv: RandomVariable) -> None
Check that the concentration_rv is actually a RandomVariable
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
concentration_rv
|
RandomVariable
|
RandomVariable from which to sample the positive concentration parameter of the negative binomial. |
required |
Returns:
| Type | Description |
|---|---|
None
|
|
Source code in pyrenew/observation/negativebinomial.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | |
ObservationSample
Bases: NamedTuple
Return type for observation process sample() methods.
Attributes:
| Name | Type | Description |
|---|---|---|
observed |
ArrayLike
|
Sampled or conditioned observations. Shape depends on the observation process and indexing. |
predicted |
ArrayLike
|
Predicted values before noise is applied. Useful for diagnostics and posterior predictive checks. |
PoissonNoise
PoissonNoise()
Bases: CountNoise
Poisson noise for equidispersed counts (variance = mean).
Initialize Poisson noise (no parameters).
Source code in pyrenew/observation/noise.py
156 157 158 | |
__repr__
__repr__() -> str
Return string representation.
Source code in pyrenew/observation/noise.py
160 161 162 | |
sample
sample(
name: str,
predicted: ArrayLike,
obs: ArrayLike | None = None,
mask: ArrayLike | None = None,
) -> ArrayLike
Sample from Poisson distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Numpyro sample site name. |
required |
predicted
|
ArrayLike
|
Predicted count values. |
required |
obs
|
ArrayLike | None
|
Observed counts for conditioning. |
None
|
mask
|
ArrayLike | None
|
Boolean mask indicating which observations to include in the likelihood. If None, all observations are included. |
None
|
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Poisson-distributed counts. |
Source code in pyrenew/observation/noise.py
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 | |
validate
validate() -> None
Validate Poisson noise (always valid).
Source code in pyrenew/observation/noise.py
164 165 166 | |
VectorizedRV
VectorizedRV(name: str, rv: RandomVariable)
Bases: RandomVariable
Wrapper that adds n_groups support to simple RandomVariables.
Uses numpyro.plate to vectorize sampling, enabling simple RVs to work with noise models expecting the group-level interface.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
A name for this random variable.
The numpyro plate is named |
required |
rv
|
RandomVariable
|
The underlying RandomVariable to wrap. |
required |
Initialize VectorizedRV wrapper.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
A name for this random variable.
The numpyro plate is named |
required |
rv
|
RandomVariable
|
The underlying RandomVariable to wrap. |
required |
Source code in pyrenew/observation/noise.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 | |
sample
sample(n_groups: int, **kwargs)
Sample n_groups values using numpyro.plate.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_groups
|
int
|
Number of group-level values to sample. |
required |
Returns:
| Type | Description |
|---|---|
ArrayLike
|
Array of shape (n_groups,). |
Source code in pyrenew/observation/noise.py
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | |
validate
validate()
Validate the underlying RV.
Source code in pyrenew/observation/noise.py
71 72 73 | |